Skip to content

Commit cb005a1

Browse files
committed
feat: add context in s3 Get functions
1 parent 93e1f03 commit cb005a1

File tree

2 files changed

+105
-11
lines changed

2 files changed

+105
-11
lines changed

aws/s3/s3.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,32 @@ func (s *S3) Client() *s3.Client {
128128
// Get gets the object referred to by key and version from bucket and writes it into b.
129129
// Version can be empty.
130130
func (s *S3) Get(bucket, key, version string, b *bytes.Buffer) error {
131+
return s.GetWithContext(context.Background(), bucket, key, version, b)
132+
}
133+
134+
// Get gets the object referred to by key and version from bucket and writes it into b.
135+
// with the provided context.
136+
// Version can be empty.
137+
func (s *S3) GetWithContext(ctx context.Context, bucket, key, version string, w io.Writer) error {
131138
input := s3.GetObjectInput{
132139
Key: aws.String(key),
133140
Bucket: aws.String(bucket),
134141
}
135142
if version != "" {
136143
input.VersionId = aws.String(version)
137144
}
138-
result, err := s.client.GetObject(context.TODO(), &input)
145+
result, err := s.client.GetObject(ctx, &input)
139146
if err != nil {
140147
return err
141148
}
142149
defer result.Body.Close()
143150

144-
_, err = b.ReadFrom(result.Body)
145-
151+
_, err = io.Copy(w, result.Body)
152+
if ctx.Err() != nil {
153+
return ctx.Err()
154+
}
146155
return err
156+
147157
}
148158

149159
// GetByteRange gets the specified byte range of an object referred to by key and version

aws/s3/s3_concurrent.go

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,34 +158,91 @@ func newConcurrencyManager(maxWorkers, maxWorkersPerRequest, maxBytes int) *Conc
158158
// Version can be empty, but must be the same for all objects.
159159
func (s *S3Concurrent) GetAllConcurrently(bucket, version string, objects []types.Object) chan HydratedFile {
160160

161+
return s.GetAllConcurrentlyWithContext(context.Background(), bucket, version, objects)
162+
}
163+
164+
// GetAllConcurrently gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles
165+
// to the returned output channel. The closure of this channel is handled, however it's the caller's
166+
// responsibility to purge the channel, and handle any errors present in the HydratedFiles.
167+
// If the ConcurrencyManager is not initialised before calling GetAllConcurrently, an output channel
168+
// containing a single HydratedFile with an error is returned.
169+
// Version can be empty, but must be the same for all objects.
170+
func (s *S3Concurrent) GetAllConcurrentlyWithContext(
171+
ctx context.Context,
172+
bucket, version string,
173+
objects []types.Object,
174+
) chan HydratedFile {
175+
176+
output := make(chan HydratedFile, 1)
177+
178+
// Early cancel check
179+
select {
180+
case <-ctx.Done():
181+
output <- HydratedFile{Error: ctx.Err()}
182+
close(output)
183+
return output
184+
default:
185+
}
186+
161187
if s.manager == nil {
162-
output := make(chan HydratedFile, 1)
163-
output <- HydratedFile{Error: errors.New("error getting files from S3, Concurrency Manager not initialised")}
188+
output <- HydratedFile{
189+
Error: errors.New("error getting files from S3, Concurrency Manager not initialised"),
190+
}
164191
close(output)
165192
return output
166193
}
167194

168195
if s.manager.memoryTotalSize < s.manager.calculateRequiredMemoryFor(objects) {
169-
output := make(chan HydratedFile, 1)
170-
output <- HydratedFile{Error: fmt.Errorf("error: bytes requested greater than max allowed by server (%v)", s.manager.memoryTotalSize)}
196+
output <- HydratedFile{
197+
Error: fmt.Errorf(
198+
"error: bytes requested greater than max allowed by server (%v)",
199+
s.manager.memoryTotalSize,
200+
),
201+
}
171202
close(output)
172203
return output
173204
}
174-
// Secure memory for all objects upfront.
175-
s.manager.secureMemory(objects) // 0.
205+
206+
// Secure memory for all objects upfront
207+
s.manager.secureMemory(objects)
208+
209+
// IMPORTANT: ensure memory is released if context cancels before processing finishes
210+
go func() {
211+
<-ctx.Done()
212+
// Best-effort cleanup: release all secured memory
213+
for _, o := range objects {
214+
s.manager.releaseMemory(aws.ToInt64(o.Size))
215+
}
216+
}()
176217

177218
processFunc := func(input types.Object) HydratedFile {
219+
// Respect cancellation before starting work
220+
select {
221+
case <-ctx.Done():
222+
return HydratedFile{Error: ctx.Err()}
223+
default:
224+
}
225+
178226
buf := bytes.NewBuffer(make([]byte, 0, int(*input.Size)))
179227
key := aws.ToString(input.Key)
180-
err := s.Get(bucket, key, version, buf)
228+
229+
// Prefer context-aware S3 call if available
230+
err := s.GetWithContext(ctx, bucket, key, version, buf)
231+
232+
// If context was cancelled during S3 read, surface that
233+
if ctx.Err() != nil {
234+
return HydratedFile{Error: ctx.Err()}
235+
}
181236

182237
return HydratedFile{
183238
Key: key,
184239
Data: buf.Bytes(),
185240
Error: err,
186241
}
187242
}
188-
return s.manager.Process(processFunc, objects)
243+
244+
// Process already accepts a context internally, so pass it through
245+
return s.manager.ProcessWithContext(ctx, processFunc, objects)
189246
}
190247

191248
// getWorker retrieves a number of workers from the manager's worker pool.
@@ -259,6 +316,33 @@ func (cm *ConcurrencyManager) Process(asyncProcessor FileProcessor, objects []ty
259316
return workerGroup.returnOutput() // 2.
260317
}
261318

319+
// Functions for providing a fan-out/fan-in operation with provided context. Workers are taken from the
320+
// worker pool and added to a WorkerGroup. All workers are returned to the pool once
321+
// the jobs have finished.
322+
func (cm *ConcurrencyManager) ProcessWithContext(
323+
ctx context.Context,
324+
asyncProcessor FileProcessor,
325+
objects []types.Object,
326+
) chan HydratedFile {
327+
328+
workerGroup := cm.newWorkerGroup(ctx, asyncProcessor, cm.maxWorkersPerRequest)
329+
330+
go func() {
331+
for _, obj := range objects {
332+
select {
333+
case <-ctx.Done():
334+
workerGroup.stopWork()
335+
return
336+
default:
337+
workerGroup.addWork(obj)
338+
}
339+
}
340+
workerGroup.stopWork()
341+
}()
342+
343+
return workerGroup.returnOutput()
344+
}
345+
262346
// start begins a worker's process of making itself available for work, doing the work,
263347
// and repeat, until all work is done.
264348
func (w *worker) start(ctx context.Context, processor FileProcessor, roster chan *worker, wg *sync.WaitGroup) {

0 commit comments

Comments
 (0)