@@ -43,20 +43,23 @@ func setAwsEnv() {
4343 os .Setenv ("AWS_SECRET_ACCESS_KEY" , "test" )
4444 os .Setenv ("AWS_ACCESS_KEY_ID" , "test" )
4545 os .Setenv ("AWS_ENDPOINT_URL" , customAWSEndpoint )
46+ os .Setenv ("AWS_S3_DISABLE_CHECKSUM" , "true" )
4647}
4748
4849func setup () {
4950 // setup environment variable to run AWS CLI/SDK
5051 setAwsEnv ()
5152
5253 // create bucket
53- if err := exec .Command ( //nolint:gosec
54+ cmd := exec .Command ( //nolint:gosec
5455 "aws" , "s3api" ,
5556 "create-bucket" ,
5657 "--bucket" , testBucket ,
5758 "--create-bucket-configuration" , fmt .Sprintf (
5859 "{\" LocationConstraint\" : \" %v\" }" , testRegion ),
59- ).Run (); err != nil {
60+ )
61+ if output , err := cmd .CombinedOutput (); err != nil {
62+ fmt .Printf ("Command failed: %v\n Output: %s\n " , err , string (output ))
6063 panic (err )
6164 }
6265}
@@ -146,11 +149,12 @@ func awsCmdPutKeys(keys []string) {
146149 testFile .Close ()
147150 }
148151 // sync to bucket
149- if err := exec .Command (
152+ cmd := exec .Command (
150153 "aws" , "s3" ,
151154 "sync" , tmpDir , fmt .Sprintf ("s3://%v" , testBucket ),
152- ).Run (); err != nil {
153-
155+ )
156+ if output , err := cmd .CombinedOutput (); err != nil {
157+ fmt .Printf ("Command failed: %v\n Output: %s\n " , err , string (output ))
154158 panic (err )
155159 }
156160}
@@ -360,6 +364,141 @@ func TestS3Get(t *testing.T) {
360364 assert .Equal (t , testObjectData , dataObject .String ())
361365}
362366
367+ func TestS3GetWithContext (t * testing.T ) {
368+ // ARRANGE
369+ setup ()
370+ defer teardown ()
371+
372+ awsCmdPopulateBucket ()
373+
374+ client , err := New ()
375+ require .NoError (t , err , "error creating s3 client" )
376+
377+ t .Run ("normal" , func (t * testing.T ) {
378+ var buf bytes.Buffer
379+ ctx := context .Background ()
380+
381+ // ACTION
382+ written , err := client .GetWithContext (
383+ ctx ,
384+ testBucket ,
385+ testObjectKey ,
386+ "" ,
387+ & buf ,
388+ )
389+
390+ // ASSERT
391+ require .NoError (t , err )
392+ assert .Equal (t , int64 (len (testObjectData )), written )
393+ assert .Equal (t , testObjectData , buf .String ())
394+ })
395+
396+ t .Run ("cancelled" , func (t * testing.T ) {
397+ var buf bytes.Buffer
398+ ctx , cancel := context .WithCancel (context .Background ())
399+ cancel () // cancel immediately
400+
401+ // ACTION
402+ written , err := client .GetWithContext (
403+ ctx ,
404+ testBucket ,
405+ testObjectKey ,
406+ "" ,
407+ & buf ,
408+ )
409+
410+ // ASSERT
411+ require .Error (t , err )
412+ assert .ErrorIs (t , err , context .Canceled )
413+ assert .Equal (t , int64 (0 ), written )
414+ })
415+
416+ t .Run ("cancel-during-processing" , func (t * testing.T ) {
417+ // We’ll cancel after a portion of the object has been written to the buffer.
418+ ctx , cancel := context .WithCancel (context .Background ())
419+ var buf bytes.Buffer
420+
421+ // Choose a threshold smaller than the total size so we cancel mid-stream.
422+ sw := & cancelAfterNWriter {
423+ dst : & buf ,
424+ cancel : cancel ,
425+ limit : 4 , // cancel after 4 bytes are written
426+ sleep : 0 * time .Millisecond , // optional; set to >0 to slow per-write
427+ }
428+
429+ // ACTION
430+ written , err := client .GetWithContext (
431+ ctx ,
432+ testBucket ,
433+ testObjectKey ,
434+ "" ,
435+ sw ,
436+ )
437+ t .Log ("written bytes:" , written )
438+ // ASSERT: it should end early with a context error and partial bytes written
439+ require .Error (t , err , "expected error due to mid-run cancellation" )
440+ assert .ErrorIs (t , err , context .Canceled )
441+ assert .GreaterOrEqual (t , written , int64 (1 ), "should write some bytes before cancel" )
442+ assert .Equal (t , written , int64 (buf .Len ()), "buffer length should match reported written" )
443+ assert .Less (t , written , int64 (len (testObjectData )), "should not complete full object" )
444+ })
445+ }
446+
447+ // cancelAfterNWriter writes at most limit bytes to dst.
448+ // Once limit is reached, it cancels ctx and returns context.Canceled.
449+ // If a single Write would exceed the limit, it performs a **partial write**
450+ // and then returns context.Canceled so the copy loop stops immediately.
451+ type cancelAfterNWriter struct {
452+ dst io.Writer
453+ cancel context.CancelFunc
454+ limit int64 // total bytes allowed before we cancel & error
455+ sleep time.Duration
456+ wrote int64
457+ }
458+
459+ func (w * cancelAfterNWriter ) Write (p []byte ) (int , error ) {
460+ if w .sleep > 0 {
461+ time .Sleep (w .sleep )
462+ }
463+
464+ remaining := w .limit - w .wrote
465+ if remaining <= 0 {
466+ // Already reached the limit: cancel & error without writing.
467+ if w .cancel != nil {
468+ w .cancel ()
469+ w .cancel = nil
470+ }
471+ return 0 , context .Canceled
472+ }
473+
474+ // If the incoming chunk exceeds the remaining budget, do a **partial write**.
475+ if int64 (len (p )) > remaining {
476+ // write only `remaining` bytes
477+ n , err := w .dst .Write (p [:remaining ])
478+ if err != nil {
479+ return n , err
480+ }
481+ w .wrote += int64 (n )
482+ // Now cancel & return error to abort the transfer
483+ if w .cancel != nil {
484+ w .cancel ()
485+ w .cancel = nil
486+ }
487+ // ignore underlying err to ensure we signal cancel; return context.Canceled with partial write
488+ return n , context .Canceled
489+ }
490+
491+ // Normal path: whole chunk fits.
492+ n , err := w .dst .Write (p )
493+ w .wrote += int64 (n )
494+ // If we *exactly* hit the limit after this write, cancel & error on the next call.
495+ if w .wrote >= w .limit && w .cancel != nil {
496+ w .cancel ()
497+ w .cancel = nil
498+ }
499+ return n , err
500+ }
501+
363502func TestS3GetByteRange (t * testing.T ) {
364503 // ARRANGE
365504 setup ()
0 commit comments