Skip to content

feat: add sub for redis #4725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion core/stores/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ type (
red.BitMapCmdable
}

// RedisPubSub interface represents a redis pubsub node.
RedisPubSub interface {
Subscribe(ctx context.Context, channels ...string) *PubSub
PSubscribe(ctx context.Context, patterns ...string) *PubSub
}

// GeoLocation is used with GeoAdd to add geospatial location.
GeoLocation = red.GeoLocation
// GeoRadiusQuery is used with GeoRadius to query geospatial index.
Expand Down Expand Up @@ -103,6 +109,9 @@ type (

// Cmder is an alias of redis.Cmder.
Cmder = red.Cmder

// PubSub is an alias of redis.PubSub.
PubSub = red.PubSub
)

// MustNewRedis returns a Redis with given options.
Expand Down Expand Up @@ -1218,7 +1227,6 @@ func (s *Redis) PipelinedCtx(ctx context.Context, fn func(Pipeliner) error) erro
if err != nil {
return err
}

_, err = conn.Pipelined(ctx, fn)
return err
}
Expand Down Expand Up @@ -2402,6 +2410,42 @@ func (s *Redis) ZunionstoreCtx(ctx context.Context, dest string, store *ZStore)
return conn.ZUnionStore(ctx, dest, store).Result()
}

// Subscribe is the implementation of redis subscribe command.
func (s *Redis) Subscribe(channels ...string) (*PubSub, error) {
conn, err := getPubSubRedis(s)
if err != nil {
return nil, err
}
return conn.Subscribe(context.Background(), channels...), nil
}

// PSubscribe is the implementation of redis psubscribe command.
func (s *Redis) PSubscribe(patterns ...string) (*PubSub, error) {
conn, err := getPubSubRedis(s)
if err != nil {
return nil, err
}
return conn.PSubscribe(context.Background(), patterns...), nil
}

// SubscribeCtx is the implementation of redis subscribe command.
func (s *Redis) SubscribeCtx(ctx context.Context, channels ...string) (*PubSub, error) {
conn, err := getPubSubRedis(s)
if err != nil {
return nil, err
}
return conn.Subscribe(ctx, channels...), nil
}

// PSubscribeCtx is the implementation of redis psubscribe command.
func (s *Redis) PSubscribeCtx(ctx context.Context, patterns ...string) (*PubSub, error) {
conn, err := getPubSubRedis(s)
if err != nil {
return nil, err
}
return conn.PSubscribe(ctx, patterns...), nil
}

func (s *Redis) checkConnection(pingTimeout time.Duration) error {
conn, err := getRedis(s)
if err != nil {
Expand Down Expand Up @@ -2474,6 +2518,17 @@ func getRedis(r *Redis) (RedisNode, error) {
}
}

func getPubSubRedis(r *Redis) (RedisPubSub, error) {
switch r.Type {
case ClusterType:
return getPubSubCluster(r)
case NodeType:
return getPubSubClient(r)
default:
return nil, fmt.Errorf("redis type '%s' is not supported", r.Type)
}
}

func toPairs(vals []red.Z) []Pair {
pairs := make([]Pair, len(vals))
for i, val := range vals {
Expand Down
80 changes: 80 additions & 0 deletions core/stores/redis/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2175,3 +2175,83 @@ func TestRedisTxPipeline(t *testing.T) {
assert.Equal(t, hashValue, value)
})
}
func TestRedisSubscribe(t *testing.T) {
runOnRedis(t, func(client *Redis) {
pubsub, err := client.Subscribe("TestChannel")
defer pubsub.Close()
assert.Nil(t, err)
_, err = client.Publish("TestChannel", "TestMessage")
assert.Nil(t, err)
msg, err := pubsub.ReceiveMessage(context.Background())
assert.Nil(t, err)
assert.Equal(t, msg, &red.Message{
Channel: "TestChannel",
Payload: "TestMessage",
})
})
}
func TestRedisSubscribeCtx(t *testing.T) {
runOnRedis(t, func(client *Redis) {
pubsub, err := client.SubscribeCtx(context.Background(), "TestChannel")
defer pubsub.Close()
assert.Nil(t, err)
_, err = client.Publish("TestChannel", "TestMessage")
assert.Nil(t, err)
msg, err := pubsub.ReceiveMessage(context.Background())
assert.Nil(t, err)
assert.Equal(t, msg, &red.Message{
Channel: "TestChannel",
Payload: "TestMessage",
})
})
}
func TestRedisPSubscribe(t *testing.T) {
runOnRedis(t, func(client *Redis) {
pubsub, err := client.PSubscribe("TestPattern*")
defer pubsub.Close()
assert.Nil(t, err)

_, err = client.Publish("TestPattern1", "TestMessage")
assert.Nil(t, err)

_, err = client.Publish("NoMatchChannel", "NoMatchMessage")
assert.Nil(t, err)

msg, err := pubsub.ReceiveMessage(context.Background())
assert.Nil(t, err)
assert.Equal(t, msg, &red.Message{
Channel: "TestPattern1",
Pattern: "TestPattern*",
Payload: "TestMessage",
})
})
}
func TestRedisPSubscribeMultiplePatterns(t *testing.T) {
runOnRedis(t, func(client *Redis) {
pubsub, err := client.PSubscribe("TestPattern*", "AnotherPattern*")
defer pubsub.Close()
assert.Nil(t, err)

_, err = client.Publish("TestPattern1", "MessageForPattern1")
assert.Nil(t, err)

_, err = client.Publish("AnotherPattern1", "MessageForAnotherPattern")
assert.Nil(t, err)

msg, err := pubsub.ReceiveMessage(context.Background())
assert.Nil(t, err)
assert.Equal(t, msg, &red.Message{
Channel: "TestPattern1",
Pattern: "TestPattern*",
Payload: "MessageForPattern1",
})

msg, err = pubsub.ReceiveMessage(context.Background())
assert.Nil(t, err)
assert.Equal(t, msg, &red.Message{
Channel: "AnotherPattern1",
Pattern: "AnotherPattern*",
Payload: "MessageForAnotherPattern",
})
})
}
58 changes: 58 additions & 0 deletions core/stores/redis/redispubsubclientmanager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package redis

import (
"crypto/tls"
"io"
"runtime"

red "github.com/redis/go-redis/v9"
"github.com/zeromicro/go-zero/core/syncx"
)

var (
pubSubClientManager = syncx.NewResourceManager()
pubSubNodePoolSize = 10 * runtime.GOMAXPROCS(0)
)

func getPubSubClient(r *Redis) (*red.Client, error) {
val, err := pubSubClientManager.GetResource(r.Addr, func() (io.Closer, error) {
var tlsConfig *tls.Config
if r.tls {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,

Check failure

Code scanning / CodeQL

Disabled TLS certificate check High

InsecureSkipVerify should not be used in production code.
}
}
store := red.NewClient(&red.Options{
Addr: r.Addr,
Username: r.User,
Password: r.Pass,
DB: defaultDatabase,
MaxRetries: maxRetries,
MinIdleConns: idleConns,
TLSConfig: tlsConfig,
})

hooks := append([]red.Hook{defaultDurationHook, breakerHook{
brk: r.brk,
}}, r.hooks...)
for _, hook := range hooks {
store.AddHook(hook)
}

connCollector.registerClient(&statGetter{
clientType: NodeType,
key: r.Addr,
poolSize: pubSubNodePoolSize,
poolStats: func() *red.PoolStats {
return store.PoolStats()
},
})

return store, nil
})
if err != nil {
return nil, err
}

return val.(*red.Client), nil
}
57 changes: 57 additions & 0 deletions core/stores/redis/redispubsubclustermanager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package redis

import (
"crypto/tls"
red "github.com/redis/go-redis/v9"
"github.com/zeromicro/go-zero/core/syncx"
"io"
"runtime"
)

var (
pubSubClusterManager = syncx.NewResourceManager()
// pubSubClusterPoolSize is default pool size for cluster type of redis.
pubSubClusterPoolSize = 5 * runtime.GOMAXPROCS(0)
)

func getPubSubCluster(r *Redis) (*red.ClusterClient, error) {
val, err := pubSubClusterManager.GetResource(r.Addr, func() (io.Closer, error) {
var tlsConfig *tls.Config
if r.tls {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,

Check failure

Code scanning / CodeQL

Disabled TLS certificate check High

InsecureSkipVerify should not be used in production code.
}
}
store := red.NewClusterClient(&red.ClusterOptions{
Addrs: splitClusterAddrs(r.Addr),
Username: r.User,
Password: r.Pass,
MaxRetries: maxRetries,
MinIdleConns: idleConns,
TLSConfig: tlsConfig,
})

hooks := append([]red.Hook{defaultDurationHook, breakerHook{
brk: r.brk,
}}, r.hooks...)
for _, hook := range hooks {
store.AddHook(hook)
}

connCollector.registerClient(&statGetter{
clientType: ClusterType,
key: r.Addr,
poolSize: pubSubClusterPoolSize,
poolStats: func() *red.PoolStats {
return store.PoolStats()
},
})

return store, nil
})
if err != nil {
return nil, err
}

return val.(*red.ClusterClient), nil
}