Skip to content

Commit 6992f74

Browse files
committed
Change from offset and limit to cursor based paging
1 parent 54c7cff commit 6992f74

File tree

4 files changed

+296
-23
lines changed

4 files changed

+296
-23
lines changed

internal/cursor/cursor.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Package cursor implements logic for paging.
2+
package cursor
3+
4+
import (
5+
"encoding/base64"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
)
10+
11+
// Cursor represents pagination state with offset for next offset.
12+
type Cursor struct {
13+
Offset int32 `json:"offset"`
14+
}
15+
16+
// New creates and validates a new Cursor.
17+
func New(offset int32) (*Cursor, error) {
18+
cursor := &Cursor{
19+
Offset: offset,
20+
}
21+
22+
if err := cursor.validate(); err != nil {
23+
return nil, err
24+
}
25+
26+
return cursor, nil
27+
}
28+
29+
// Encode serializes the cursor to a Base64-encoded string.
30+
func (c *Cursor) Encode() (string, error) {
31+
if err := c.validate(); err != nil {
32+
return "", err
33+
}
34+
35+
jsonBytes, err := json.Marshal(c)
36+
if err != nil {
37+
return "", fmt.Errorf("failed to marshal cursor: %w", err)
38+
}
39+
40+
encoded := base64.StdEncoding.EncodeToString(jsonBytes)
41+
42+
return encoded, nil
43+
}
44+
45+
// Decode deserializes a Base64-encoded string to a Cursor.
46+
func Decode(encoded string) (*Cursor, error) {
47+
if encoded == "" {
48+
return nil, errors.New("encoded cursor cannot be empty")
49+
}
50+
51+
decoded, err := base64.StdEncoding.DecodeString(encoded)
52+
if err != nil {
53+
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
54+
}
55+
56+
var cursor Cursor
57+
if err := json.Unmarshal(decoded, &cursor); err != nil {
58+
return nil, fmt.Errorf("invalid cursor format: %w", err)
59+
}
60+
61+
if err := cursor.validate(); err != nil {
62+
return nil, err
63+
}
64+
65+
return &cursor, nil
66+
}
67+
68+
// GetOffset returns offset that can be used for API call.
69+
func (c *Cursor) GetOffset() int32 {
70+
return c.Offset
71+
}
72+
73+
// GetNextCursor returns cursor for the next offset.
74+
func (c *Cursor) GetNextCursor(limit int32) *Cursor {
75+
if limit < 0 || c.Offset+limit < 0 {
76+
limit = 0
77+
}
78+
79+
return &Cursor{
80+
Offset: c.Offset + limit,
81+
}
82+
}
83+
84+
// validate checks if the cursor has valid values.
85+
func (c *Cursor) validate() error {
86+
if c.Offset < 0 {
87+
return errors.New("offset must be non-negative")
88+
}
89+
90+
return nil
91+
}

internal/cursor/cursor_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package cursor
2+
3+
import (
4+
"encoding/base64"
5+
"math"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestNew_ValidInput(t *testing.T) {
13+
tests := map[string]struct {
14+
offset int32
15+
}{
16+
"offset 0": {
17+
offset: 0,
18+
},
19+
"offset 5": {
20+
offset: 5,
21+
},
22+
"large values": {
23+
offset: 1000000,
24+
},
25+
}
26+
27+
for testName, testCase := range tests {
28+
t.Run(testName, func(t *testing.T) {
29+
cursor, err := New(testCase.offset)
30+
require.NoError(t, err)
31+
require.NotNil(t, cursor)
32+
assert.Equal(t, testCase.offset, cursor.Offset)
33+
})
34+
}
35+
}
36+
37+
func TestNew_InvalidInput(t *testing.T) {
38+
cursor, err := New(-1)
39+
require.Error(t, err)
40+
assert.Nil(t, cursor)
41+
assert.Contains(t, err.Error(), "offset must be non-negative")
42+
}
43+
44+
func TestDecode_Success(t *testing.T) {
45+
original := &Cursor{Offset: 1}
46+
encoded, err := original.Encode()
47+
require.NoError(t, err)
48+
49+
decoded, err := Decode(encoded)
50+
require.NoError(t, err)
51+
require.NotNil(t, decoded)
52+
assert.Equal(t, original.Offset, decoded.Offset)
53+
}
54+
55+
func TestDecode_InvalidInput(t *testing.T) {
56+
tests := map[string]struct {
57+
encoded string
58+
expectedError string
59+
}{
60+
"empty string": {
61+
encoded: "",
62+
expectedError: "encoded cursor cannot be empty",
63+
},
64+
"invalid base64": {
65+
encoded: "not-base64!@#$%",
66+
expectedError: "invalid base64 encoding",
67+
},
68+
"invalid json": {
69+
encoded: base64.StdEncoding.EncodeToString([]byte("not json")),
70+
expectedError: "invalid cursor format",
71+
},
72+
"valid json but invalid cursor - negative offset": {
73+
encoded: base64.StdEncoding.EncodeToString([]byte(`{"offset":-1}`)),
74+
expectedError: "offset must be non-negative",
75+
},
76+
}
77+
78+
for testName, testCase := range tests {
79+
t.Run(testName, func(t *testing.T) {
80+
decoded, err := Decode(testCase.encoded)
81+
require.Error(t, err)
82+
assert.Nil(t, decoded)
83+
assert.Contains(t, err.Error(), testCase.expectedError)
84+
})
85+
}
86+
}
87+
88+
func TestEncodeDecode_RoundTrip(t *testing.T) {
89+
tests := map[string]struct {
90+
offset int32
91+
}{
92+
"zero offset": {
93+
offset: 0,
94+
},
95+
"non-zero offset": {
96+
offset: 5,
97+
},
98+
"large offset": {
99+
offset: 10000,
100+
},
101+
}
102+
103+
for testName, testCase := range tests {
104+
t.Run(testName, func(t *testing.T) {
105+
original, err := New(testCase.offset)
106+
require.NoError(t, err)
107+
108+
encoded, err := original.Encode()
109+
require.NoError(t, err)
110+
assert.NotEmpty(t, encoded)
111+
112+
decoded, err := Decode(encoded)
113+
require.NoError(t, err)
114+
require.NotNil(t, decoded)
115+
116+
assert.Equal(t, original.Offset, decoded.Offset)
117+
})
118+
}
119+
}
120+
121+
func TestEncode_InvalidInput(t *testing.T) {
122+
cursor := &Cursor{Offset: -1}
123+
encoded, err := cursor.Encode()
124+
125+
require.Error(t, err)
126+
assert.Empty(t, encoded)
127+
assert.Contains(t, err.Error(), "offset must be non-negative")
128+
}
129+
130+
func TestGetOffset(t *testing.T) {
131+
cursor := &Cursor{Offset: 1}
132+
assert.Equal(t, cursor.Offset, cursor.GetOffset())
133+
}
134+
135+
func TestGetNextCursor(t *testing.T) {
136+
cursor := &Cursor{Offset: 0}
137+
138+
cursorStep1 := cursor.GetNextCursor(10)
139+
assert.Equal(t, int32(10), cursorStep1.GetOffset())
140+
141+
cursorStep2 := cursorStep1.GetNextCursor(5)
142+
assert.Equal(t, int32(15), cursorStep2.GetOffset())
143+
144+
cursorNegativeLimit := cursorStep2.GetNextCursor(-1)
145+
assert.Equal(t, int32(15), cursorNegativeLimit.GetOffset(), "negative limit should not change offset")
146+
147+
cursorOverflow := cursorStep2.GetNextCursor(math.MaxInt32)
148+
assert.Equal(t, int32(15), cursorOverflow.GetOffset(), "overflow paging should not change offset")
149+
}

internal/toolsets/vulnerability/tools.go

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ import (
1111
v1 "github.com/stackrox/rox/generated/api/v1"
1212
"github.com/stackrox/stackrox-mcp/internal/client"
1313
"github.com/stackrox/stackrox-mcp/internal/client/auth"
14+
"github.com/stackrox/stackrox-mcp/internal/cursor"
1415
"github.com/stackrox/stackrox-mcp/internal/logging"
1516
"github.com/stackrox/stackrox-mcp/internal/toolsets"
1617
)
1718

1819
const (
19-
defaultLimit = 50
20-
maximumLimit = 200.0
20+
defaultLimit = 100
2121
)
2222

2323
type filterPlatformType string
@@ -34,8 +34,7 @@ type getDeploymentsForCVEInput struct {
3434
FilterClusterID string `json:"filterClusterId,omitempty"`
3535
FilterNamespace string `json:"filterNamespace,omitempty"`
3636
FilterPlatform filterPlatformType `json:"filterPlatform,omitempty"`
37-
Offset int32 `json:"offset,omitempty"`
38-
Limit int32 `json:"limit,omitempty"`
37+
Cursor string `json:"cursor,omitempty"`
3938
}
4039

4140
func (input *getDeploymentsForCVEInput) validate() error {
@@ -57,6 +56,7 @@ type DeploymentResult struct {
5756
// getDeploymentsForCVEOutput defines the output structure for get_deployments_for_cve tool.
5857
type getDeploymentsForCVEOutput struct {
5958
Deployments []DeploymentResult `json:"deployments"`
59+
NextCursor string `json:"nextCursor"`
6060
}
6161

6262
// getDeploymentsForCVETool implements the get_deployments_for_cve tool.
@@ -118,14 +118,7 @@ func getDeploymentsForCVEInputSchema() *jsonschema.Schema {
118118
filterPlatformPlatform,
119119
}
120120

121-
schema.Properties["offset"].Description = "Pagination offset (default: 0)"
122-
schema.Properties["offset"].Default = toolsets.MustJSONMarshal(0)
123-
schema.Properties["limit"].Minimum = jsonschema.Ptr(0.0)
124-
125-
schema.Properties["limit"].Description = "Pagination limit: minimum: 1, maximum: 200 (default: 50)"
126-
schema.Properties["limit"].Default = toolsets.MustJSONMarshal(defaultLimit)
127-
schema.Properties["limit"].Minimum = jsonschema.Ptr(1.0)
128-
schema.Properties["limit"].Maximum = jsonschema.Ptr(maximumLimit)
121+
schema.Properties["cursor"].Description = "Cursor for next page provided by server"
129122

130123
return schema
131124
}
@@ -160,6 +153,21 @@ func buildQuery(input getDeploymentsForCVEInput) string {
160153
return strings.Join(queryParts, "+")
161154
}
162155

156+
func getCursor(input *getDeploymentsForCVEInput) (*cursor.Cursor, error) {
157+
if input.Cursor == "" {
158+
startCursor, err := cursor.New(0)
159+
160+
return startCursor, errors.Wrap(err, "error creating starting cursor")
161+
}
162+
163+
currCursor, err := cursor.Decode(input.Cursor)
164+
if err != nil {
165+
return nil, errors.Wrap(err, "error decoding cursor")
166+
}
167+
168+
return currCursor, nil
169+
}
170+
163171
// handle is the handler for get_deployments_for_cve tool.
164172
func (t *getDeploymentsForCVETool) handle(
165173
ctx context.Context,
@@ -171,6 +179,11 @@ func (t *getDeploymentsForCVETool) handle(
171179
return nil, nil, err
172180
}
173181

182+
currCursor, err := getCursor(&input)
183+
if err != nil {
184+
return nil, nil, err
185+
}
186+
174187
conn, err := t.client.ReadyConn(ctx)
175188
if err != nil {
176189
return nil, nil, errors.Wrap(err, "unable to connect to server")
@@ -182,8 +195,8 @@ func (t *getDeploymentsForCVETool) handle(
182195
listReq := &v1.RawQuery{
183196
Query: buildQuery(input),
184197
Pagination: &v1.Pagination{
185-
Offset: input.Offset,
186-
Limit: input.Limit,
198+
Offset: currCursor.GetOffset(),
199+
Limit: defaultLimit + 1,
187200
},
188201
}
189202

@@ -202,8 +215,19 @@ func (t *getDeploymentsForCVETool) handle(
202215
})
203216
}
204217

218+
// We always fetch limit+1 - if we do not have one additional element we can end paging.
219+
if len(deployments) <= defaultLimit {
220+
return nil, &getDeploymentsForCVEOutput{Deployments: deployments}, nil
221+
}
222+
223+
nextCursorStr, err := currCursor.GetNextCursor(defaultLimit).Encode()
224+
if err != nil {
225+
return nil, nil, errors.Wrap(err, "unable to create next cursor")
226+
}
227+
205228
output := &getDeploymentsForCVEOutput{
206-
Deployments: deployments,
229+
Deployments: deployments[:len(deployments)-1],
230+
NextCursor: nextCursorStr,
207231
}
208232

209233
return nil, output, nil

0 commit comments

Comments
 (0)