Skip to content

Commit 6edfce6

Browse files
authored
feat: add rest.WithSSE to build SSE route easier (#4729)
1 parent cdb0098 commit 6edfce6

File tree

13 files changed

+106
-34
lines changed

13 files changed

+106
-34
lines changed

rest/engine.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/zeromicro/go-zero/rest/handler"
1616
"github.com/zeromicro/go-zero/rest/httpx"
1717
"github.com/zeromicro/go-zero/rest/internal"
18+
"github.com/zeromicro/go-zero/rest/internal/header"
1819
"github.com/zeromicro/go-zero/rest/internal/response"
1920
)
2021

@@ -54,6 +55,9 @@ func newEngine(c RestConf) *engine {
5455
}
5556

5657
func (ng *engine) addRoutes(r featuredRoutes) {
58+
if r.sse {
59+
r.routes = buildSSERoutes(r.routes)
60+
}
5761
ng.routes = append(ng.routes, r)
5862

5963
// need to guarantee the timeout is the max of all routes
@@ -63,6 +67,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
6367
}
6468
}
6569

70+
func buildSSERoutes(routes []Route) []Route {
71+
for i, route := range routes {
72+
h := route.Handler
73+
routes[i].Handler = func(w http.ResponseWriter, r *http.Request) {
74+
w.Header().Set(header.ContentType, header.ContentTypeEventStream)
75+
w.Header().Set(header.CacheControl, header.CacheControlNoCache)
76+
w.Header().Set(header.Connection, header.ConnectionKeepAlive)
77+
h(w, r)
78+
}
79+
}
80+
81+
return routes
82+
}
83+
6684
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
6785
verifier func(chain.Chain) chain.Chain) chain.Chain {
6886
if fr.jwt.enabled {

rest/httpc/requests.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func buildRequest(ctx context.Context, method, url string, data any) (*http.Requ
105105
req.URL.RawQuery = buildFormQuery(u, val[formKey])
106106
fillHeader(req, val[headerKey])
107107
if hasJsonBody {
108-
req.Header.Set(header.ContentType, header.JsonContentType)
108+
req.Header.Set(header.ContentType, header.ContentTypeJson)
109109
}
110110

111111
return req, nil

rest/httpc/requests_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestDoRequest_NotFound(t *testing.T) {
4545
defer svr.Close()
4646
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
4747
assert.Nil(t, err)
48-
req.Header.Set(header.ContentType, header.JsonContentType)
48+
req.Header.Set(header.ContentType, header.ContentTypeJson)
4949
resp, err := DoRequest(req)
5050
assert.Nil(t, err)
5151
assert.Equal(t, http.StatusNotFound, resp.StatusCode)

rest/httpc/responses_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func TestParse(t *testing.T) {
1818
}
1919
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2020
w.Header().Set("foo", "bar")
21-
w.Header().Set(header.ContentType, header.JsonContentType)
21+
w.Header().Set(header.ContentType, header.ContentTypeJson)
2222
w.Write([]byte(`{"name":"kevin","value":100}`))
2323
}))
2424
defer svr.Close()
@@ -38,7 +38,7 @@ func TestParseHeaderError(t *testing.T) {
3838
}
3939
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4040
w.Header().Set("foo", "bar")
41-
w.Header().Set(header.ContentType, header.JsonContentType)
41+
w.Header().Set(header.ContentType, header.ContentTypeJson)
4242
}))
4343
defer svr.Close()
4444
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -54,7 +54,7 @@ func TestParseNoBody(t *testing.T) {
5454
}
5555
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5656
w.Header().Set("foo", "bar")
57-
w.Header().Set(header.ContentType, header.JsonContentType)
57+
w.Header().Set(header.ContentType, header.ContentTypeJson)
5858
}))
5959
defer svr.Close()
6060
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -72,7 +72,7 @@ func TestParseWithZeroValue(t *testing.T) {
7272
}
7373
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
7474
w.Header().Set("foo", "0")
75-
w.Header().Set(header.ContentType, header.JsonContentType)
75+
w.Header().Set(header.ContentType, header.ContentTypeJson)
7676
w.Write([]byte(`{"bar":0}`))
7777
}))
7878
defer svr.Close()
@@ -90,7 +90,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
9090
Bar int `json:"bar"`
9191
}
9292
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
93-
w.Header().Set(header.ContentType, header.JsonContentType)
93+
w.Header().Set(header.ContentType, header.ContentTypeJson)
9494
w.Write([]byte(`{"bar":0}`))
9595
}))
9696
defer svr.Close()
@@ -124,7 +124,7 @@ func TestParseWithNegativeContentLength(t *testing.T) {
124124
func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
125125
var val struct{}
126126
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
127-
w.Header().Set(header.ContentType, header.JsonContentType)
127+
w.Header().Set(header.ContentType, header.ContentTypeJson)
128128
}))
129129
defer svr.Close()
130130
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)
@@ -156,7 +156,7 @@ func TestParseWithNegativeContentLengthNoBody(t *testing.T) {
156156
func TestParseJsonBody_BodyError(t *testing.T) {
157157
var val struct{}
158158
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
159-
w.Header().Set(header.ContentType, header.JsonContentType)
159+
w.Header().Set(header.ContentType, header.ContentTypeJson)
160160
}))
161161
defer svr.Close()
162162
req, err := http.NewRequest(http.MethodGet, svr.URL, nil)

rest/httpc/service_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func TestNamedService_DoRequestPost(t *testing.T) {
4444
service := NewService("foo")
4545
req, err := http.NewRequest(http.MethodPost, svr.URL, nil)
4646
assert.Nil(t, err)
47-
req.Header.Set(header.ContentType, header.JsonContentType)
47+
req.Header.Set(header.ContentType, header.ContentTypeJson)
4848
resp, err := service.DoRequest(req)
4949
assert.Nil(t, err)
5050
assert.Equal(t, http.StatusNotFound, resp.StatusCode)

rest/httpx/requests_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ func TestParseJsonBody(t *testing.T) {
476476

477477
body := `{"name":"kevin", "age": 18}`
478478
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
479-
r.Header.Set(ContentType, header.JsonContentType)
479+
r.Header.Set(ContentType, header.ContentTypeJson)
480480

481481
if assert.NoError(t, Parse(r, &v)) {
482482
assert.Equal(t, "kevin", v.Name)
@@ -492,7 +492,7 @@ func TestParseJsonBody(t *testing.T) {
492492

493493
body := `{"name":"kevin", "ag": 18}`
494494
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
495-
r.Header.Set(ContentType, header.JsonContentType)
495+
r.Header.Set(ContentType, header.ContentTypeJson)
496496

497497
assert.Error(t, Parse(r, &v))
498498
})
@@ -517,7 +517,7 @@ func TestParseJsonBody(t *testing.T) {
517517

518518
body := `[{"name":"kevin", "age": 18}]`
519519
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
520-
r.Header.Set(ContentType, header.JsonContentType)
520+
r.Header.Set(ContentType, header.ContentTypeJson)
521521

522522
assert.NoError(t, Parse(r, &v))
523523
assert.Equal(t, 1, len(v))
@@ -537,7 +537,7 @@ func TestParseJsonBody(t *testing.T) {
537537

538538
body := `[{"name":"apple", "age": 18}]`
539539
r := httptest.NewRequest(http.MethodPost, "/a?product=tree", strings.NewReader(body))
540-
r.Header.Set(ContentType, header.JsonContentType)
540+
r.Header.Set(ContentType, header.ContentTypeJson)
541541

542542
assert.NoError(t, Parse(r, &v))
543543
assert.Equal(t, 1, len(v))
@@ -555,7 +555,7 @@ func TestParseJsonBody(t *testing.T) {
555555
body, _ := json.Marshal(v1)
556556
t.Logf("body:%s", string(body))
557557
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(body)))
558-
r.Header.Set(ContentType, header.JsonContentType)
558+
r.Header.Set(ContentType, header.ContentTypeJson)
559559
var v2 v
560560
err := ParseJsonBody(r, &v2)
561561
if assert.NoError(t, err) {
@@ -609,7 +609,7 @@ func TestParseHeaders(t *testing.T) {
609609
request.Header.Add("addrs", "addr2")
610610
request.Header.Add("X-Forwarded-For", "10.0.10.11")
611611
request.Header.Add("x-real-ip", "10.0.11.10")
612-
request.Header.Add("Accept", header.JsonContentType)
612+
request.Header.Add("Accept", header.ContentTypeJson)
613613
err = ParseHeaders(request, &v)
614614
if err != nil {
615615
t.Fatal(err)
@@ -619,7 +619,7 @@ func TestParseHeaders(t *testing.T) {
619619
assert.Equal(t, []string{"addr1", "addr2"}, v.Addrs)
620620
assert.Equal(t, "10.0.10.11", v.XForwardedFor)
621621
assert.Equal(t, "10.0.11.10", v.XRealIP)
622-
assert.Equal(t, header.JsonContentType, v.Accept)
622+
assert.Equal(t, header.ContentTypeJson, v.Accept)
623623
}
624624

625625
func TestParseHeaders_Error(t *testing.T) {
@@ -711,7 +711,7 @@ func TestParseWithFloatPtr(t *testing.T) {
711711
}
712712
body := `{"weightFloat32": 3.2}`
713713
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
714-
r.Header.Set(ContentType, header.JsonContentType)
714+
r.Header.Set(ContentType, header.ContentTypeJson)
715715

716716
if assert.NoError(t, Parse(r, &v)) {
717717
assert.Equal(t, float32(3.2), *v.WeightFloat32)

rest/httpx/responses.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ func doWriteJson(w http.ResponseWriter, code int, v any) error {
179179
return fmt.Errorf("marshal json failed, error: %w", err)
180180
}
181181

182-
w.Header().Set(ContentType, header.JsonContentType)
182+
w.Header().Set(ContentType, header.ContentTypeJson)
183183
w.WriteHeader(code)
184184

185185
if n, err := w.Write(bs); err != nil {

rest/httpx/vars.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ const (
1010
// ContentType means Content-Type.
1111
ContentType = header.ContentType
1212
// JsonContentType means application/json.
13-
JsonContentType = header.JsonContentType
13+
JsonContentType = header.ContentTypeJson
1414
// KeyField means key.
1515
KeyField = "key"
1616
// SecretField means secret.

rest/internal/header/headers.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,18 @@ package header
33
const (
44
// ApplicationJson stands for application/json.
55
ApplicationJson = "application/json"
6+
// CacheControl is the header key for Cache-Control.
7+
CacheControl = "Cache-Control"
8+
// CacheControlNoCache is the value for Cache-Control: no-cache.
9+
CacheControlNoCache = "no-cache"
10+
// Connection is the header key for Connection.
11+
Connection = "Connection"
12+
// ConnectionKeepAlive is the value for Connection: keep-alive.
13+
ConnectionKeepAlive = "keep-alive"
614
// ContentType is the header key for Content-Type.
715
ContentType = "Content-Type"
8-
// JsonContentType is the content type for JSON.
9-
JsonContentType = "application/json; charset=utf-8"
16+
// ContentTypeJson is the content type for JSON.
17+
ContentTypeJson = "application/json; charset=utf-8"
18+
// ContentTypeEventStream is the content type for event stream.
19+
ContentTypeEventStream = "text/event-stream"
1020
)

rest/router/patrouter_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ func TestParseWrappedRequest(t *testing.T) {
628628
func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
629629
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", bytes.NewReader(nil))
630630
assert.Nil(t, err)
631-
r.Header.Set(httpx.ContentType, header.JsonContentType)
631+
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
632632

633633
type (
634634
Request struct {
@@ -661,7 +661,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) {
661661
func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) {
662662
r, err := http.NewRequest(http.MethodHead, "http://hello.com/kevin/2017", bytes.NewReader(nil))
663663
assert.Nil(t, err)
664-
r.Header.Set(httpx.ContentType, header.JsonContentType)
664+
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
665665

666666
type (
667667
Request struct {
@@ -758,7 +758,7 @@ func TestParseWithAllUtf8(t *testing.T) {
758758
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
759759
bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`))
760760
assert.Nil(t, err)
761-
r.Header.Set(httpx.ContentType, header.JsonContentType)
761+
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
762762

763763
router := NewRouter()
764764
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
@@ -948,7 +948,7 @@ func TestParseWithMissingAllPaths(t *testing.T) {
948948
func TestParseGetWithContentLengthHeader(t *testing.T) {
949949
r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil)
950950
assert.Nil(t, err)
951-
r.Header.Set(httpx.ContentType, header.JsonContentType)
951+
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
952952
r.Header.Set(contentLength, "1024")
953953

954954
router := NewRouter()
@@ -976,7 +976,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) {
976976
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000",
977977
bytes.NewBufferString(`{"time": "20170912"}`))
978978
assert.Nil(t, err)
979-
r.Header.Set(httpx.ContentType, header.JsonContentType)
979+
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
980980

981981
router := NewRouter()
982982
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(
@@ -1002,7 +1002,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) {
10021002
r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017",
10031003
bytes.NewBufferString(`{"time": 20170912}`))
10041004
assert.Nil(t, err)
1005-
r.Header.Set(httpx.ContentType, header.JsonContentType)
1005+
r.Header.Set(httpx.ContentType, header.ContentTypeJson)
10061006

10071007
router := NewRouter()
10081008
err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(

0 commit comments

Comments
 (0)