Skip to content

Commit cbc65f0

Browse files
authored
Merge pull request #89 from Teamwork/httpx-exp-backoff
Fix request body handling in DoExponentialBackoff
2 parents 929cb93 + 2718338 commit cbc65f0

2 files changed

Lines changed: 111 additions & 9 deletions

File tree

httputilx/httputilx.go

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,9 @@ func DoExponentialBackoff(req *http.Request, options ...ExponentialBackoffOption
233233
backoff := o.initialBackoff
234234

235235
for attempt := 0; attempt <= o.maxRetries; attempt++ {
236-
reqClone := req.Clone(req.Context())
237-
if req.Body != nil {
238-
if seeker, ok := req.Body.(interface {
239-
Seek(int64, int) (int64, error)
240-
}); ok {
241-
_, _ = seeker.Seek(0, 0)
242-
}
243-
reqClone.Body = req.Body
236+
reqClone, err := cloneWithBody(req)
237+
if err != nil {
238+
return nil, errors.Wrap(err, "failed to clone request with body")
244239
}
245240

246241
resp, err := o.client.Do(reqClone)
@@ -272,3 +267,56 @@ func DoExponentialBackoff(req *http.Request, options ...ExponentialBackoffOption
272267

273268
return nil, fmt.Errorf("request failed after %d attempts", o.maxRetries+1)
274269
}
270+
271+
func cloneWithBody(req *http.Request) (*http.Request, error) {
272+
newReq := req.Clone(req.Context())
273+
if req.Body == nil {
274+
return newReq, nil
275+
}
276+
if req.GetBody != nil {
277+
var err error
278+
newReq.Body, err = req.GetBody()
279+
if err != nil {
280+
return nil, err
281+
}
282+
return newReq, nil
283+
}
284+
285+
if seeker, ok := req.Body.(io.Seeker); ok {
286+
if _, err := seeker.Seek(0, io.SeekStart); err != nil {
287+
return nil, err
288+
}
289+
newReq.Body = req.Body
290+
newReq.GetBody = func() (io.ReadCloser, error) {
291+
if _, err := seeker.Seek(0, io.SeekStart); err != nil {
292+
return nil, err
293+
}
294+
return req.Body, nil
295+
}
296+
return newReq, nil
297+
}
298+
299+
bodyBytes, err := io.ReadAll(req.Body)
300+
if err != nil {
301+
return nil, err
302+
}
303+
if err := req.Body.Close(); err != nil {
304+
return nil, err
305+
}
306+
307+
createBody := func(bodyBytes []byte) func() (io.ReadCloser, error) {
308+
return func() (io.ReadCloser, error) {
309+
return io.NopCloser(bytes.NewReader(bodyBytes)), nil
310+
}
311+
}
312+
313+
req.GetBody = createBody(bodyBytes)
314+
req.Body, _ = req.GetBody()
315+
req.ContentLength = int64(len(bodyBytes))
316+
317+
newReq.GetBody = createBody(bodyBytes)
318+
newReq.Body, _ = newReq.GetBody()
319+
newReq.ContentLength = int64(len(bodyBytes))
320+
321+
return newReq, nil
322+
}

httputilx/httputilx_test.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ func TestDoExponentialBackoff(t *testing.T) {
230230
name string
231231
options []ExponentialBackoffOption
232232
handler http.HandlerFunc
233+
requestBody io.Reader
233234
wantBody string
234235
wantErr string
235236
wantAttempts int
@@ -323,6 +324,55 @@ func TestDoExponentialBackoff(t *testing.T) {
323324
wantErr: "",
324325
wantAttempts: 3,
325326
},
327+
{
328+
name: "RequestBodyCopiedOnRetry",
329+
options: []ExponentialBackoffOption{
330+
ExponentialBackoffWithConfig(4, 100*time.Millisecond, 5*time.Second, 2.0),
331+
},
332+
handler: func() http.HandlerFunc {
333+
initialBody := "request body content"
334+
335+
attempts := 0
336+
return func(w http.ResponseWriter, r *http.Request) {
337+
attempts++
338+
339+
if r.ContentLength != int64(len(initialBody)) {
340+
w.WriteHeader(http.StatusInternalServerError)
341+
_, _ = fmt.Fprintf(w, "wrong content-length: got %d, want %d", r.ContentLength, len(initialBody))
342+
return
343+
}
344+
345+
body, err := io.ReadAll(r.Body)
346+
if err != nil {
347+
w.WriteHeader(http.StatusInternalServerError)
348+
return
349+
}
350+
351+
if len(body) != len(initialBody) {
352+
w.WriteHeader(http.StatusInternalServerError)
353+
_, _ = fmt.Fprintf(w, "content-length mismatch: header=%d actual=%d", r.ContentLength, len(body))
354+
return
355+
}
356+
357+
// Verify body is correctly sent on all attempts
358+
if string(body) != initialBody {
359+
w.WriteHeader(http.StatusInternalServerError)
360+
_, _ = fmt.Fprintf(w, "incorrect body: %q", string(body))
361+
return
362+
}
363+
if attempts < 3 {
364+
w.WriteHeader(http.StatusInternalServerError)
365+
return
366+
}
367+
w.WriteHeader(http.StatusOK)
368+
_, _ = w.Write([]byte("body received correctly"))
369+
}
370+
}(),
371+
requestBody: io.NopCloser(bytes.NewBuffer([]byte("request body content"))),
372+
wantBody: "body received correctly",
373+
wantErr: "",
374+
wantAttempts: 3,
375+
},
326376
}
327377

328378
for _, tt := range tests {
@@ -334,7 +384,11 @@ func TestDoExponentialBackoff(t *testing.T) {
334384
}))
335385
defer ts.Close()
336386

337-
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
387+
method := http.MethodGet
388+
if tt.requestBody != nil {
389+
method = http.MethodPost
390+
}
391+
req, err := http.NewRequest(method, ts.URL, tt.requestBody)
338392
if err != nil {
339393
t.Fatalf("failed to create request: %v", err)
340394
}

0 commit comments

Comments
 (0)