Skip to content

Commit 5750ac9

Browse files
authored
Merge pull request #202 from GeoNet/s3-context
feat: context aware s3 Get function
2 parents feb97b8 + cb05654 commit 5750ac9

2 files changed

Lines changed: 167 additions & 10 deletions

File tree

aws/s3/s3.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,22 +128,40 @@ func (s *S3) Client() *s3.Client {
128128
// Get gets the object referred to by key and version from bucket and writes it into b.
129129
// Version can be empty.
130130
func (s *S3) Get(bucket, key, version string, b *bytes.Buffer) error {
131+
_, err := s.GetWithContext(context.Background(), bucket, key, version, b)
132+
return err
133+
}
134+
135+
// Get gets the object referred to by key and version from bucket and writes it into b.
136+
// with the provided context.
137+
// Version can be empty.
138+
func (s *S3) GetWithContext(
139+
ctx context.Context,
140+
bucket, key, version string,
141+
w io.Writer,
142+
) (int64, error) {
143+
131144
input := s3.GetObjectInput{
132-
Key: aws.String(key),
133145
Bucket: aws.String(bucket),
146+
Key: aws.String(key),
134147
}
135148
if version != "" {
136149
input.VersionId = aws.String(version)
137150
}
138-
result, err := s.client.GetObject(context.TODO(), &input)
151+
152+
result, err := s.client.GetObject(ctx, &input)
139153
if err != nil {
140-
return err
154+
return 0, err
141155
}
142156
defer result.Body.Close()
143157

144-
_, err = b.ReadFrom(result.Body)
158+
n, err := io.Copy(w, result.Body)
145159

146-
return err
160+
// Distinguish cancellation from real errors
161+
if ctx.Err() != nil {
162+
return n, ctx.Err()
163+
}
164+
return n, err
147165
}
148166

149167
// GetByteRange gets the specified byte range of an object referred to by key and version

aws/s3/s3_integration_test.go

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,23 @@ func setAwsEnv() {
4343
os.Setenv("AWS_SECRET_ACCESS_KEY", "test")
4444
os.Setenv("AWS_ACCESS_KEY_ID", "test")
4545
os.Setenv("AWS_ENDPOINT_URL", customAWSEndpoint)
46+
os.Setenv("AWS_S3_DISABLE_CHECKSUM", "true")
4647
}
4748

4849
func setup() {
4950
// setup environment variable to run AWS CLI/SDK
5051
setAwsEnv()
5152

5253
// create bucket
53-
if err := exec.Command( //nolint:gosec
54+
cmd := exec.Command( //nolint:gosec
5455
"aws", "s3api",
5556
"create-bucket",
5657
"--bucket", testBucket,
5758
"--create-bucket-configuration", fmt.Sprintf(
5859
"{\"LocationConstraint\": \"%v\"}", testRegion),
59-
).Run(); err != nil {
60+
)
61+
if output, err := cmd.CombinedOutput(); err != nil {
62+
fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output))
6063
panic(err)
6164
}
6265
}
@@ -146,11 +149,12 @@ func awsCmdPutKeys(keys []string) {
146149
testFile.Close()
147150
}
148151
// sync to bucket
149-
if err := exec.Command(
152+
cmd := exec.Command(
150153
"aws", "s3",
151154
"sync", tmpDir, fmt.Sprintf("s3://%v", testBucket),
152-
).Run(); err != nil {
153-
155+
)
156+
if output, err := cmd.CombinedOutput(); err != nil {
157+
fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output))
154158
panic(err)
155159
}
156160
}
@@ -360,6 +364,141 @@ func TestS3Get(t *testing.T) {
360364
assert.Equal(t, testObjectData, dataObject.String())
361365
}
362366

367+
func TestS3GetWithContext(t *testing.T) {
368+
// ARRANGE
369+
setup()
370+
defer teardown()
371+
372+
awsCmdPopulateBucket()
373+
374+
client, err := New()
375+
require.NoError(t, err, "error creating s3 client")
376+
377+
t.Run("normal", func(t *testing.T) {
378+
var buf bytes.Buffer
379+
ctx := context.Background()
380+
381+
// ACTION
382+
written, err := client.GetWithContext(
383+
ctx,
384+
testBucket,
385+
testObjectKey,
386+
"",
387+
&buf,
388+
)
389+
390+
// ASSERT
391+
require.NoError(t, err)
392+
assert.Equal(t, int64(len(testObjectData)), written)
393+
assert.Equal(t, testObjectData, buf.String())
394+
})
395+
396+
t.Run("cancelled", func(t *testing.T) {
397+
var buf bytes.Buffer
398+
ctx, cancel := context.WithCancel(context.Background())
399+
cancel() // cancel immediately
400+
401+
// ACTION
402+
written, err := client.GetWithContext(
403+
ctx,
404+
testBucket,
405+
testObjectKey,
406+
"",
407+
&buf,
408+
)
409+
410+
// ASSERT
411+
require.Error(t, err)
412+
assert.ErrorIs(t, err, context.Canceled)
413+
assert.Equal(t, int64(0), written)
414+
})
415+
416+
t.Run("cancel-during-processing", func(t *testing.T) {
417+
// We’ll cancel after a portion of the object has been written to the buffer.
418+
ctx, cancel := context.WithCancel(context.Background())
419+
var buf bytes.Buffer
420+
421+
// Choose a threshold smaller than the total size so we cancel mid-stream.
422+
sw := &cancelAfterNWriter{
423+
dst: &buf,
424+
cancel: cancel,
425+
limit: 4, // cancel after 4 bytes are written
426+
sleep: 0 * time.Millisecond, // optional; set to >0 to slow per-write
427+
}
428+
429+
// ACTION
430+
written, err := client.GetWithContext(
431+
ctx,
432+
testBucket,
433+
testObjectKey,
434+
"",
435+
sw,
436+
)
437+
t.Log("written bytes:", written)
438+
// ASSERT: it should end early with a context error and partial bytes written
439+
require.Error(t, err, "expected error due to mid-run cancellation")
440+
assert.ErrorIs(t, err, context.Canceled)
441+
assert.GreaterOrEqual(t, written, int64(1), "should write some bytes before cancel")
442+
assert.Equal(t, written, int64(buf.Len()), "buffer length should match reported written")
443+
assert.Less(t, written, int64(len(testObjectData)), "should not complete full object")
444+
})
445+
}
446+
447+
// cancelAfterNWriter writes at most limit bytes to dst.
448+
// Once limit is reached, it cancels ctx and returns context.Canceled.
449+
// If a single Write would exceed the limit, it performs a **partial write**
450+
// and then returns context.Canceled so the copy loop stops immediately.
451+
type cancelAfterNWriter struct {
452+
dst io.Writer
453+
cancel context.CancelFunc
454+
limit int64 // total bytes allowed before we cancel & error
455+
sleep time.Duration
456+
wrote int64
457+
}
458+
459+
func (w *cancelAfterNWriter) Write(p []byte) (int, error) {
460+
if w.sleep > 0 {
461+
time.Sleep(w.sleep)
462+
}
463+
464+
remaining := w.limit - w.wrote
465+
if remaining <= 0 {
466+
// Already reached the limit: cancel & error without writing.
467+
if w.cancel != nil {
468+
w.cancel()
469+
w.cancel = nil
470+
}
471+
return 0, context.Canceled
472+
}
473+
474+
// If the incoming chunk exceeds the remaining budget, do a **partial write**.
475+
if int64(len(p)) > remaining {
476+
// write only `remaining` bytes
477+
n, err := w.dst.Write(p[:remaining])
478+
if err != nil {
479+
return n, err
480+
}
481+
w.wrote += int64(n)
482+
// Now cancel & return error to abort the transfer
483+
if w.cancel != nil {
484+
w.cancel()
485+
w.cancel = nil
486+
}
487+
// ignore underlying err to ensure we signal cancel; return context.Canceled with partial write
488+
return n, context.Canceled
489+
}
490+
491+
// Normal path: whole chunk fits.
492+
n, err := w.dst.Write(p)
493+
w.wrote += int64(n)
494+
// If we *exactly* hit the limit after this write, cancel & error on the next call.
495+
if w.wrote >= w.limit && w.cancel != nil {
496+
w.cancel()
497+
w.cancel = nil
498+
}
499+
return n, err
500+
}
501+
363502
func TestS3GetByteRange(t *testing.T) {
364503
// ARRANGE
365504
setup()

0 commit comments

Comments
 (0)