11package redirecthandler
22
33import (
4+ "fmt"
45 "net/http"
56 "net/url"
67 "sync"
@@ -10,52 +11,70 @@ import (
1011 "go.uber.org/zap"
1112)
1213
13- // RedirectHandler handles HTTP redirects within an http.Client.
14- // It provides features such as redirect loop detection, security enhancements,
15- // and integration with client settings for fine-grained control over redirect behavior.
14+ // RedirectHandler contains configurations for handling HTTP redirects.
1615type RedirectHandler struct {
17- Logger logger.Logger
18- MaxRedirects int
19- VisitedURLs map [string ]int
20- VisitedURLsMutex sync.Mutex
21- SensitiveHeaders []string
16+ Logger logger.Logger // Logger instance for logging.
17+ MaxRedirects int // Maximum allowed redirects to prevent infinite loops.
18+ VisitedURLs map [string ]int // Tracks visited URLs to detect loops.
19+ VisitedURLsMutex sync.RWMutex // Mutex for safe concurrent access to VisitedURLs.
20+ SensitiveHeaders []string // Headers to be removed on cross-domain redirects.
21+ PermanentRedirects map [string ]string // Cache for permanent redirects
22+ PermRedirectsMutex sync.RWMutex // Mutex for safe concurrent access to PermanentRedirects
23+ RedirectHistories map [* http.Request ][]* url.URL // Map to track redirect history for each request
2224}
2325
24- // NewRedirectHandler creates a new instance of RedirectHandler with the provided logger
25- // and maximum number of redirects. It initializes internal structures and is ready to use.
26+ // NewRedirectHandler creates a new instance of RedirectHandler.
2627func NewRedirectHandler (logger logger.Logger , maxRedirects int ) * RedirectHandler {
2728 return & RedirectHandler {
28- Logger : logger ,
29- MaxRedirects : maxRedirects ,
30- VisitedURLs : make (map [string ]int ),
31- SensitiveHeaders : []string {"Authorization" , "Cookie" }, // Add other sensitive headers if needed
29+ Logger : logger ,
30+ MaxRedirects : maxRedirects ,
31+ VisitedURLs : make (map [string ]int ),
32+ SensitiveHeaders : []string {"Authorization" , "Cookie" },
33+ PermanentRedirects : make (map [string ]string ),
34+ RedirectHistories : make (map [* http.Request ][]* url.URL ),
3235 }
3336}
3437
38+ // AddSensitiveHeader allows adding configurable sensitive headers.
39+ func (r * RedirectHandler ) AddSensitiveHeader (header string ) {
40+ r .SensitiveHeaders = append (r .SensitiveHeaders , header )
41+ }
42+
3543// WithRedirectHandling applies the redirect handling policy to an http.Client.
36- // It sets the CheckRedirect function on the client to use the handler's logic.
3744func (r * RedirectHandler ) WithRedirectHandling (client * http.Client ) {
3845 client .CheckRedirect = r .checkRedirect
3946}
4047
41- // checkRedirect is the core function that implements the redirect handling logic.
42- // It is set as the CheckRedirect function on an http.Client and is called whenever
43- // the client encounters a 3XX response. It enforces the max redirects limit,
44- // detects redirect loops, applies security measures for cross-domain redirects,
45- // resolves relative redirects, and optimizes performance.
48+ // checkRedirect implements the redirect handling logic.
4649func (r * RedirectHandler ) checkRedirect (req * http.Request , via []* http.Request ) error {
47- // Redirect Loop Detection
48- r .VisitedURLsMutex .Lock ()
49- defer r .VisitedURLsMutex .Unlock ()
50- if _ , exists := r .VisitedURLs [req .URL .String ()]; exists {
51- r .Logger .Warn ("Detected redirect loop" , zap .String ("url" , req .URL .String ()))
52- return http .ErrUseLastResponse
50+ defer r .clearRedirectHistory (req ) // Ensure redirect history is always cleared to prevent memory leaks
51+
52+ // Check for cached permanent redirect
53+ if urlString , ok := r .checkPermanentRedirect (req .URL .String ()); ok && (req .Method == http .MethodGet || req .Method == http .MethodHead ) {
54+ parsedURL , err := url .Parse (urlString )
55+ if err != nil {
56+ r .Logger .Error ("Failed to parse URL from cache" , zap .String ("url" , urlString ), zap .Error (err ))
57+ // Continue with the original URL since the cached URL is invalid
58+ } else {
59+ req .URL = parsedURL // Use cached redirect location
60+ r .Logger .Info ("Using cached permanent redirect" , zap .String ("originalURL" , urlString ), zap .String ("redirectURL" , parsedURL .String ()))
61+ return nil
62+ }
5363 }
54- r .VisitedURLs [req .URL .String ()]++
5564
65+ // Track redirect history for the current request
66+ r .RedirectHistories [req ] = append (r .RedirectHistories [req ], req .URL )
67+
68+ // Check for redirect loops by analyzing the history
69+ if hasLoop (r .RedirectHistories [req ]) {
70+ r .Logger .Error ("Redirect loop detected" , zap .Any ("redirectHistory" , r .RedirectHistories [req ]))
71+ return fmt .Errorf ("redirect loop detected: %v" , r .RedirectHistories [req ])
72+ }
73+
74+ // Enforce max redirects
5675 if len (via ) >= r .MaxRedirects {
57- r .Logger .Warn ("Stopped after maximum redirects" , zap .Int ("maxRedirects" , r .MaxRedirects ))
58- return http . ErrUseLastResponse
76+ r .Logger .Warn ("Maximum redirects reached " , zap .Int ("maxRedirects" , r .MaxRedirects ))
77+ return & MaxRedirectsError { MaxRedirects : r . MaxRedirects }
5978 }
6079
6180 lastResponse := via [len (via )- 1 ].Response
@@ -66,69 +85,111 @@ func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request)
6685 return err
6786 }
6887
69- // Resolve relative redirects against the current request URL
7088 newReqURL , err := r .resolveRedirectURL (req .URL , location )
7189 if err != nil {
7290 r .Logger .Error ("Failed to resolve redirect URL" , zap .Error (err ))
7391 return err
7492 }
7593
76- // Security Measures
94+ // Apply security measures for cross-domain redirects
7795 if newReqURL .Host != req .URL .Host {
7896 r .secureRequest (req )
7997 }
8098
81- // Handling 303 See Other
82- if lastResponse .StatusCode == http .StatusSeeOther {
83- req .Method = http .MethodGet
84- req .Body = nil
85- req .GetBody = nil
86- req .ContentLength = 0
87- req .Header .Del ("Content-Type" )
88- r .Logger .Info ("Changed request method to GET for 303 See Other response" )
99+ // Cache permanent redirects
100+ if status .IsPermanentRedirect (lastResponse .StatusCode ) {
101+ r .cachePermanentRedirect (req .URL .String (), newReqURL .String ())
89102 }
90103
91- // Logging enhancements
92- r .Logger .Info ("Redirecting request" ,
93- zap .String ("originalURL" , req .URL .String ()),
94- zap .String ("newURL" , newReqURL .String ()),
95- zap .String ("method" , req .Method ),
96- zap .Int ("redirectCount" , len (via )),
97- )
98-
99- // Log removed sensitive headers
100- for _ , header := range r .SensitiveHeaders {
101- r .Logger .Info ("Removed sensitive header due to domain change" ,
102- zap .String ("header" , header ),
103- )
104+ // Special handling for 303 See Other
105+ if lastResponse .StatusCode == http .StatusSeeOther {
106+ r .adjustForSeeOther (req )
104107 }
105108
106- req .URL = newReqURL
109+ r .Logger .Info ("Redirecting request" , zap .String ("originalURL" , req .URL .String ()), zap .String ("newURL" , newReqURL .String ()), zap .Int ("redirectCount" , len (via )))
110+ req .URL = newReqURL // Update request URL to follow the redirect
107111 return nil
108112 }
109113
110- return http .ErrUseLastResponse
114+ return http .ErrUseLastResponse // No further action required if not a redirect status code
111115}
112116
113- // resolveRedirectURL resolves the redirect location URL against the current request URL
114- // to handle relative redirects accurately.
117+ // resolveRedirectURL resolves the redirect location URL against the current request URL.
115118func (r * RedirectHandler ) resolveRedirectURL (reqURL * url.URL , redirectURL * url.URL ) (* url.URL , error ) {
116- if redirectURL .IsAbs () {
117- return redirectURL , nil // Absolute URL, no need to resolve
119+ if ! redirectURL .IsAbs () {
120+ redirectURL . Scheme = reqURL . Scheme // Preserve the scheme
118121 }
119-
120- // Relative URL, resolve against the current request URL
121- absoluteURL := * reqURL
122- absoluteURL .Path = redirectURL .Path
123- absoluteURL .RawQuery = redirectURL .RawQuery
124- absoluteURL .Fragment = redirectURL .Fragment
125- return & absoluteURL , nil
122+ return redirectURL , nil
126123}
127124
128125// secureRequest removes sensitive headers from the request if the new destination is a different domain.
129126func (r * RedirectHandler ) secureRequest (req * http.Request ) {
130127 for _ , header := range r .SensitiveHeaders {
131128 req .Header .Del (header )
132- r .Logger .Info ("Removed sensitive header due to domain change" , zap .String ("header" , header ))
133129 }
134130}
131+
132+ // adjustForSeeOther adjusts the request for "303 See Other" responses.
133+ func (r * RedirectHandler ) adjustForSeeOther (req * http.Request ) {
134+ req .Method = http .MethodGet
135+ req .Body = nil
136+ req .GetBody = nil
137+ req .ContentLength = 0
138+ req .Header .Del ("Content-Type" )
139+ }
140+
141+ // RedirectLoopError represents an error when a redirect loop is detected.
142+ type RedirectLoopError struct {
143+ URL string
144+ }
145+
146+ // RedirectLoopError defines an error for when a redirect loop is detected.
147+ func (e * RedirectLoopError ) Error () string {
148+ return fmt .Sprintf ("redirect loop detected at %s" , e .URL )
149+ }
150+
151+ // MaxRedirectsError represents an error when the maximum number of redirects is reached.
152+ type MaxRedirectsError struct {
153+ MaxRedirects int
154+ }
155+
156+ // MaxRedirectsError defines an error for when the maximum number of redirects is reached.
157+ func (e * MaxRedirectsError ) Error () string {
158+ return fmt .Sprintf ("maximum redirects reached: %d" , e .MaxRedirects )
159+ }
160+
161+ // cachePermanentRedirect caches the permanent redirect location.
162+ func (r * RedirectHandler ) cachePermanentRedirect (originalURL , redirectURL string ) {
163+ r .PermRedirectsMutex .Lock ()
164+ defer r .PermRedirectsMutex .Unlock ()
165+
166+ r .PermanentRedirects [originalURL ] = redirectURL
167+ }
168+
169+ // checkPermanentRedirect checks if there's a cached redirect for the given URL.
170+ func (r * RedirectHandler ) checkPermanentRedirect (originalURL string ) (string , bool ) {
171+ r .PermRedirectsMutex .RLock ()
172+ defer r .PermRedirectsMutex .RUnlock ()
173+
174+ url , exists := r .PermanentRedirects [originalURL ]
175+ return url , exists
176+ }
177+
178+ // hasLoop checks if there's a loop in the redirect history.
179+ func hasLoop (history []* url.URL ) bool {
180+ urlSet := make (map [string ]struct {})
181+ for _ , url := range history {
182+ if _ , exists := urlSet [url .String ()]; exists {
183+ return true // Loop detected
184+ }
185+ urlSet [url .String ()] = struct {}{}
186+ }
187+ return false
188+ }
189+
190+ // clearRedirectHistory clears the redirect history for a given request to prevent memory leaks.
191+ func (r * RedirectHandler ) clearRedirectHistory (req * http.Request ) {
192+ r .VisitedURLsMutex .Lock () // Use the appropriate mutex to synchronize access to RedirectHistories
193+ delete (r .RedirectHistories , req )
194+ r .VisitedURLsMutex .Unlock ()
195+ }
0 commit comments