1
+ package codegen
2
+
3
+ import (
4
+ "strings"
5
+ "testing"
6
+
7
+ . "goa.design/goa/v3/dsl"
8
+ "goa.design/goa/v3/grpc/codegen/testdata"
9
+ "github.com/stretchr/testify/assert"
10
+ "github.com/stretchr/testify/require"
11
+ )
12
+
13
+ // TestStreamingWithErrors tests that streaming endpoints properly handle
14
+ // custom errors defined in the service DSL.
15
+ func TestStreamingWithErrors (t * testing.T ) {
16
+ cases := []struct {
17
+ name string
18
+ dsl func ()
19
+ testFunc func (t * testing.T , code string )
20
+ }{
21
+ {
22
+ name : "server streaming with custom errors" ,
23
+ dsl : testdata .ServerStreamingWithCustomErrorsDSL ,
24
+ testFunc : func (t * testing.T , code string ) {
25
+ // Verify error decoding is present
26
+ assert .Contains (t , code , "goagrpc.DecodeError(err)" ,
27
+ "should decode errors from stream" )
28
+
29
+ // Verify custom error types are handled
30
+ assert .Contains (t , code , "case *streaming_error_servicepb.ServerStreamCustomErrorError:" ,
31
+ "should handle custom error type" )
32
+ assert .Contains (t , code , "case *streaming_error_servicepb.ServerStreamValidationErrorError:" ,
33
+ "should handle validation error type" )
34
+
35
+ // Verify generic errors are handled
36
+ assert .Contains (t , code , "case *goapb.ErrorResponse:" ,
37
+ "should handle generic goa errors" )
38
+
39
+ // Verify proper error construction
40
+ assert .Contains (t , code , "NewServerStreamCustomErrorError(message" ,
41
+ "should construct custom error" )
42
+ assert .Contains (t , code , "NewServerStreamValidationErrorError(message" ,
43
+ "should construct validation error" )
44
+ },
45
+ },
46
+ {
47
+ name : "bidirectional streaming with errors" ,
48
+ dsl : testdata .BidirectionalStreamingRPCWithErrorsDSL ,
49
+ testFunc : func (t * testing.T , code string ) {
50
+ // Bidirectional streaming with simple errors should still decode
51
+ assert .Contains (t , code , "goagrpc.DecodeError(err)" ,
52
+ "should decode errors from bidirectional stream" )
53
+ assert .Contains (t , code , "case *goapb.ErrorResponse:" ,
54
+ "should handle generic errors in bidirectional streaming" )
55
+ },
56
+ },
57
+ }
58
+
59
+ for _ , c := range cases {
60
+ t .Run (c .name , func (t * testing.T ) {
61
+ root := RunGRPCDSL (t , c .dsl )
62
+ services := CreateGRPCServices (root )
63
+ clientfs := ClientFiles ("" , services )
64
+ require .Greater (t , len (clientfs ), 0 )
65
+
66
+ // Get recv method implementations
67
+ recvSections := clientfs [0 ].Section ("client-stream-recv" )
68
+ require .Greater (t , len (recvSections ), 0 )
69
+
70
+ // Build complete recv method code
71
+ var codeBuilder strings.Builder
72
+ for _ , section := range recvSections {
73
+ require .NoError (t , section .Write (& codeBuilder ))
74
+ }
75
+ code := codeBuilder .String ()
76
+
77
+ // Run test-specific assertions
78
+ c .testFunc (t , code )
79
+ })
80
+ }
81
+ }
82
+
83
+ // TestStreamingErrorsWithValidation verifies that custom errors with
84
+ // validation rules are properly validated in streaming recv methods.
85
+ func TestStreamingErrorsWithValidation (t * testing.T ) {
86
+ root := RunGRPCDSL (t , testdata .ServerStreamingWithCustomErrorsDSL )
87
+ services := CreateGRPCServices (root )
88
+
89
+ // Verify the DSL has errors with validation
90
+ require .Len (t , root .Services , 1 )
91
+ svc := root .Services [0 ]
92
+ require .Len (t , svc .Methods , 1 )
93
+ method := svc .Methods [0 ]
94
+ require .Greater (t , len (method .Errors ), 0 , "method should have errors defined" )
95
+
96
+ // Generate client code
97
+ clientfs := ClientFiles ("" , services )
98
+ require .Greater (t , len (clientfs ), 0 )
99
+
100
+ // Check recv implementations
101
+ recvSections := clientfs [0 ].Section ("client-stream-recv" )
102
+ var code strings.Builder
103
+ for _ , section := range recvSections {
104
+ require .NoError (t , section .Write (& code ))
105
+ }
106
+ recvCode := code .String ()
107
+
108
+ // For errors with validation, verify validation is called
109
+ if strings .Contains (recvCode , "ValidateServerStreamCustomErrorError" ) {
110
+ assert .Contains (t , recvCode , "if err := ValidateServerStreamCustomErrorError(message); err != nil {" ,
111
+ "should validate custom error before returning" )
112
+ }
113
+ }
114
+
115
+ // TestStreamingErrorComparison compares error handling between unary and
116
+ // streaming methods to ensure consistency.
117
+ func TestStreamingErrorComparison (t * testing.T ) {
118
+ // DSL with both unary and streaming methods with errors
119
+ dsl := func () {
120
+ var CustomError = Type ("CustomError" , func () {
121
+ ErrorName ("name" , String , "error name" )
122
+ Attribute ("message" , String , "error message" )
123
+ Required ("name" , "message" )
124
+ })
125
+
126
+ Service ("MixedService" , func () {
127
+ // Unary method with custom error
128
+ Method ("UnaryMethod" , func () {
129
+ Payload (String )
130
+ Result (String )
131
+ Error ("custom_error" , CustomError )
132
+ GRPC (func () {
133
+ Response ("custom_error" , CodeInvalidArgument )
134
+ })
135
+ })
136
+
137
+ // Streaming method with same error
138
+ Method ("StreamingMethod" , func () {
139
+ Payload (String )
140
+ StreamingResult (String )
141
+ Error ("custom_error" , CustomError )
142
+ GRPC (func () {
143
+ Response ("custom_error" , CodeInvalidArgument )
144
+ })
145
+ })
146
+ })
147
+ }
148
+
149
+ root := RunGRPCDSL (t , dsl )
150
+ services := CreateGRPCServices (root )
151
+ clientfs := ClientFiles ("" , services )
152
+ require .Greater (t , len (clientfs ), 0 , "should have client files" )
153
+
154
+ // Find unary and streaming code in different sections
155
+ var unaryCode , streamCode string
156
+
157
+ // For unary, look in client-endpoint-init
158
+ if sections := clientfs [0 ].Section ("client-endpoint-init" ); len (sections ) > 0 {
159
+ var code strings.Builder
160
+ for _ , section := range sections {
161
+ require .NoError (t , section .Write (& code ))
162
+ }
163
+ unaryCode = code .String ()
164
+ }
165
+
166
+ // For streaming, look in client-stream-recv
167
+ if sections := clientfs [0 ].Section ("client-stream-recv" ); len (sections ) > 0 {
168
+ var code strings.Builder
169
+ for _ , section := range sections {
170
+ require .NoError (t , section .Write (& code ))
171
+ }
172
+ streamCode = code .String ()
173
+ }
174
+
175
+ // If no sections found, skip test with explanation
176
+ if unaryCode == "" || streamCode == "" {
177
+ t .Skip ("Cannot compare unary and streaming - sections not found in generated code" )
178
+ }
179
+
180
+ // Both should decode errors
181
+ assert .Contains (t , unaryCode , "goagrpc.DecodeError(err)" ,
182
+ "unary methods should decode errors" )
183
+ assert .Contains (t , streamCode , "goagrpc.DecodeError(err)" ,
184
+ "streaming methods should decode errors" )
185
+
186
+ // Both should handle the custom error type
187
+ assert .Contains (t , unaryCode , "case *mixed_servicepb." ,
188
+ "unary should handle custom error types" )
189
+ assert .Contains (t , streamCode , "case *mixed_servicepb." ,
190
+ "streaming should handle custom error types" )
191
+ }
0 commit comments