Skip to content

Commit 926b6db

Browse files
committed
Update mock logger functions to handle edge cases
1 parent 27ca368 commit 926b6db

File tree

2 files changed

+273
-0
lines changed

2 files changed

+273
-0
lines changed

mocklogger/mocklogger.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package mocklogger
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"os"
7+
"time"
8+
9+
"github.com/deploymenttheory/go-api-http-client/logger" // Assuming this is the package where Logger interface is defined
10+
"github.com/stretchr/testify/mock"
11+
"go.uber.org/zap"
12+
"go.uber.org/zap/zapcore"
13+
)
14+
15+
// MockLogger is a mock type for the Logger interface, embedding a *zap.Logger to satisfy the type requirement.
16+
type MockLogger struct {
17+
mock.Mock
18+
*zap.Logger
19+
logLevel logger.LogLevel
20+
}
21+
22+
// NewMockLogger creates a new instance of MockLogger with an embedded no-op *zap.Logger.
23+
func NewMockLogger() *MockLogger {
24+
return &MockLogger{
25+
Logger: zap.NewNop(),
26+
}
27+
}
28+
29+
// Ensure MockLogger implements the logger.Logger interface from the logger package
30+
var _ logger.Logger = (*MockLogger)(nil)
31+
32+
func (m *MockLogger) GetLogLevel() logger.LogLevel {
33+
args := m.Called()
34+
if args.Get(0) != nil { // Check if the Called method has a return value
35+
return args.Get(0).(logger.LogLevel)
36+
}
37+
return logger.LogLevelNone // Return LogLevelNone if no specific log level is set
38+
}
39+
40+
func (m *MockLogger) SetLevel(level logger.LogLevel) {
41+
m.logLevel = level
42+
m.Called(level)
43+
}
44+
45+
func (m *MockLogger) With(fields ...zapcore.Field) logger.Logger {
46+
m.Called(fields)
47+
// This is a mock implementation; adjust as necessary for your tests
48+
return m
49+
}
50+
51+
// Debug logs a message at the Debug level.
52+
func (m *MockLogger) Debug(msg string, fields ...zapcore.Field) {
53+
m.Called(msg, fields)
54+
if m.logLevel <= logger.LogLevelDebug {
55+
fmt.Printf("[DEBUG] %s\n", msg)
56+
}
57+
}
58+
59+
// Info logs a message at the Info level.
60+
func (m *MockLogger) Info(msg string, fields ...zapcore.Field) {
61+
m.Called(msg, fields)
62+
if m.logLevel <= logger.LogLevelInfo {
63+
fmt.Printf("[INFO] %s\n", msg)
64+
}
65+
}
66+
67+
// Error logs a message at the Error level and returns an error.
68+
func (m *MockLogger) Error(msg string, fields ...zapcore.Field) error {
69+
m.Called(msg, fields)
70+
if m.logLevel <= logger.LogLevelError {
71+
fmt.Printf("[ERROR] %s\n", msg)
72+
}
73+
return errors.New(msg)
74+
}
75+
76+
// Warn logs a message at the Warn level.
77+
func (m *MockLogger) Warn(msg string, fields ...zapcore.Field) {
78+
m.Called(msg, fields)
79+
if m.logLevel <= logger.LogLevelWarn {
80+
fmt.Printf("[WARN] %s\n", msg)
81+
}
82+
}
83+
84+
// Panic logs a message at the Panic level and then panics.
85+
func (m *MockLogger) Panic(msg string, fields ...zapcore.Field) {
86+
m.Called(msg, fields)
87+
if m.logLevel <= logger.LogLevelPanic {
88+
fmt.Printf("[PANIC] %s\n", msg)
89+
panic(msg)
90+
}
91+
}
92+
93+
// Fatal logs a message at the Fatal level and then calls os.Exit(1).
94+
func (m *MockLogger) Fatal(msg string, fields ...zapcore.Field) {
95+
m.Called(msg, fields)
96+
if m.logLevel <= logger.LogLevelFatal {
97+
fmt.Printf("[FATAL] %s\n", msg)
98+
os.Exit(1)
99+
}
100+
}
101+
102+
// LogRequestStart logs the start of an HTTP request.
103+
func (m *MockLogger) LogRequestStart(event string, requestID string, userID string, method string, url string, headers map[string][]string) {
104+
m.Called(event, requestID, userID, method, url, headers)
105+
// Mock logging implementation...
106+
}
107+
108+
// LogRequestEnd logs the end of an HTTP request.
109+
func (m *MockLogger) LogRequestEnd(event string, method string, url string, statusCode int, duration time.Duration) {
110+
m.Called(event, method, url, statusCode, duration)
111+
// Mock logging implementation...
112+
}
113+
114+
// LogError logs an error event.
115+
func (m *MockLogger) LogError(event string, method string, url string, statusCode int, serverStatusMessage string, err error, rawResponse string) {
116+
m.Called(event, method, url, statusCode, serverStatusMessage, err, rawResponse)
117+
// Mock logging implementation...
118+
}
119+
120+
// Example for LogAuthTokenError:
121+
func (m *MockLogger) LogAuthTokenError(event string, method string, url string, statusCode int, err error) {
122+
m.Called(event, method, url, statusCode, err)
123+
// Mock logging implementation...
124+
}
125+
126+
// LogCookies logs information about cookies.
127+
func (m *MockLogger) LogCookies(direction string, obj interface{}, method, url string) {
128+
// Use the mock framework to record that LogCookies was called with the specified arguments
129+
m.Called(direction, obj, method, url)
130+
fmt.Printf("[COOKIES] Direction: %s, Object: %v, Method: %s, URL: %s\n", direction, obj, method, url)
131+
}
132+
133+
// LogRetryAttempt logs a retry attempt.
134+
func (m *MockLogger) LogRetryAttempt(event string, method string, url string, attempt int, reason string, waitDuration time.Duration, err error) {
135+
m.Called(event, method, url, attempt, reason, waitDuration, err)
136+
// Mock logging implementation...
137+
fmt.Printf("[RETRY ATTEMPT] Event: %s, Method: %s, URL: %s, Attempt: %d, Reason: %s, Wait Duration: %s, Error: %v\n", event, method, url, attempt, reason, waitDuration, err)
138+
}
139+
140+
// LogRateLimiting logs rate limiting events.
141+
func (m *MockLogger) LogRateLimiting(event string, method string, url string, retryAfter string, waitDuration time.Duration) {
142+
m.Called(event, method, url, retryAfter, waitDuration)
143+
// Mock logging implementation...
144+
fmt.Printf("[RATE LIMITING] Event: %s, Method: %s, URL: %s, Retry After: %s, Wait Duration: %s\n", event, method, url, retryAfter, waitDuration)
145+
}
146+
147+
// LogResponse logs HTTP responses.
148+
func (m *MockLogger) LogResponse(event string, method string, url string, statusCode int, responseBody string, responseHeaders map[string][]string, duration time.Duration) {
149+
m.Called(event, method, url, statusCode, responseBody, responseHeaders, duration)
150+
// Mock logging implementation...
151+
fmt.Printf("[RESPONSE] Event: %s, Method: %s, URL: %s, Status Code: %d, Response Body: %s, Response Headers: %v, Duration: %s\n", event, method, url, statusCode, responseBody, responseHeaders, duration)
152+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package redirecthandler
2+
3+
import (
4+
"net/http"
5+
"net/url"
6+
"testing"
7+
8+
"github.com/deploymenttheory/go-api-http-client/mocklogger"
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestRedirectHandler_CheckRedirect(t *testing.T) {
13+
mockLogger := mocklogger.NewMockLogger()
14+
15+
// Set the mock logger to capture logs at all levels
16+
mockLogger.SetLevel(mocklogger.LogLevelDebug)
17+
18+
redirectHandler := NewRedirectHandler(mockLogger, 10)
19+
20+
reqURL, _ := url.Parse("http://example.com")
21+
req := &http.Request{URL: reqURL, Method: http.MethodPost}
22+
resp := &http.Response{
23+
Status: "303 See Other",
24+
StatusCode: http.StatusSeeOther,
25+
Header: http.Header{"Location": []string{"http://example.com/new"}},
26+
}
27+
28+
t.Run("Redirect Loop Detection", func(t *testing.T) {
29+
redirectHandler.VisitedURLs = map[string]int{"http://example.com": 1}
30+
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
31+
assert.Equal(t, http.ErrUseLastResponse, err)
32+
// Verify that a warning log for redirect loop was recorded
33+
assert.Contains(t, mockLogger.Calls[0].Arguments.String(0), "Detected redirect loop")
34+
})
35+
36+
t.Run("Maximum Redirects Reached", func(t *testing.T) {
37+
redirectHandler.VisitedURLs = map[string]int{}
38+
redirectHandler.MaxRedirects = 1
39+
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
40+
assert.Equal(t, http.ErrUseLastResponse, err)
41+
// Verify that a warning log for max redirects was recorded
42+
assert.Contains(t, mockLogger.Calls[1].Arguments.String(0), "Stopped after maximum redirects")
43+
})
44+
45+
t.Run("Resolve Relative Redirects", func(t *testing.T) {
46+
redirectHandler.MaxRedirects = 10
47+
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
48+
assert.Nil(t, err)
49+
assert.Equal(t, "http://example.com/new", req.URL.String())
50+
})
51+
52+
t.Run("Cross-Domain Security Measures", func(t *testing.T) {
53+
reqURL, _ = url.Parse("http://example.com")
54+
req = &http.Request{URL: reqURL, Method: http.MethodPost}
55+
resp.Header.Set("Location", "http://anotherdomain.com/new")
56+
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
57+
assert.Nil(t, err)
58+
// Ensure sensitive headers are removed and corresponding log is recorded
59+
assert.Empty(t, req.Header.Get("Authorization"))
60+
assert.Contains(t, mockLogger.Calls[2].Arguments.String(0), "Removed sensitive header")
61+
})
62+
63+
t.Run("Handling 303 See Other", func(t *testing.T) {
64+
reqURL, _ = url.Parse("http://example.com")
65+
req = &http.Request{URL: reqURL, Method: http.MethodPost}
66+
resp.Header.Set("Location", "http://example.com/new")
67+
err := redirectHandler.checkRedirect(req, []*http.Request{{}, {}})
68+
assert.Nil(t, err)
69+
assert.Equal(t, http.MethodGet, req.Method)
70+
// Ensure no body, no GetBody, correct ContentLength, no Content-Type header, and a log is recorded
71+
assert.Nil(t, req.Body)
72+
assert.Nil(t, req.GetBody)
73+
assert.Equal(t, int64(0), req.ContentLength)
74+
assert.Empty(t, req.Header.Get("Content-Type"))
75+
assert.Contains(t, mockLogger.Calls[3].Arguments.String(0), "Changed request method to GET")
76+
})
77+
}
78+
79+
func TestRedirectHandler_ResolveRedirectURL(t *testing.T) {
80+
redirectHandler := RedirectHandler{}
81+
82+
t.Run("Absolute URL", func(t *testing.T) {
83+
reqURL, _ := url.Parse("http://example.com")
84+
redirectURL, _ := url.Parse("http://newexample.com/path")
85+
newReqURL, err := redirectHandler.resolveRedirectURL(reqURL, redirectURL)
86+
assert.Nil(t, err)
87+
assert.Equal(t, redirectURL.String(), newReqURL.String())
88+
})
89+
90+
t.Run("Relative URL", func(t *testing.T) {
91+
reqURL, _ := url.Parse("http://example.com/current")
92+
redirectURL, _ := url.Parse("/newpath")
93+
newReqURL, err := redirectHandler.resolveRedirectURL(reqURL, redirectURL)
94+
assert.Nil(t, err)
95+
assert.Equal(t, "http://example.com/newpath", newReqURL.String())
96+
})
97+
98+
t.Run("Relative URL with Query and Fragment", func(t *testing.T) {
99+
reqURL, _ := url.Parse("http://example.com/current?param=value#fragment")
100+
redirectURL, _ := url.Parse("newpath?newparam=newvalue#newfragment")
101+
newReqURL, err := redirectHandler.resolveRedirectURL(reqURL, redirectURL)
102+
assert.Nil(t, err)
103+
assert.Equal(t, "http://example.com/newpath?newparam=newvalue#newfragment", newReqURL.String())
104+
})
105+
}
106+
107+
func TestRedirectHandler_SecureRequest(t *testing.T) {
108+
mockLogger := mocklogger.NewMockLogger()
109+
mockLogger.SetLevel(mocklogger.LogLevelDebug)
110+
111+
redirectHandler := RedirectHandler{Logger: mockLogger}
112+
req := &http.Request{Header: http.Header{"Authorization": []string{"token"}, "Cookie": []string{"session"}}}
113+
114+
t.Run("Secure Cross-Domain Redirect", func(t *testing.T) {
115+
redirectHandler.secureRequest(req)
116+
// Ensure sensitive headers are removed and log messages were recorded
117+
assert.Empty(t, req.Header.Get("Authorization"))
118+
assert.Empty(t, req.Header.Get("Cookie"))
119+
assert.Contains(t, mockLogger.Calls[0].Arguments.String(0), "Removed sensitive header")
120+
})
121+
}

0 commit comments

Comments
 (0)