diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index 66fcef4e0ad9..7ab4cecf1dbc 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -68,6 +68,12 @@ type ( red.BitMapCmdable } + // RedisPubSub interface represents a redis pubsub node. + RedisPubSub interface { + Subscribe(ctx context.Context, channels ...string) *PubSub + PSubscribe(ctx context.Context, patterns ...string) *PubSub + } + // GeoLocation is used with GeoAdd to add geospatial location. GeoLocation = red.GeoLocation // GeoRadiusQuery is used with GeoRadius to query geospatial index. @@ -103,6 +109,9 @@ type ( // Cmder is an alias of redis.Cmder. Cmder = red.Cmder + + // PubSub is an alias of redis.PubSub. + PubSub = red.PubSub ) // MustNewRedis returns a Redis with given options. @@ -1218,7 +1227,6 @@ func (s *Redis) PipelinedCtx(ctx context.Context, fn func(Pipeliner) error) erro if err != nil { return err } - _, err = conn.Pipelined(ctx, fn) return err } @@ -2402,6 +2410,42 @@ func (s *Redis) ZunionstoreCtx(ctx context.Context, dest string, store *ZStore) return conn.ZUnionStore(ctx, dest, store).Result() } +// Subscribe is the implementation of redis subscribe command. +func (s *Redis) Subscribe(channels ...string) (*PubSub, error) { + conn, err := getPubSubRedis(s) + if err != nil { + return nil, err + } + return conn.Subscribe(context.Background(), channels...), nil +} + +// PSubscribe is the implementation of redis psubscribe command. +func (s *Redis) PSubscribe(patterns ...string) (*PubSub, error) { + conn, err := getPubSubRedis(s) + if err != nil { + return nil, err + } + return conn.PSubscribe(context.Background(), patterns...), nil +} + +// SubscribeCtx is the implementation of redis subscribe command. +func (s *Redis) SubscribeCtx(ctx context.Context, channels ...string) (*PubSub, error) { + conn, err := getPubSubRedis(s) + if err != nil { + return nil, err + } + return conn.Subscribe(ctx, channels...), nil +} + +// PSubscribeCtx is the implementation of redis psubscribe command. +func (s *Redis) PSubscribeCtx(ctx context.Context, patterns ...string) (*PubSub, error) { + conn, err := getPubSubRedis(s) + if err != nil { + return nil, err + } + return conn.PSubscribe(ctx, patterns...), nil +} + func (s *Redis) checkConnection(pingTimeout time.Duration) error { conn, err := getRedis(s) if err != nil { @@ -2474,6 +2518,17 @@ func getRedis(r *Redis) (RedisNode, error) { } } +func getPubSubRedis(r *Redis) (RedisPubSub, error) { + switch r.Type { + case ClusterType: + return getPubSubCluster(r) + case NodeType: + return getPubSubClient(r) + default: + return nil, fmt.Errorf("redis type '%s' is not supported", r.Type) + } +} + func toPairs(vals []red.Z) []Pair { pairs := make([]Pair, len(vals)) for i, val := range vals { diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index b34792ee83c3..00293330f644 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -2175,3 +2175,83 @@ func TestRedisTxPipeline(t *testing.T) { assert.Equal(t, hashValue, value) }) } +func TestRedisSubscribe(t *testing.T) { + runOnRedis(t, func(client *Redis) { + pubsub, err := client.Subscribe("TestChannel") + defer pubsub.Close() + assert.Nil(t, err) + _, err = client.Publish("TestChannel", "TestMessage") + assert.Nil(t, err) + msg, err := pubsub.ReceiveMessage(context.Background()) + assert.Nil(t, err) + assert.Equal(t, msg, &red.Message{ + Channel: "TestChannel", + Payload: "TestMessage", + }) + }) +} +func TestRedisSubscribeCtx(t *testing.T) { + runOnRedis(t, func(client *Redis) { + pubsub, err := client.SubscribeCtx(context.Background(), "TestChannel") + defer pubsub.Close() + assert.Nil(t, err) + _, err = client.Publish("TestChannel", "TestMessage") + assert.Nil(t, err) + msg, err := pubsub.ReceiveMessage(context.Background()) + assert.Nil(t, err) + assert.Equal(t, msg, &red.Message{ + Channel: "TestChannel", + Payload: "TestMessage", + }) + }) +} +func TestRedisPSubscribe(t *testing.T) { + runOnRedis(t, func(client *Redis) { + pubsub, err := client.PSubscribe("TestPattern*") + defer pubsub.Close() + assert.Nil(t, err) + + _, err = client.Publish("TestPattern1", "TestMessage") + assert.Nil(t, err) + + _, err = client.Publish("NoMatchChannel", "NoMatchMessage") + assert.Nil(t, err) + + msg, err := pubsub.ReceiveMessage(context.Background()) + assert.Nil(t, err) + assert.Equal(t, msg, &red.Message{ + Channel: "TestPattern1", + Pattern: "TestPattern*", + Payload: "TestMessage", + }) + }) +} +func TestRedisPSubscribeMultiplePatterns(t *testing.T) { + runOnRedis(t, func(client *Redis) { + pubsub, err := client.PSubscribe("TestPattern*", "AnotherPattern*") + defer pubsub.Close() + assert.Nil(t, err) + + _, err = client.Publish("TestPattern1", "MessageForPattern1") + assert.Nil(t, err) + + _, err = client.Publish("AnotherPattern1", "MessageForAnotherPattern") + assert.Nil(t, err) + + msg, err := pubsub.ReceiveMessage(context.Background()) + assert.Nil(t, err) + assert.Equal(t, msg, &red.Message{ + Channel: "TestPattern1", + Pattern: "TestPattern*", + Payload: "MessageForPattern1", + }) + + msg, err = pubsub.ReceiveMessage(context.Background()) + assert.Nil(t, err) + assert.Equal(t, msg, &red.Message{ + Channel: "AnotherPattern1", + Pattern: "AnotherPattern*", + Payload: "MessageForAnotherPattern", + }) + }) +} diff --git a/core/stores/redis/redispubsubclientmanager.go b/core/stores/redis/redispubsubclientmanager.go new file mode 100644 index 000000000000..02ab698603af --- /dev/null +++ b/core/stores/redis/redispubsubclientmanager.go @@ -0,0 +1,58 @@ +package redis + +import ( + "crypto/tls" + "io" + "runtime" + + red "github.com/redis/go-redis/v9" + "github.com/zeromicro/go-zero/core/syncx" +) + +var ( + pubSubClientManager = syncx.NewResourceManager() + pubSubNodePoolSize = 10 * runtime.GOMAXPROCS(0) +) + +func getPubSubClient(r *Redis) (*red.Client, error) { + val, err := pubSubClientManager.GetResource(r.Addr, func() (io.Closer, error) { + var tlsConfig *tls.Config + if r.tls { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + store := red.NewClient(&red.Options{ + Addr: r.Addr, + Username: r.User, + Password: r.Pass, + DB: defaultDatabase, + MaxRetries: maxRetries, + MinIdleConns: idleConns, + TLSConfig: tlsConfig, + }) + + hooks := append([]red.Hook{defaultDurationHook, breakerHook{ + brk: r.brk, + }}, r.hooks...) + for _, hook := range hooks { + store.AddHook(hook) + } + + connCollector.registerClient(&statGetter{ + clientType: NodeType, + key: r.Addr, + poolSize: pubSubNodePoolSize, + poolStats: func() *red.PoolStats { + return store.PoolStats() + }, + }) + + return store, nil + }) + if err != nil { + return nil, err + } + + return val.(*red.Client), nil +} diff --git a/core/stores/redis/redispubsubclustermanager.go b/core/stores/redis/redispubsubclustermanager.go new file mode 100644 index 000000000000..0039abfe3185 --- /dev/null +++ b/core/stores/redis/redispubsubclustermanager.go @@ -0,0 +1,57 @@ +package redis + +import ( + "crypto/tls" + red "github.com/redis/go-redis/v9" + "github.com/zeromicro/go-zero/core/syncx" + "io" + "runtime" +) + +var ( + pubSubClusterManager = syncx.NewResourceManager() + // pubSubClusterPoolSize is default pool size for cluster type of redis. + pubSubClusterPoolSize = 5 * runtime.GOMAXPROCS(0) +) + +func getPubSubCluster(r *Redis) (*red.ClusterClient, error) { + val, err := pubSubClusterManager.GetResource(r.Addr, func() (io.Closer, error) { + var tlsConfig *tls.Config + if r.tls { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + store := red.NewClusterClient(&red.ClusterOptions{ + Addrs: splitClusterAddrs(r.Addr), + Username: r.User, + Password: r.Pass, + MaxRetries: maxRetries, + MinIdleConns: idleConns, + TLSConfig: tlsConfig, + }) + + hooks := append([]red.Hook{defaultDurationHook, breakerHook{ + brk: r.brk, + }}, r.hooks...) + for _, hook := range hooks { + store.AddHook(hook) + } + + connCollector.registerClient(&statGetter{ + clientType: ClusterType, + key: r.Addr, + poolSize: pubSubClusterPoolSize, + poolStats: func() *red.PoolStats { + return store.PoolStats() + }, + }) + + return store, nil + }) + if err != nil { + return nil, err + } + + return val.(*red.ClusterClient), nil +}