diff --git a/transport/httptransport/http_transport.go b/transport/httptransport/http_transport.go index 0716aca1e..32a91c3a6 100644 --- a/transport/httptransport/http_transport.go +++ b/transport/httptransport/http_transport.go @@ -210,8 +210,12 @@ func (h *httpTransport) Execute(ctx context.Context, transportInfo []byte, dealI } else { // do not follow http redirects for security reasons t.client = &http.Client{ + // Custom CheckRedirect function to limit redirects CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse + if len(via) >= 2 { // Limit to 2 redirects + return http.ErrUseLastResponse + } + return nil }, } diff --git a/transport/httptransport/http_transport_test.go b/transport/httptransport/http_transport_test.go index 240fad8a7..5799fe495 100644 --- a/transport/httptransport/http_transport_test.go +++ b/transport/httptransport/http_transport_test.go @@ -387,8 +387,8 @@ func TestDownloadFromPrivateIPs(t *testing.T) { } func TestDontFollowHttpRedirects(t *testing.T) { - // we should not follow http redirects for security reasons. If the target URL tries to redirect, the client should return 303 response instead. - // This test sets up two servers, with one redirecting to the other. Without the redirect check the download would have been completed successfully. + // we should not follow more than 2 http redirects for security reasons. If the target URL tries to redirect, the client should return 303 response instead. + // This test sets up 3 servers, with one redirecting to the other. Without the redirect check the download would have been completed successfully. rawSize := (100 * readBufferSize) + 30 ctx := context.Background() st := newServerTest(t, rawSize) @@ -422,8 +422,14 @@ func TestDontFollowHttpRedirects(t *testing.T) { redirectSvr := httptest.NewServer(redirectHandler) defer redirectSvr.Close() + var redirectHandler1 http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, redirectSvr.URL, http.StatusSeeOther) + } + redirectSvr1 := httptest.NewServer(redirectHandler1) + defer redirectSvr1.Close() + of := getTempFilePath(t) - th := executeTransfer(t, ctx, New(nil, newDealLogger(t, ctx), NChunksOpt(numChunks)), carSize, types.HttpRequest{URL: redirectSvr.URL}, of) + th := executeTransfer(t, ctx, New(nil, newDealLogger(t, ctx), NChunksOpt(numChunks)), carSize, types.HttpRequest{URL: redirectSvr1.URL}, of) require.NotNil(t, th) evts := waitForTransferComplete(th)