Skip to content

Commit f328d98

Browse files
wwwfengkevwan
authored andcommitted
feat
1 parent 6edfce6 commit f328d98

File tree

4 files changed

+252
-1
lines changed

4 files changed

+252
-1
lines changed

core/stores/redis/redis.go

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ type (
6868
red.BitMapCmdable
6969
}
7070

71+
// RedisPubSub interface represents a redis pubsub node.
72+
RedisPubSub interface {
73+
Subscribe(ctx context.Context, channels ...string) *PubSub
74+
PSubscribe(ctx context.Context, patterns ...string) *PubSub
75+
}
76+
7177
// GeoLocation is used with GeoAdd to add geospatial location.
7278
GeoLocation = red.GeoLocation
7379
// GeoRadiusQuery is used with GeoRadius to query geospatial index.
@@ -103,6 +109,9 @@ type (
103109

104110
// Cmder is an alias of redis.Cmder.
105111
Cmder = red.Cmder
112+
113+
// PubSub is an alias of redis.PubSub.
114+
PubSub = red.PubSub
106115
)
107116

108117
// MustNewRedis returns a Redis with given options.
@@ -1218,7 +1227,6 @@ func (s *Redis) PipelinedCtx(ctx context.Context, fn func(Pipeliner) error) erro
12181227
if err != nil {
12191228
return err
12201229
}
1221-
12221230
_, err = conn.Pipelined(ctx, fn)
12231231
return err
12241232
}
@@ -2402,6 +2410,42 @@ func (s *Redis) ZunionstoreCtx(ctx context.Context, dest string, store *ZStore)
24022410
return conn.ZUnionStore(ctx, dest, store).Result()
24032411
}
24042412

2413+
// Subscribe is the implementation of redis subscribe command.
2414+
func (s *Redis) Subscribe(channels ...string) (*PubSub, error) {
2415+
conn, err := getPubSubRedis(s)
2416+
if err != nil {
2417+
return nil, err
2418+
}
2419+
return conn.Subscribe(context.Background(), channels...), nil
2420+
}
2421+
2422+
// PSubscribe is the implementation of redis psubscribe command.
2423+
func (s *Redis) PSubscribe(patterns ...string) (*PubSub, error) {
2424+
conn, err := getPubSubRedis(s)
2425+
if err != nil {
2426+
return nil, err
2427+
}
2428+
return conn.PSubscribe(context.Background(), patterns...), nil
2429+
}
2430+
2431+
// SubscribeCtx is the implementation of redis subscribe command.
2432+
func (s *Redis) SubscribeCtx(ctx context.Context, channels ...string) (*PubSub, error) {
2433+
conn, err := getPubSubRedis(s)
2434+
if err != nil {
2435+
return nil, err
2436+
}
2437+
return conn.Subscribe(ctx, channels...), nil
2438+
}
2439+
2440+
// PSubscribeCtx is the implementation of redis psubscribe command.
2441+
func (s *Redis) PSubscribeCtx(ctx context.Context, patterns ...string) (*PubSub, error) {
2442+
conn, err := getPubSubRedis(s)
2443+
if err != nil {
2444+
return nil, err
2445+
}
2446+
return conn.PSubscribe(ctx, patterns...), nil
2447+
}
2448+
24052449
func (s *Redis) checkConnection(pingTimeout time.Duration) error {
24062450
conn, err := getRedis(s)
24072451
if err != nil {
@@ -2474,6 +2518,17 @@ func getRedis(r *Redis) (RedisNode, error) {
24742518
}
24752519
}
24762520

2521+
func getPubSubRedis(r *Redis) (RedisPubSub, error) {
2522+
switch r.Type {
2523+
case ClusterType:
2524+
return getPubSubClient(r)
2525+
case NodeType:
2526+
return getPubSubCluster(r)
2527+
default:
2528+
return nil, fmt.Errorf("redis type '%s' is not supported", r.Type)
2529+
}
2530+
}
2531+
24772532
func toPairs(vals []red.Z) []Pair {
24782533
pairs := make([]Pair, len(vals))
24792534
for i, val := range vals {

core/stores/redis/redis_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,3 +2175,84 @@ func TestRedisTxPipeline(t *testing.T) {
21752175
assert.Equal(t, hashValue, value)
21762176
})
21772177
}
2178+
func TestRedisSubscribe(t *testing.T) {
2179+
runOnRedis(t, func(client *Redis) {
2180+
2181+
pubsub, err := client.Subscribe("TestChannel")
2182+
defer pubsub.Close()
2183+
assert.Nil(t, err)
2184+
_, err = client.Publish("TestChannel", "TestMessage")
2185+
2186+
msg, err := pubsub.ReceiveMessage(context.Background())
2187+
assert.Nil(t, err)
2188+
assert.Equal(t, msg, &red.Message{
2189+
Channel: "TestChannel",
2190+
Payload: "TestMessage",
2191+
})
2192+
})
2193+
}
2194+
func TestRedisSubscribeCtx(t *testing.T) {
2195+
runOnRedis(t, func(client *Redis) {
2196+
pubsub, err := client.SubscribeCtx(context.Background(), "TestChannel")
2197+
defer pubsub.Close()
2198+
assert.Nil(t, err)
2199+
_, err = client.Publish("TestChannel", "TestMessage")
2200+
2201+
msg, err := pubsub.ReceiveMessage(context.Background())
2202+
assert.Nil(t, err)
2203+
assert.Equal(t, msg, &red.Message{
2204+
Channel: "TestChannel",
2205+
Payload: "TestMessage",
2206+
})
2207+
})
2208+
}
2209+
func TestRedisPSubscribe(t *testing.T) {
2210+
runOnRedis(t, func(client *Redis) {
2211+
pubsub, err := client.PSubscribe("TestPattern*")
2212+
defer pubsub.Close()
2213+
assert.Nil(t, err)
2214+
2215+
_, err = client.Publish("TestPattern1", "TestMessage")
2216+
assert.Nil(t, err)
2217+
2218+
_, err = client.Publish("NoMatchChannel", "NoMatchMessage")
2219+
assert.Nil(t, err)
2220+
2221+
msg, err := pubsub.ReceiveMessage(context.Background())
2222+
assert.Nil(t, err)
2223+
assert.Equal(t, msg, &red.Message{
2224+
Channel: "TestPattern1",
2225+
Pattern: "TestPattern*",
2226+
Payload: "TestMessage",
2227+
})
2228+
})
2229+
}
2230+
func TestRedisPSubscribeMultiplePatterns(t *testing.T) {
2231+
runOnRedis(t, func(client *Redis) {
2232+
pubsub, err := client.PSubscribe("TestPattern*", "AnotherPattern*")
2233+
defer pubsub.Close()
2234+
assert.Nil(t, err)
2235+
2236+
_, err = client.Publish("TestPattern1", "MessageForPattern1")
2237+
assert.Nil(t, err)
2238+
2239+
_, err = client.Publish("AnotherPattern1", "MessageForAnotherPattern")
2240+
assert.Nil(t, err)
2241+
2242+
msg, err := pubsub.ReceiveMessage(context.Background())
2243+
assert.Nil(t, err)
2244+
assert.Equal(t, msg, &red.Message{
2245+
Channel: "TestPattern1",
2246+
Pattern: "TestPattern*",
2247+
Payload: "MessageForPattern1",
2248+
})
2249+
2250+
msg, err = pubsub.ReceiveMessage(context.Background())
2251+
assert.Nil(t, err)
2252+
assert.Equal(t, msg, &red.Message{
2253+
Channel: "AnotherPattern1",
2254+
Pattern: "AnotherPattern*",
2255+
Payload: "MessageForAnotherPattern",
2256+
})
2257+
})
2258+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package redis
2+
3+
import (
4+
"crypto/tls"
5+
"io"
6+
"runtime"
7+
8+
red "github.com/redis/go-redis/v9"
9+
"github.com/zeromicro/go-zero/core/syncx"
10+
)
11+
12+
var (
13+
pubSubClientManager = syncx.NewResourceManager()
14+
pubSubNodePoolSize = 10 * runtime.GOMAXPROCS(0)
15+
)
16+
17+
func getPubSubClient(r *Redis) (*red.Client, error) {
18+
val, err := pubSubClientManager.GetResource(r.Addr, func() (io.Closer, error) {
19+
var tlsConfig *tls.Config
20+
if r.tls {
21+
tlsConfig = &tls.Config{
22+
InsecureSkipVerify: true,
23+
}
24+
}
25+
store := red.NewClient(&red.Options{
26+
Addr: r.Addr,
27+
Username: r.User,
28+
Password: r.Pass,
29+
DB: defaultDatabase,
30+
MaxRetries: maxRetries,
31+
MinIdleConns: idleConns,
32+
TLSConfig: tlsConfig,
33+
})
34+
35+
hooks := append([]red.Hook{defaultDurationHook, breakerHook{
36+
brk: r.brk,
37+
}}, r.hooks...)
38+
for _, hook := range hooks {
39+
store.AddHook(hook)
40+
}
41+
42+
connCollector.registerClient(&statGetter{
43+
clientType: NodeType,
44+
key: r.Addr,
45+
poolSize: pubSubNodePoolSize,
46+
poolStats: func() *red.PoolStats {
47+
return store.PoolStats()
48+
},
49+
})
50+
51+
return store, nil
52+
})
53+
if err != nil {
54+
return nil, err
55+
}
56+
57+
return val.(*red.Client), nil
58+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package redis
2+
3+
import (
4+
"crypto/tls"
5+
red "github.com/redis/go-redis/v9"
6+
"github.com/zeromicro/go-zero/core/syncx"
7+
"io"
8+
"runtime"
9+
)
10+
11+
var (
12+
pubSubClusterManager = syncx.NewResourceManager()
13+
// clusterPoolSize is default pool size for cluster type of redis.
14+
pubSubClusterPoolSize = 5 * runtime.GOMAXPROCS(0)
15+
)
16+
17+
func getPubSubCluster(r *Redis) (*red.ClusterClient, error) {
18+
val, err := pubSubClusterManager.GetResource(r.Addr, func() (io.Closer, error) {
19+
var tlsConfig *tls.Config
20+
if r.tls {
21+
tlsConfig = &tls.Config{
22+
InsecureSkipVerify: true,
23+
}
24+
}
25+
store := red.NewClusterClient(&red.ClusterOptions{
26+
Addrs: splitClusterAddrs(r.Addr),
27+
Username: r.User,
28+
Password: r.Pass,
29+
MaxRetries: maxRetries,
30+
MinIdleConns: idleConns,
31+
TLSConfig: tlsConfig,
32+
})
33+
34+
hooks := append([]red.Hook{defaultDurationHook, breakerHook{
35+
brk: r.brk,
36+
}}, r.hooks...)
37+
for _, hook := range hooks {
38+
store.AddHook(hook)
39+
}
40+
41+
connCollector.registerClient(&statGetter{
42+
clientType: ClusterType,
43+
key: r.Addr,
44+
poolSize: pubSubClusterPoolSize,
45+
poolStats: func() *red.PoolStats {
46+
return store.PoolStats()
47+
},
48+
})
49+
50+
return store, nil
51+
})
52+
if err != nil {
53+
return nil, err
54+
}
55+
56+
return val.(*red.ClusterClient), nil
57+
}

0 commit comments

Comments
 (0)