diff --git a/async_handoff_integration_test.go b/async_handoff_integration_test.go index b8925cad4f..eaea90f23c 100644 --- a/async_handoff_integration_test.go +++ b/async_handoff_integration_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" - "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/logging" + "github.com/redis/go-redis/v9/maintnotifications" ) // mockNetConn implements net.Conn for testing diff --git a/internal/maintnotifications/logs/log_messages.go b/internal/maintnotifications/logs/log_messages.go index 34cb1692d9..ae6434d366 100644 --- a/internal/maintnotifications/logs/log_messages.go +++ b/internal/maintnotifications/logs/log_messages.go @@ -295,11 +295,11 @@ func RemovingConnectionFromPool(connID uint64, reason error) string { }) } -func NoPoolProvidedCannotRemove(connID uint64, reason error) string { - message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason) +func NoPoolProvidedCannotRemove(connID uint64) string { + message := fmt.Sprintf("conn[%d] %s", connID, NoPoolProvidedMessageCannotRemoveMessage) return appendJSONIfDebug(message, map[string]interface{}{ "connID": connID, - "reason": reason.Error(), + "reason": nil, }) } diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 0a6453c7c9..0a424d3cd1 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -11,6 +11,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" ) var ( @@ -115,6 +116,9 @@ type Options struct { // DialerRetryTimeout is the backoff duration between retry attempts. // Default: 100ms DialerRetryTimeout time.Duration + + // Optional logger for connection pool operations. + Logger *logging.CustomLogger } type lastDialErrorWrap struct { @@ -223,7 +227,7 @@ func (p *ConnPool) checkMinIdleConns() { p.idleConnsLen.Add(-1) p.freeTurn() - internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) + p.logger().Errorf(context.Background(), "addIdleConn panic: %+v", err) } }() @@ -379,7 +383,7 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { return cn, nil } - internal.Logger.Printf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr) + p.logger().Errorf(ctx, "redis: connection pool: failed to dial after %d attempts: %v", maxRetries, lastErr) // All retries failed - handle error tracking p.setLastDialError(lastErr) if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.cfg.PoolSize) { @@ -452,7 +456,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { for { if attempts >= getAttempts { - internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) + p.logger().Errorf(ctx, "redis: connection pool: was not able to get a healthy connection after %d attempts", attempts) break } attempts++ @@ -479,12 +483,12 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { if hookManager != nil { acceptConn, err := hookManager.ProcessOnGet(ctx, cn, false) if err != nil { - internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) + p.logger().Errorf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) _ = p.CloseConn(cn) continue } if !acceptConn { - internal.Logger.Printf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) + p.logger().Errorf(ctx, "redis: connection pool: conn[%d] rejected by hook, returning to pool", cn.GetID()) p.Put(ctx, cn) cn = nil continue @@ -509,7 +513,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { // this should not happen with a new connection, but we handle it gracefully if err != nil || !acceptConn { // Failed to process connection, discard it - internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) + p.logger().Errorf(ctx, "redis: connection pool: failed to process new connection conn[%d] by hook: accept=%v, err=%v", newcn.GetID(), acceptConn, err) _ = p.CloseConn(newcn) return nil, err } @@ -703,7 +707,7 @@ func (p *ConnPool) popIdle() (*Conn, error) { // If we exhausted all attempts without finding a usable connection, return nil if attempts > 1 && attempts >= maxAttempts && int32(attempts) >= p.poolSize.Load() { - internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) + p.logger().Errorf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", attempts) return nil, nil } @@ -720,7 +724,7 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { // Peek at the reply type to check if it's a push notification if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush { // Not a push notification or error peeking, remove connection - internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it") + p.logger().Errorf(ctx, "Conn has unread data (not push notification), removing it") p.Remove(ctx, cn, err) } // It's a push notification, allow pooling (client will handle it) @@ -733,7 +737,7 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) { if hookManager != nil { shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) if err != nil { - internal.Logger.Printf(ctx, "Connection hook error: %v", err) + p.logger().Errorf(ctx, "Connection hook error: %v", err) p.Remove(ctx, cn, err) return } @@ -835,7 +839,7 @@ func (p *ConnPool) removeConn(cn *Conn) { // this can be idle conn for idx, ic := range p.idleConns { if ic.GetID() == cid { - internal.Logger.Printf(context.Background(), "redis: connection pool: removing idle conn[%d]", cid) + p.logger().Infof(context.Background(), "redis: connection pool: removing idle conn[%d]", cid) p.idleConns = append(p.idleConns[:idx], p.idleConns[idx+1:]...) p.idleConnsLen.Add(-1) break @@ -951,7 +955,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { // For RESP3 connections with push notifications, we allow some buffered data // The client will process these notifications before using the connection - internal.Logger.Printf(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID()) + p.logger().Infof(context.Background(), "push: conn[%d] has buffered data, likely push notifications - will be processed by client", cn.GetID()) return true // Connection is healthy, client will handle notifications } return false // Unexpected data, not push notifications, connection is unhealthy @@ -961,3 +965,11 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { } return true } + +func (p *ConnPool) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if p.cfg != nil && p.cfg.Logger != nil { + logger = p.cfg.Logger + } + return logger +} diff --git a/logging/custom.go b/logging/custom.go new file mode 100644 index 0000000000..4bd0b627a9 --- /dev/null +++ b/logging/custom.go @@ -0,0 +1,144 @@ +package logging + +import ( + "context" + "fmt" +) + +// CustomLogger is a logger interface with leveled logging methods. +// +// This interface can be implemented by custom loggers to provide leveled logging. +type CustomLogger struct { + logger LoggerWithLevel + loggerLevel *LogLevelT + printfAdapter PrintfAdapter +} + +func NewCustomLogger(logger LoggerWithLevel, opts ...CustomLoggerOption) *CustomLogger { + cl := &CustomLogger{ + logger: logger, + } + for _, opt := range opts { + opt(cl) + } + return cl +} + +type CustomLoggerOption func(*CustomLogger) + +func WithPrintfAdapter(adapter PrintfAdapter) CustomLoggerOption { + return func(cl *CustomLogger) { + cl.printfAdapter = adapter + } +} + +func WithLoggerLevel(level LogLevelT) CustomLoggerOption { + return func(cl *CustomLogger) { + cl.loggerLevel = &level + } +} + +// PrintfAdapter is a function that converts Printf-style log messages into structured log messages. +// It can be used to extract key-value pairs from the formatted message. +type PrintfAdapter func(ctx context.Context, format string, v ...any) (context.Context, string, []any) + +// Error is a structured error level logging method with context and arguments. +func (cl *CustomLogger) Error(ctx context.Context, msg string, args ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Errorf(ctx, msg, args...) + return + } + cl.logger.ErrorContext(ctx, msg, args...) +} + +func (cl *CustomLogger) Errorf(ctx context.Context, format string, v ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Errorf(ctx, format, v...) + return + } + cl.logger.ErrorContext(ctx, format, v...) +} + +// Warn is a structured warning level logging method with context and arguments. +func (cl *CustomLogger) Warn(ctx context.Context, msg string, args ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Warnf(ctx, msg, args...) + return + } + cl.logger.WarnContext(ctx, msg, args...) +} + +func (cl *CustomLogger) Warnf(ctx context.Context, format string, v ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Warnf(ctx, format, v...) + return + } + cl.logger.WarnContext(cl.printfToStructured(ctx, format, v...)) +} + +// Info is a structured info level logging method with context and arguments. +func (cl *CustomLogger) Info(ctx context.Context, msg string, args ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Infof(ctx, msg, args...) + return + } + cl.logger.InfoContext(ctx, msg, args...) +} + +// Debug is a structured debug level logging method with context and arguments. +func (cl *CustomLogger) Debug(ctx context.Context, msg string, args ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Debugf(ctx, msg, args...) + return + } + cl.logger.DebugContext(ctx, msg, args...) +} + +func (cl *CustomLogger) Infof(ctx context.Context, format string, v ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Infof(ctx, format, v...) + return + } + + cl.logger.InfoContext(cl.printfToStructured(ctx, format, v...)) +} + +func (cl *CustomLogger) Debugf(ctx context.Context, format string, v ...any) { + if cl == nil || cl.logger == nil { + legacyLoggerWithLevel.Debugf(ctx, format, v...) + return + } + cl.logger.DebugContext(cl.printfToStructured(ctx, format, v...)) +} + +func (cl *CustomLogger) printfToStructured(ctx context.Context, format string, v ...any) (context.Context, string, []any) { + if cl.printfAdapter != nil { + return cl.printfAdapter(ctx, format, v...) + } + return ctx, fmt.Sprintf(format, v...), nil +} + +func (cl *CustomLogger) Enabled(ctx context.Context, level LogLevelT) bool { + if cl.loggerLevel != nil { + return level >= *cl.loggerLevel + } + + return legacyLoggerWithLevel.Enabled(ctx, level) +} + +// LoggerWithLevel is a logger interface with leveled logging methods. +// +// [slog.Logger] from the standard library satisfies this interface. +type LoggerWithLevel interface { + // InfoContext logs an info level message + InfoContext(ctx context.Context, format string, v ...any) + + // WarnContext logs a warning level message + WarnContext(ctx context.Context, format string, v ...any) + + // Debugf logs a debug level message + DebugContext(ctx context.Context, format string, v ...any) + + // Errorf logs an error level message + ErrorContext(ctx context.Context, format string, v ...any) +} diff --git a/logging/legacy.go b/logging/legacy.go new file mode 100644 index 0000000000..6a29023901 --- /dev/null +++ b/logging/legacy.go @@ -0,0 +1,91 @@ +package logging + +import ( + "context" + + "github.com/redis/go-redis/v9/internal" +) + +// legacyLoggerAdapter is a logger that implements [LoggerWithLevel] interface +// using the global [internal.Logger] and [internal.LogLevel] variables. +type legacyLoggerAdapter struct{} + +var _ LoggerWithLevel = (*legacyLoggerAdapter)(nil) + +// structuredToPrintf converts a structured log message and key-value pairs into something a Printf-style logger can understand. +func (l *legacyLoggerAdapter) structuredToPrintf(msg string, v ...any) (string, []any) { + format := msg + var args []any + + for i := 0; i < len(v); i += 2 { + if i+1 >= len(v) { + break + } + format += " %v=%v" + args = append(args, v[i], v[i+1]) + } + + return format, args +} + +func (l legacyLoggerAdapter) Errorf(ctx context.Context, format string, v ...any) { + internal.Logger.Printf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) ErrorContext(ctx context.Context, msg string, args ...any) { + format, v := l.structuredToPrintf(msg, args...) + l.Errorf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) WarnContext(ctx context.Context, msg string, args ...any) { + format, v := l.structuredToPrintf(msg, args...) + l.Warnf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) Warnf(ctx context.Context, format string, v ...any) { + if !internal.LogLevel.WarnOrAbove() { + // Skip logging + return + } + internal.Logger.Printf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) InfoContext(ctx context.Context, msg string, args ...any) { + format, v := l.structuredToPrintf(msg, args...) + l.Infof(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) Infof(ctx context.Context, format string, v ...any) { + if !internal.LogLevel.InfoOrAbove() { + // Skip logging + return + } + internal.Logger.Printf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) DebugContext(ctx context.Context, msg string, args ...any) { + format, v := l.structuredToPrintf(msg, args...) + l.Debugf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) Debugf(ctx context.Context, format string, v ...any) { + if !internal.LogLevel.DebugOrAbove() { + // Skip logging + return + } + internal.Logger.Printf(ctx, format, v...) +} + +func (l *legacyLoggerAdapter) Enabled(ctx context.Context, level LogLevelT) bool { + switch level { + case LogLevelDebug: + return internal.LogLevel.DebugOrAbove() + case LogLevelWarn: + return internal.LogLevel.WarnOrAbove() + case LogLevelInfo: + return internal.LogLevel.InfoOrAbove() + } + return true +} + +var legacyLoggerWithLevel = &legacyLoggerAdapter{} diff --git a/logging/logging.go b/logging/logging.go index b8453aa929..64e1a69ad5 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -89,3 +89,4 @@ func (l *filterLogger) Printf(ctx context.Context, format string, v ...interface return } } + diff --git a/maintnotifications/circuit_breaker.go b/maintnotifications/circuit_breaker.go index cb76b6447f..764300137c 100644 --- a/maintnotifications/circuit_breaker.go +++ b/maintnotifications/circuit_breaker.go @@ -6,8 +6,8 @@ import ( "sync/atomic" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" + "github.com/redis/go-redis/v9/logging" ) // CircuitBreakerState represents the state of a circuit breaker @@ -102,9 +102,7 @@ func (cb *CircuitBreaker) Execute(fn func() error) error { if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) { cb.requests.Store(0) cb.successes.Store(0) - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint)) - } + cb.logger().Infof(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint)) // Fall through to half-open logic } else { return ErrCircuitBreakerOpen @@ -144,17 +142,13 @@ func (cb *CircuitBreaker) recordFailure() { case CircuitBreakerClosed: if failures >= int64(cb.failureThreshold) { if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) { - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures)) - } + cb.logger().Warnf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures)) } } case CircuitBreakerHalfOpen: // Any failure in half-open state immediately opens the circuit if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) { - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint)) - } + cb.logger().Warnf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint)) } } } @@ -176,9 +170,7 @@ func (cb *CircuitBreaker) recordSuccess() { if successes >= int64(cb.maxRequests) { if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) { cb.failures.Store(0) - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes)) - } + cb.logger().Infof(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes)) } } } @@ -202,6 +194,14 @@ func (cb *CircuitBreaker) GetStats() CircuitBreakerStats { } } +func (cb *CircuitBreaker) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if cb.config != nil && cb.config.Logger != nil { + logger = cb.config.Logger + } + return logger +} + // CircuitBreakerStats provides statistics about a circuit breaker type CircuitBreakerStats struct { Endpoint string @@ -325,8 +325,8 @@ func (cbm *CircuitBreakerManager) cleanup() { } // Log cleanup results - if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count)) + if len(toDelete) > 0 { + cbm.logger().Infof(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count)) } cbm.lastCleanup.Store(now.Unix()) @@ -351,3 +351,11 @@ func (cbm *CircuitBreakerManager) Reset() { return true }) } + +func (cbm *CircuitBreakerManager) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if cbm.config != nil && cbm.config.Logger != nil { + logger = cbm.config.Logger + } + return logger +} diff --git a/maintnotifications/config.go b/maintnotifications/config.go index cbf4f6b22b..37df3fe710 100644 --- a/maintnotifications/config.go +++ b/maintnotifications/config.go @@ -7,9 +7,9 @@ import ( "strings" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" ) // Mode represents the maintenance notifications mode @@ -128,6 +128,9 @@ type Config struct { // After this many retries, the connection will be removed from the pool. // Default: 3 MaxHandoffRetries int + + // Logger is an optional custom logger for maintenance notifications. + Logger *logging.CustomLogger } func (c *Config) IsEnabled() bool { @@ -312,10 +315,9 @@ func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) * result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests } - if internal.LogLevel.DebugOrAbove() { - internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled()) - internal.Logger.Printf(context.Background(), logs.ConfigDebug(result)) - } + c.logger().Debugf(context.Background(), logs.DebugLoggingEnabled()) + c.logger().Debugf(context.Background(), logs.ConfigDebug(result)) + return result } @@ -341,6 +343,8 @@ func (c *Config) Clone() *Config { // Configuration fields MaxHandoffRetries: c.MaxHandoffRetries, + + Logger: c.Logger, } } @@ -365,6 +369,14 @@ func (c *Config) applyWorkerDefaults(poolSize int) { } } +func (c *Config) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if c.Logger != nil { + logger = c.Logger + } + return logger +} + // DetectEndpointType automatically detects the appropriate endpoint type // based on the connection address and TLS configuration. // diff --git a/maintnotifications/handoff_worker.go b/maintnotifications/handoff_worker.go index 22df2c8008..c40782f01e 100644 --- a/maintnotifications/handoff_worker.go +++ b/maintnotifications/handoff_worker.go @@ -11,6 +11,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" ) // handoffWorkerManager manages background workers and queue for connection handoffs @@ -121,7 +122,7 @@ func (hwm *handoffWorkerManager) onDemandWorker() { defer func() { // Handle panics to ensure proper cleanup if r := recover(); r != nil { - internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r)) + hwm.logger().Errorf(context.Background(), logs.WorkerPanicRecovered(r)) } // Decrement active worker count when exiting @@ -145,23 +146,17 @@ func (hwm *handoffWorkerManager) onDemandWorker() { select { case <-hwm.shutdown: - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown()) - } + hwm.logger().Infof(context.Background(), logs.WorkerExitingDueToShutdown()) return case <-timer.C: // Worker has been idle for too long, exit to save resources - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout)) - } + hwm.logger().Infof(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout)) return case request := <-hwm.handoffQueue: // Check for shutdown before processing select { case <-hwm.shutdown: - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing()) - } + hwm.logger().Infof(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing()) // Clean up the request before exiting hwm.pending.Delete(request.ConnID) return @@ -177,9 +172,7 @@ func (hwm *handoffWorkerManager) onDemandWorker() { func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { // Remove from pending map defer hwm.pending.Delete(request.Conn.GetID()) - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint)) - } + hwm.logger().Infof(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint)) // Create a context with handoff timeout from config handoffTimeout := 15 * time.Second // Default timeout @@ -219,20 +212,21 @@ func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) { afterTime = minRetryBackoff } - if internal.LogLevel.InfoOrAbove() { + // the HandoffRetries() requires locking resource via [atomic.Uint32.Load], + // so we check the log level first before calling it + if hwm.logger().Enabled(context.Background(), internal.LogLevelInfo) { + // Get current retry count for better logging currentRetries := request.Conn.HandoffRetries() maxRetries := 3 // Default fallback if hwm.config != nil { maxRetries = hwm.config.MaxHandoffRetries } - internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err)) + hwm.logger().Infof(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err)) } time.AfterFunc(afterTime, func() { if err := hwm.queueHandoff(request.Conn); err != nil { - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err)) - } + hwm.logger().Warnf(context.Background(), logs.CannotQueueHandoffForRetry(err)) hwm.closeConnFromRequest(context.Background(), request, err) } }) @@ -258,9 +252,7 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { // on retries the connection will not be marked for handoff, but it will have retries > 0 // if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff if !shouldHandoff && conn.HandoffRetries() == 0 { - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID())) - } + hwm.logger().Infof(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID())) return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID())) } @@ -301,9 +293,7 @@ func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error { // Queue is full - log and attempt scaling queueLen := len(hwm.handoffQueue) queueCap := cap(hwm.handoffQueue) - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap)) - } + hwm.logger().Warnf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap)) } } } @@ -356,7 +346,7 @@ func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, c // Check if circuit breaker is open before attempting handoff if circuitBreaker.IsOpen() { - internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint)) + hwm.logger().Infof(ctx, logs.CircuitBreakerOpen(connID, newEndpoint)) return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open } @@ -385,16 +375,14 @@ func (hwm *handoffWorkerManager) performHandoffInternal( connID uint64, ) (shouldRetry bool, err error) { retries := conn.IncrementAndGetHandoffRetries(1) - internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) + hwm.logger().Infof(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String())) maxRetries := 3 // Default fallback if hwm.config != nil { maxRetries = hwm.config.MaxHandoffRetries } if retries > maxRetries { - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries)) - } + hwm.logger().Warnf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries)) // won't retry on ErrMaxHandoffRetriesReached return false, ErrMaxHandoffRetriesReached } @@ -405,7 +393,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( // Create new connection to the new endpoint newNetConn, err := endpointDialer(ctx) if err != nil { - internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err)) + hwm.logger().Errorf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err)) // will retry // Maybe a network error - retry after a delay return true, err @@ -424,9 +412,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration) conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000"))) - } + hwm.logger().Infof(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000"))) } // Replace the connection and execute initialization @@ -447,7 +433,7 @@ func (hwm *handoffWorkerManager) performHandoffInternal( // - clear the handoff state (shouldHandoff, endpoint, seqID) // - reset the handoff retries to 0 conn.ClearHandoffState() - internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint)) + hwm.logger().Infof(ctx, logs.HandoffSucceeded(connID, newEndpoint)) // successfully completed the handoff, no retry needed and no error return false, nil @@ -477,16 +463,20 @@ func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, reque conn := request.Conn if pooler != nil { pooler.Remove(ctx, conn, err) - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) - } + hwm.logger().Warnf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err)) } else { err := conn.Close() // Close the connection if no pool provided if err != nil { - internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err) - } - if internal.LogLevel.WarnOrAbove() { - internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err)) + hwm.logger().Errorf(ctx, "redis: failed to close connection: %v", err) } + hwm.logger().Warnf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID())) + } +} + +func (hwm *handoffWorkerManager) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if hwm.config != nil && hwm.config.Logger != nil { + logger = hwm.config.Logger } + return logger } diff --git a/maintnotifications/manager.go b/maintnotifications/manager.go index 775c163e14..b5815abf0b 100644 --- a/maintnotifications/manager.go +++ b/maintnotifications/manager.go @@ -9,10 +9,10 @@ import ( "sync/atomic" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/interfaces" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/push" ) @@ -150,14 +150,10 @@ func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoi // Use LoadOrStore for atomic check-and-set operation if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded { // Duplicate MOVING notification, ignore - if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID)) - } + hm.logger().Debugf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID)) return nil } - if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID)) - } + hm.logger().Debugf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID)) // Increment active operation count atomically hm.activeOperationCount.Add(1) @@ -175,15 +171,11 @@ func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) { // Remove from active operations atomically if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded { - if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID)) - } + hm.logger().Debugf(context.Background(), logs.UntrackingMovingOperation(connID, seqID)) // Decrement active operation count only if operation existed hm.activeOperationCount.Add(-1) } else { - if internal.LogLevel.DebugOrAbove() { // Debug level - internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID)) - } + hm.logger().Debugf(context.Background(), logs.OperationNotTracked(connID, seqID)) } } @@ -318,3 +310,11 @@ func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) { defer hm.hooksMu.Unlock() hm.hooks = append(hm.hooks, notificationHook) } + +func (hm *Manager) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if hm.config != nil && hm.config.Logger != nil { + logger = hm.config.Logger + } + return logger +} diff --git a/maintnotifications/pool_hook.go b/maintnotifications/pool_hook.go index 9fd24b4a7b..2e88eef885 100644 --- a/maintnotifications/pool_hook.go +++ b/maintnotifications/pool_hook.go @@ -6,9 +6,9 @@ import ( "sync" "time" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" ) // OperationsManagerInterface defines the interface for completing handoff operations @@ -150,7 +150,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool if err := ph.workerManager.queueHandoff(conn); err != nil { // Failed to queue handoff, remove the connection - internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err)) + ph.logger().Errorf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err)) // Don't pool, remove connection, no error to caller return false, true, nil } @@ -170,7 +170,7 @@ func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool // Other error - remove the connection return false, true, nil } - internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID())) + ph.logger().Errorf(ctx, logs.MarkedForHandoff(conn.GetID())) return true, false, nil } @@ -182,3 +182,11 @@ func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) { func (ph *PoolHook) Shutdown(ctx context.Context) error { return ph.workerManager.shutdownWorkers(ctx) } + +func (ph *PoolHook) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if ph.config != nil && ph.config.Logger != nil { + logger = ph.config.Logger + } + return logger +} diff --git a/maintnotifications/push_notification_handler.go b/maintnotifications/push_notification_handler.go index 937b4ae82e..a6cccd3547 100644 --- a/maintnotifications/push_notification_handler.go +++ b/maintnotifications/push_notification_handler.go @@ -9,6 +9,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/maintnotifications/logs" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/push" ) @@ -21,13 +22,13 @@ type NotificationHandler struct { // HandlePushNotification processes push notifications with hook support. func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { if len(notification) == 0 { - internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification)) + snh.logger().Errorf(ctx, logs.InvalidNotificationFormat(notification)) return ErrInvalidNotification } notificationType, ok := notification[0].(string) if !ok { - internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0])) + snh.logger().Errorf(ctx, logs.InvalidNotificationTypeFormat(notification[0])) return ErrInvalidNotification } @@ -64,19 +65,19 @@ func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, hand // ["MOVING", seqNum, timeS, endpoint] - per-connection handoff func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { if len(notification) < 3 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("MOVING", notification)) return ErrInvalidNotification } seqID, ok := notification[1].(int64) if !ok { - internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1])) + snh.logger().Errorf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1])) return ErrInvalidNotification } // Extract timeS timeS, ok := notification[2].(int64) if !ok { - internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2])) + snh.logger().Errorf(ctx, logs.InvalidTimeSInMovingNotification(notification[2])) return ErrInvalidNotification } @@ -90,7 +91,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus if notification[3] == nil || stringified == internal.RedisNull { newEndpoint = "" } else { - internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3])) + snh.logger().Errorf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3])) return ErrInvalidNotification } } @@ -99,7 +100,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus // Get the connection that received this notification conn := handlerCtx.Conn if conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("MOVING")) return ErrInvalidNotification } @@ -108,7 +109,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus if pc, ok := conn.(*pool.Conn); ok { poolConn = pc } else { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx)) return ErrInvalidNotification } @@ -124,9 +125,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus deadline := time.Now().Add(time.Duration(timeS) * time.Second) // If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds if newEndpoint == "" || newEndpoint == internal.RedisNull { - if internal.LogLevel.DebugOrAbove() { - internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2)) - } + snh.logger().Debugf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2)) // same as current endpoint newEndpoint = snh.manager.options.GetAddr() // delay the handoff for timeS/2 seconds to the same endpoint @@ -139,7 +138,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus } if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil { // Log error but don't fail the goroutine - use background context since original may be cancelled - internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err)) + snh.logger().Errorf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err)) } }) return nil @@ -150,7 +149,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error { if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil { - internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err)) + snh.logger().Errorf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err)) // Connection is already marked for handoff, which is acceptable // This can happen if multiple MOVING notifications are received for the same connection return nil @@ -171,25 +170,23 @@ func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx // MIGRATING notifications indicate that a connection is about to be migrated // Apply relaxed timeouts to the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("MIGRATING", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("MIGRATING")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Apply relaxed timeout to this specific connection - if internal.LogLevel.InfoOrAbove() { - internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout)) - } + snh.logger().Infof(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout)) conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) return nil } @@ -199,26 +196,25 @@ func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx p // MIGRATED notifications indicate that a connection migration has completed // Restore normal timeouts for the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("MIGRATED", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("MIGRATED")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Clear relaxed timeout for this specific connection - if internal.LogLevel.InfoOrAbove() { - connID := conn.GetID() - internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) - } + connID := conn.GetID() + snh.logger().Infof(ctx, logs.UnrelaxedTimeout(connID)) + conn.ClearRelaxedTimeout() return nil } @@ -228,26 +224,25 @@ func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCt // FAILING_OVER notifications indicate that a connection is about to failover // Apply relaxed timeouts to the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("FAILING_OVER", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Apply relaxed timeout to this specific connection - if internal.LogLevel.InfoOrAbove() { - connID := conn.GetID() - internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout)) - } + connID := conn.GetID() + snh.logger().Infof(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout)) + conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) return nil } @@ -257,26 +252,32 @@ func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx // FAILED_OVER notifications indicate that a connection failover has completed // Restore normal timeouts for the specific connection that received this notification if len(notification) < 2 { - internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification)) + snh.logger().Errorf(ctx, logs.InvalidNotification("FAILED_OVER", notification)) return ErrInvalidNotification } if handlerCtx.Conn == nil { - internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER")) + snh.logger().Errorf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER")) return ErrInvalidNotification } conn, ok := handlerCtx.Conn.(*pool.Conn) if !ok { - internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx)) + snh.logger().Errorf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx)) return ErrInvalidNotification } // Clear relaxed timeout for this specific connection - if internal.LogLevel.InfoOrAbove() { - connID := conn.GetID() - internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID)) - } + connID := conn.GetID() + snh.logger().Infof(ctx, logs.UnrelaxedTimeout(connID)) conn.ClearRelaxedTimeout() return nil } + +func (snh *NotificationHandler) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if snh.manager != nil && snh.manager.config != nil && snh.manager.config.Logger != nil { + logger = snh.manager.config.Logger + } + return logger +} diff --git a/options.go b/options.go index e0dcb5eba6..2bb82fcf9d 100644 --- a/options.go +++ b/options.go @@ -17,6 +17,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -267,6 +268,10 @@ type Options struct { // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. // If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it. MaintNotificationsConfig *maintnotifications.Config + + // Logger is the logger used by the client for logging. + // If none is provided, the global logger [internal.LegacyLoggerWithLevel] is used. + Logger *logging.CustomLogger } func (opt *Options) init() { diff --git a/osscluster.go b/osscluster.go index 7925d2c603..a5566210cd 100644 --- a/osscluster.go +++ b/osscluster.go @@ -20,6 +20,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -148,6 +149,9 @@ type ClusterOptions struct { // If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it. // The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications. MaintNotificationsConfig *maintnotifications.Config + + // Logger is an optional logger for logging cluster-related messages. + Logger *logging.CustomLogger } func (opt *ClusterOptions) init() { @@ -390,6 +394,8 @@ func (opt *ClusterOptions) clientOptions() *Options { UnstableResp3: opt.UnstableResp3, MaintNotificationsConfig: maintNotificationsConfig, PushNotificationProcessor: opt.PushNotificationProcessor, + + Logger: opt.Logger, } } @@ -703,6 +709,14 @@ func (c *clusterNodes) Random() (*clusterNode, error) { return c.GetOrCreate(addrs[n]) } +func (c *clusterNodes) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if c.opt != nil && c.opt.Logger != nil { + logger = c.opt.Logger + } + return logger +} + //------------------------------------------------------------------------------ type clusterSlot struct { @@ -900,12 +914,12 @@ func (c *clusterState) slotClosestNode(slot int) (*clusterNode, error) { // if all nodes are failing, we will pick the temporarily failing node with lowest latency if minLatency < maximumNodeLatency && closestNode != nil { - internal.Logger.Printf(context.TODO(), "redis: all nodes are marked as failed, picking the temporarily failing node with lowest latency") + c.nodes.logger().Errorf(context.TODO(), "redis: all nodes are marked as failed, picking the temporarily failing node with lowest latency") return closestNode, nil } // If all nodes are having the maximum latency(all pings are failing) - return a random node across the cluster - internal.Logger.Printf(context.TODO(), "redis: pings to all nodes are failing, picking a random node across the cluster") + c.nodes.logger().Errorf(context.TODO(), "redis: pings to all nodes are failing, picking a random node across the cluster") return c.nodes.Random() } @@ -1740,7 +1754,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } if err := statusCmd.readReply(rd); err != nil { return err @@ -1751,7 +1765,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } err := statusCmd.readReply(rd) if err != nil { @@ -1770,7 +1784,7 @@ func (c *ClusterClient) txPipelineReadQueued( if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse number of replies. line, err := rd.ReadLine() @@ -2022,13 +2036,13 @@ func (c *ClusterClient) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, func (c *ClusterClient) cmdInfo(ctx context.Context, name string) *CommandInfo { cmdsInfo, err := c.cmdsInfoCache.Get(ctx) if err != nil { - internal.Logger.Printf(context.TODO(), "getting command info: %s", err) + c.logger().Errorf(ctx, "getting command info: %s", err) return nil } info := cmdsInfo[name] if info == nil { - internal.Logger.Printf(context.TODO(), "info for cmd=%s not found", name) + c.logger().Errorf(ctx, "info for cmd=%s not found", name) } return info } @@ -2126,6 +2140,14 @@ func (c *ClusterClient) context(ctx context.Context) context.Context { return context.Background() } +func (c *ClusterClient) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if c.opt != nil && c.opt.Logger != nil { + logger = c.opt.Logger + } + return logger +} + func appendIfNotExist[T comparable](vals []T, newVal T) []T { for _, v := range vals { if v == newVal { diff --git a/pubsub.go b/pubsub.go index 959a5c45b1..901d4b0bf3 100644 --- a/pubsub.go +++ b/pubsub.go @@ -10,6 +10,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/push" ) @@ -141,6 +142,17 @@ func mapKeys(m map[string]struct{}) []string { return s } +// logger is a wrapper around the logger to log messages with context. +// +// it uses the client logger if set, otherwise it uses the global logger. +func (c *PubSub) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if c.opt != nil && c.opt.Logger != nil { + logger = c.opt.Logger + } + return logger +} + func (c *PubSub) _subscribe( ctx context.Context, cn *pool.Conn, redisCmd string, channels []string, ) error { @@ -190,7 +202,7 @@ func (c *PubSub) reconnect(ctx context.Context, reason error) { // Update the address in the options oldAddr := c.cn.RemoteAddr().String() c.opt.Addr = newEndpoint - internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) + c.logger().Infof(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) } } _ = c.closeTheCn(reason) @@ -475,7 +487,7 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution // Push notification processing errors shouldn't break normal Redis operations - internal.Logger.Printf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err) + c.logger().Errorf(ctx, "push: conn[%d] error processing pending notifications before reading reply: %v", cn.GetID(), err) } return c.cmd.readReply(rd) }) @@ -634,6 +646,9 @@ func WithChannelSendTimeout(d time.Duration) ChannelOption { type channel struct { pubSub *PubSub + // Optional logger for logging channel-related messages. + Logger *logging.CustomLogger + msgCh chan *Message allCh chan interface{} ping chan struct{} @@ -733,12 +748,10 @@ func (c *channel) initMsgChan() { <-timer.C } case <-timer.C: - internal.Logger.Printf( - ctx, "redis: %s channel is full for %s (message is dropped)", - c, c.chanSendTimeout) + c.logger().Errorf(ctx, "redis: %s channel is full for %s (message is dropped)", c, c.chanSendTimeout) } default: - internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) + c.logger().Errorf(ctx, "redis: unknown message type: %T", msg) } } }() @@ -787,13 +800,20 @@ func (c *channel) initAllChan() { <-timer.C } case <-timer.C: - internal.Logger.Printf( - ctx, "redis: %s channel is full for %s (message is dropped)", + c.logger().Errorf(ctx, "redis: %s channel is full for %s (message is dropped)", c, c.chanSendTimeout) } default: - internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) + c.logger().Errorf(ctx, "redis: unknown message type: %T", msg) } } }() } + +func (c *channel) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if c.Logger != nil { + logger = c.Logger + } + return logger +} diff --git a/redis.go b/redis.go index dcd7b59a78..2f72db62a8 100644 --- a/redis.go +++ b/redis.go @@ -15,6 +15,7 @@ import ( "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/maintnotifications" "github.com/redis/go-redis/v9/push" ) @@ -228,6 +229,9 @@ type baseClient struct { // streamingCredentialsManager is used to manage streaming credentials streamingCredentialsManager *streaming.Manager + + // loggerWithLevel is used for logging + loggerWithLevel *logging.CustomLogger } func (c *baseClient) clone() *baseClient { @@ -242,6 +246,7 @@ func (c *baseClient) clone() *baseClient { pushProcessor: c.pushProcessor, maintNotificationsManager: maintNotificationsManager, streamingCredentialsManager: c.streamingCredentialsManager, + loggerWithLevel: c.loggerWithLevel, } return clone } @@ -330,16 +335,16 @@ func (c *baseClient) onAuthenticationErr() func(poolCn *pool.Conn, err error) { // Close the connection to force a reconnection. err := c.connPool.CloseConn(poolCn) if err != nil { - internal.Logger.Printf(context.Background(), "redis: failed to close connection: %v", err) + c.logger().Errorf(context.Background(), "redis: failed to close connection: %v", err) // try to close the network connection directly // so that no resource is leaked err := poolCn.Close() if err != nil { - internal.Logger.Printf(context.Background(), "redis: failed to close network connection: %v", err) + c.logger().Errorf(context.Background(), "redis: failed to close network connection: %v", err) } } } - internal.Logger.Printf(context.Background(), "redis: re-authentication failed: %v", err) + c.logger().Errorf(context.Background(), "redis: re-authentication failed: %v", err) } } } @@ -475,13 +480,13 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { c.optLock.Unlock() return fmt.Errorf("failed to enable maintnotifications: %w", maintNotifHandshakeErr) default: // will handle auto and any other - internal.Logger.Printf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) + c.logger().Errorf(ctx, "auto mode fallback: maintnotifications disabled due to handshake error: %v", maintNotifHandshakeErr) c.opt.MaintNotificationsConfig.Mode = maintnotifications.ModeDisabled c.optLock.Unlock() // auto mode, disable maintnotifications and continue if err := c.disableMaintNotificationsUpgrades(); err != nil { // Log error but continue - auto mode should be resilient - internal.Logger.Printf(ctx, "failed to disable maintnotifications in auto mode: %v", err) + c.logger().Errorf(ctx, "failed to disable maintnotifications in auto mode: %v", err) } } } else { @@ -536,7 +541,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) } else { // process any pending push notifications before returning the connection to the pool if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before releasing connection: %v", err) } c.connPool.Put(ctx, cn) } @@ -603,7 +608,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { // Process any pending push notifications before executing the command if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before command: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -626,7 +631,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } return readReplyFunc(rd) }); err != nil { @@ -672,6 +677,16 @@ func (c *baseClient) context(ctx context.Context) context.Context { return context.Background() } +// logger is a wrapper around the logger to log messages with context. +// it uses the client logger if set, otherwise it uses the global logger. +func (c *baseClient) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if c.opt != nil && c.opt.Logger != nil { + logger = c.opt.Logger + } + return logger +} + // createInitConnFunc creates a connection initialization function that can be used for reconnections. func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error { return func(ctx context.Context, cn *pool.Conn) error { @@ -783,7 +798,7 @@ func (c *baseClient) generalProcessPipeline( lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { // Process any pending push notifications before executing the pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before processing pipeline: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before processing pipeline: %v", err) } var err error canRetry, err = p(ctx, cn, cmds) @@ -805,7 +820,7 @@ func (c *baseClient) pipelineProcessCmds( ) (bool, error) { // Process any pending push notifications before executing the pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before writing pipeline: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before writing pipeline: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -829,7 +844,7 @@ func (c *baseClient) pipelineReadCmds(ctx context.Context, cn *pool.Conn, rd *pr for i, cmd := range cmds { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } err := cmd.readReply(rd) cmd.SetErr(err) @@ -847,7 +862,7 @@ func (c *baseClient) txPipelineProcessCmds( ) (bool, error) { // Process any pending push notifications before executing the transaction pipeline if err := c.processPushNotifications(ctx, cn); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before transaction: %v", err) } if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { @@ -881,7 +896,7 @@ func (c *baseClient) txPipelineProcessCmds( func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse +OK. if err := statusCmd.readReply(rd); err != nil { @@ -892,7 +907,7 @@ func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd for _, cmd := range cmds { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } if err := statusCmd.readReply(rd); err != nil { cmd.SetErr(err) @@ -904,7 +919,7 @@ func (c *baseClient) txPipelineReadQueued(ctx context.Context, cn *pool.Conn, rd // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { - internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err) + c.logger().Errorf(ctx, "push: error processing pending notifications before reading reply: %v", err) } // Parse number of replies. line, err := rd.ReadLine() @@ -978,7 +993,7 @@ func NewClient(opt *Options) *Client { if opt.MaintNotificationsConfig != nil && opt.MaintNotificationsConfig.Mode != maintnotifications.ModeDisabled && opt.Protocol == 3 { err := c.enableMaintNotificationsUpgrades() if err != nil { - internal.Logger.Printf(context.Background(), "failed to initialize maintnotifications: %v", err) + c.logger().Errorf(context.Background(), "failed to initialize maintnotifications: %v", err) if opt.MaintNotificationsConfig.Mode == maintnotifications.ModeEnabled { /* Design decision: panic here to fail fast if maintnotifications cannot be enabled when explicitly requested. diff --git a/ring.go b/ring.go index 3381460abd..c41a8b0f91 100644 --- a/ring.go +++ b/ring.go @@ -20,6 +20,7 @@ import ( "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" "github.com/redis/go-redis/v9/internal/rand" + "github.com/redis/go-redis/v9/logging" ) var errRingShardsDown = errors.New("redis: all ring shards are down") @@ -154,6 +155,8 @@ type RingOptions struct { DisableIdentity bool IdentitySuffix string UnstableResp3 bool + + Logger *logging.CustomLogger } func (opt *RingOptions) init() { @@ -345,7 +348,7 @@ func (c *ringSharding) SetAddrs(addrs map[string]string) { cleanup := func(shards map[string]*ringShard) { for addr, shard := range shards { if err := shard.Client.Close(); err != nil { - internal.Logger.Printf(context.Background(), "shard.Close %s failed: %s", addr, err) + c.logger().Errorf(context.Background(), "shard.Close %s failed: %s", addr, err) } } } @@ -490,7 +493,7 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { for _, shard := range c.List() { isUp := c.opt.HeartbeatFn(ctx, shard.Client) if shard.Vote(isUp) { - internal.Logger.Printf(ctx, "ring shard state changed: %s", shard) + c.logger().Infof(ctx, "ring shard state changed: %s", shard) rebalance = true } } @@ -559,6 +562,14 @@ func (c *ringSharding) Close() error { return firstErr } +func (c *ringSharding) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if c.opt != nil && c.opt.Logger != nil { + logger = c.opt.Logger + } + return logger +} + //------------------------------------------------------------------------------ // Ring is a Redis client that uses consistent hashing to distribute diff --git a/sentinel.go b/sentinel.go index 6481e1ee84..828f2f0d39 100644 --- a/sentinel.go +++ b/sentinel.go @@ -13,10 +13,10 @@ import ( "time" "github.com/redis/go-redis/v9/auth" - "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/rand" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/logging" "github.com/redis/go-redis/v9/push" ) @@ -148,6 +148,9 @@ type FailoverOptions struct { // If nil, maintnotifications upgrades are disabled. // (however if Mode is nil, it defaults to "auto" - enable if server supports it) //MaintNotificationsConfig *maintnotifications.Config + + // Optional logger for logging + Logger *logging.CustomLogger } func (opt *FailoverOptions) clientOptions() *Options { @@ -194,6 +197,8 @@ func (opt *FailoverOptions) clientOptions() *Options { IdentitySuffix: opt.IdentitySuffix, UnstableResp3: opt.UnstableResp3, + + Logger: opt.Logger, } } @@ -238,6 +243,8 @@ func (opt *FailoverOptions) sentinelOptions(addr string) *Options { IdentitySuffix: opt.IdentitySuffix, UnstableResp3: opt.UnstableResp3, + + Logger: opt.Logger, } } @@ -287,6 +294,8 @@ func (opt *FailoverOptions) clusterOptions() *ClusterOptions { DisableIndentity: opt.DisableIndentity, IdentitySuffix: opt.IdentitySuffix, FailingTimeoutSeconds: opt.FailingTimeoutSeconds, + + Logger: opt.Logger, } } @@ -818,7 +827,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { return "", err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", c.opt.MasterName, err) } else { return addr, nil @@ -836,7 +845,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { return "", err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", c.opt.MasterName, err) } else { return addr, nil @@ -865,7 +874,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { sentinelCli := NewSentinelClient(c.opt.sentinelOptions(addr)) addrVal, err := sentinelCli.GetMasterAddrByName(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName addr=%s, master=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: GetMasterAddrByName addr=%s, master=%q failed: %s", addr, c.opt.MasterName, err) _ = sentinelCli.Close() errCh <- err @@ -876,7 +885,7 @@ func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { // Push working sentinel to the top c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] c.setSentinel(ctx, sentinelCli) - internal.Logger.Printf(ctx, "sentinel: selected addr=%s masterAddr=%s", addr, masterAddr) + c.logger().Infof(ctx, "sentinel: selected addr=%s masterAddr=%s", addr, masterAddr) cancel() }) }(i, sentinelAddr) @@ -921,7 +930,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo return nil, err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) } else if len(addrs) > 0 { return addrs, nil @@ -939,7 +948,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo return nil, err } // Continue on other errors - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) } else if len(addrs) > 0 { return addrs, nil @@ -960,7 +969,7 @@ func (c *sentinelFailover) replicaAddrs(ctx context.Context, useDisconnected boo if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, err } - internal.Logger.Printf(ctx, "sentinel: Replicas master=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: Replicas master=%q failed: %s", c.opt.MasterName, err) continue } @@ -993,7 +1002,7 @@ func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *Sentinel func (c *sentinelFailover) getReplicaAddrs(ctx context.Context, sentinel *SentinelClient) ([]string, error) { addrs, err := sentinel.Replicas(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: Replicas name=%q failed: %s", + c.logger().Errorf(ctx, "sentinel: Replicas name=%q failed: %s", c.opt.MasterName, err) return nil, err } @@ -1041,7 +1050,7 @@ func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) { } c.masterAddr = addr - internal.Logger.Printf(ctx, "sentinel: new master=%q addr=%q", + c.logger().Infof(ctx, "sentinel: new master=%q addr=%q", c.opt.MasterName, addr) if c.onFailover != nil { c.onFailover(ctx, addr) @@ -1062,7 +1071,7 @@ func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelCl func (c *sentinelFailover) discoverSentinels(ctx context.Context) { sentinels, err := c.sentinel.Sentinels(ctx, c.opt.MasterName).Result() if err != nil { - internal.Logger.Printf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err) + c.logger().Errorf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err) return } for _, sentinel := range sentinels { @@ -1077,7 +1086,7 @@ func (c *sentinelFailover) discoverSentinels(ctx context.Context) { if ip != "" && port != "" { sentinelAddr := net.JoinHostPort(ip, port) if !contains(c.sentinelAddrs, sentinelAddr) { - internal.Logger.Printf(ctx, "sentinel: discovered new sentinel=%q for master=%q", + c.logger().Infof(ctx, "sentinel: discovered new sentinel=%q for master=%q", sentinelAddr, c.opt.MasterName) c.sentinelAddrs = append(c.sentinelAddrs, sentinelAddr) } @@ -1097,7 +1106,7 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { if msg.Channel == "+switch-master" { parts := strings.Split(msg.Payload, " ") if parts[0] != c.opt.MasterName { - internal.Logger.Printf(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0]) + c.logger().Infof(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0]) continue } addr := net.JoinHostPort(parts[3], parts[4]) @@ -1110,6 +1119,14 @@ func (c *sentinelFailover) listen(pubsub *PubSub) { } } +func (c *sentinelFailover) logger() *logging.CustomLogger { + var logger *logging.CustomLogger + if c.opt != nil && c.opt.Logger != nil { + logger = c.opt.Logger + } + return logger +} + func contains(slice []string, str string) bool { for _, s := range slice { if s == str {