Skip to content

Commit 714a959

Browse files
authored
feat(Scheduler): Deal with model replicas being set to 0 (#6557)
1 parent c43fd13 commit 714a959

File tree

16 files changed

+936
-613
lines changed

16 files changed

+936
-613
lines changed

apis/go/mlops/scheduler/scheduler.pb.go

Lines changed: 566 additions & 562 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

apis/mlops/scheduler/scheduler.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ message ModelStatus {
157157
ModelTerminated = 5;
158158
ModelTerminateFailed = 6;
159159
ScheduleFailed = 7;
160+
ModelScaledDown = 8;
160161
}
161162
ModelState state = 1;
162163
string reason = 2;

operator/apis/mlops/v1alpha1/model_types_test.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ package v1alpha1
1111

1212
import (
1313
"encoding/json"
14+
"math"
1415
"testing"
1516

17+
"github.com/gotidy/ptr"
1618
. "github.com/onsi/gomega"
1719
"github.com/tidwall/gjson"
1820
v1 "k8s.io/api/core/v1"
@@ -89,8 +91,7 @@ func TestAsModelDetails(t *testing.T) {
8991
modelpb *scheduler.Model
9092
error bool
9193
}
92-
replicas := int32(4)
93-
replicas1 := int32(1)
94+
9495
secret := "secret"
9596
modelType := "sklearn"
9697
server := "server"
@@ -127,7 +128,7 @@ func TestAsModelDetails(t *testing.T) {
127128
DeploymentSpec: &scheduler.DeploymentSpec{
128129
Replicas: 1,
129130
MinReplicas: 0,
130-
MaxReplicas: 0,
131+
MaxReplicas: math.MaxUint32,
131132
},
132133
},
133134
},
@@ -147,7 +148,7 @@ func TestAsModelDetails(t *testing.T) {
147148
},
148149
Logger: &LoggingSpec{},
149150
Requirements: []string{"a", "b"},
150-
ScalingSpec: ScalingSpec{Replicas: &replicas},
151+
ScalingSpec: ScalingSpec{Replicas: ptr.Int32(4)},
151152
Server: &server,
152153
Explainer: &ExplainerSpec{
153154
Type: "anchor_tabular",
@@ -199,7 +200,7 @@ func TestAsModelDetails(t *testing.T) {
199200
Replicas: 4,
200201
LogPayloads: true,
201202
MinReplicas: 0,
202-
MaxReplicas: 0,
203+
MaxReplicas: math.MaxUint32,
203204
},
204205
},
205206
},
@@ -219,7 +220,7 @@ func TestAsModelDetails(t *testing.T) {
219220
},
220221
Logger: &LoggingSpec{},
221222
Requirements: []string{"a", "b"},
222-
ScalingSpec: ScalingSpec{Replicas: &replicas},
223+
ScalingSpec: ScalingSpec{Replicas: ptr.Int32(4)},
223224
Server: &server,
224225
Llm: &LlmSpec{
225226
ModelRef: &llmModel,
@@ -269,7 +270,7 @@ func TestAsModelDetails(t *testing.T) {
269270
Replicas: 4,
270271
LogPayloads: true,
271272
MinReplicas: 0,
272-
MaxReplicas: 0,
273+
MaxReplicas: math.MaxUint32,
273274
},
274275
},
275276
},
@@ -285,7 +286,8 @@ func TestAsModelDetails(t *testing.T) {
285286
InferenceArtifactSpec: InferenceArtifactSpec{
286287
StorageURI: "gs://test",
287288
},
288-
Memory: &m1,
289+
Memory: &m1,
290+
ScalingSpec: ScalingSpec{Replicas: ptr.Int32(1)},
289291
},
290292
},
291293
modelpb: &scheduler.Model{
@@ -303,7 +305,7 @@ func TestAsModelDetails(t *testing.T) {
303305
DeploymentSpec: &scheduler.DeploymentSpec{
304306
Replicas: 1,
305307
MinReplicas: 0,
306-
MaxReplicas: 0,
308+
MaxReplicas: math.MaxUint32,
307309
},
308310
},
309311
},
@@ -320,7 +322,7 @@ func TestAsModelDetails(t *testing.T) {
320322
InferenceArtifactSpec: InferenceArtifactSpec{
321323
StorageURI: "gs://test",
322324
},
323-
ScalingSpec: ScalingSpec{MinReplicas: &replicas},
325+
ScalingSpec: ScalingSpec{MinReplicas: ptr.Int32(4)},
324326
},
325327
},
326328
modelpb: &scheduler.Model{
@@ -337,7 +339,7 @@ func TestAsModelDetails(t *testing.T) {
337339
DeploymentSpec: &scheduler.DeploymentSpec{
338340
Replicas: 4,
339341
MinReplicas: 4,
340-
MaxReplicas: 0,
342+
MaxReplicas: math.MaxUint32,
341343
},
342344
},
343345
},
@@ -354,7 +356,7 @@ func TestAsModelDetails(t *testing.T) {
354356
InferenceArtifactSpec: InferenceArtifactSpec{
355357
StorageURI: "gs://test",
356358
},
357-
ScalingSpec: ScalingSpec{MaxReplicas: &replicas},
359+
ScalingSpec: ScalingSpec{Replicas: ptr.Int32(1), MaxReplicas: ptr.Int32(4)},
358360
},
359361
},
360362
modelpb: &scheduler.Model{
@@ -388,7 +390,7 @@ func TestAsModelDetails(t *testing.T) {
388390
InferenceArtifactSpec: InferenceArtifactSpec{
389391
StorageURI: "gs://test",
390392
},
391-
ScalingSpec: ScalingSpec{MinReplicas: &replicas, Replicas: &replicas1},
393+
ScalingSpec: ScalingSpec{MinReplicas: ptr.Int32(4), Replicas: ptr.Int32(1)},
392394
},
393395
},
394396
modelpb: &scheduler.Model{
@@ -423,7 +425,7 @@ func TestAsModelDetails(t *testing.T) {
423425
InferenceArtifactSpec: InferenceArtifactSpec{
424426
StorageURI: "gs://test",
425427
},
426-
ScalingSpec: ScalingSpec{Replicas: &replicas, MaxReplicas: &replicas1},
428+
ScalingSpec: ScalingSpec{Replicas: ptr.Int32(4), MaxReplicas: ptr.Int32(1)},
427429
},
428430
},
429431
modelpb: &scheduler.Model{

operator/apis/mlops/v1alpha1/utils.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ the Change License after the Change Date as each is defined in accordance with t
99

1010
package v1alpha1
1111

12-
import "fmt"
12+
import (
13+
"fmt"
14+
"math"
15+
)
1316

1417
type ValidatedScalingSpec struct {
1518
Replicas uint32
@@ -18,35 +21,43 @@ type ValidatedScalingSpec struct {
1821
}
1922

2023
func GetValidatedScalingSpec(replicas *int32, minReplicas *int32, maxReplicas *int32) (*ValidatedScalingSpec, error) {
21-
var validatedSpec ValidatedScalingSpec
24+
validatedSpec := ValidatedScalingSpec{
25+
Replicas: 1,
26+
MinReplicas: 0,
27+
MaxReplicas: math.MaxUint32,
28+
}
29+
30+
if replicas == nil && minReplicas == nil && maxReplicas == nil {
31+
return &validatedSpec, nil
32+
}
2233

2334
if replicas != nil && *replicas > 0 {
2435
validatedSpec.Replicas = uint32(*replicas)
25-
} else {
26-
if minReplicas != nil && *minReplicas > 0 {
27-
// set replicas to the min replicas when replicas is not set explicitly
28-
validatedSpec.Replicas = uint32(*minReplicas)
29-
} else {
30-
validatedSpec.Replicas = 1
36+
} else if replicas != nil && *replicas == 0 && minReplicas != nil && *minReplicas == 0 {
37+
validatedSpec.Replicas = uint32(0)
38+
} else if minReplicas != nil && *minReplicas > 0 {
39+
if replicas != nil && *replicas < *minReplicas {
40+
return nil, fmt.Errorf("number of replicas %d cannot be less than minimum replica %d", *replicas, *minReplicas)
3141
}
42+
// set replicas to the min replicas when replicas is not set explicitly
43+
validatedSpec.Replicas = uint32(*minReplicas)
3244
}
3345

3446
if minReplicas != nil && *minReplicas > 0 {
3547
validatedSpec.MinReplicas = uint32(*minReplicas)
3648
if validatedSpec.Replicas < validatedSpec.MinReplicas {
37-
return nil, fmt.Errorf("number of replicas %d must be >= min replicas %d", validatedSpec.Replicas, validatedSpec.MinReplicas)
49+
return nil, fmt.Errorf("number of replicas %d must be >= min replicas %d", validatedSpec.Replicas, validatedSpec.MinReplicas)
3850
}
39-
} else {
40-
validatedSpec.MinReplicas = 0
4151
}
4252

4353
if maxReplicas != nil && *maxReplicas > 0 {
4454
validatedSpec.MaxReplicas = uint32(*maxReplicas)
55+
if validatedSpec.MinReplicas > validatedSpec.MaxReplicas {
56+
return nil, fmt.Errorf("min number of replicas %d must be <= max number of replicas %d", validatedSpec.MinReplicas, validatedSpec.MaxReplicas)
57+
}
4558
if validatedSpec.Replicas > validatedSpec.MaxReplicas {
46-
return nil, fmt.Errorf("number of replicas %d must be <= min replicas %d", validatedSpec.Replicas, validatedSpec.MaxReplicas)
59+
return nil, fmt.Errorf("number of replicas %d must be <= max replicas %d", validatedSpec.Replicas, validatedSpec.MaxReplicas)
4760
}
48-
} else {
49-
validatedSpec.MaxReplicas = 0
5061
}
5162

5263
return &validatedSpec, nil
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
package v1alpha1
2+
3+
import (
4+
"math"
5+
"testing"
6+
7+
"github.com/gotidy/ptr"
8+
. "github.com/onsi/gomega"
9+
)
10+
11+
func TestGetValidatedScalingSpec(t *testing.T) {
12+
type test struct {
13+
name string
14+
replicas *int32
15+
minReplicas *int32
16+
maxReplicas *int32
17+
expected *ValidatedScalingSpec
18+
wantErr string
19+
}
20+
21+
g := NewGomegaWithT(t)
22+
23+
tests := []test{
24+
{
25+
name: "success - replicas is higher than min replicas and lower than max replicas",
26+
replicas: ptr.Int32(2),
27+
minReplicas: ptr.Int32(1),
28+
maxReplicas: ptr.Int32(3),
29+
expected: &ValidatedScalingSpec{
30+
Replicas: 2,
31+
MinReplicas: 1,
32+
MaxReplicas: 3,
33+
},
34+
wantErr: "",
35+
},
36+
{
37+
name: "error - replicas is less than min replicas",
38+
replicas: ptr.Int32(1),
39+
minReplicas: ptr.Int32(2),
40+
maxReplicas: ptr.Int32(4),
41+
expected: nil,
42+
wantErr: "number of replicas 1 must be >= min replicas 2",
43+
},
44+
{
45+
name: "error - replicas is bigger than max replicas",
46+
replicas: ptr.Int32(5),
47+
minReplicas: ptr.Int32(1),
48+
maxReplicas: ptr.Int32(4),
49+
expected: nil,
50+
wantErr: "number of replicas 5 must be <= max replicas 4",
51+
},
52+
{
53+
name: "error - replica is less than min replicas",
54+
replicas: ptr.Int32(0),
55+
minReplicas: ptr.Int32(1),
56+
maxReplicas: nil,
57+
expected: nil,
58+
wantErr: "number of replicas 0 cannot be less than minimum replica 1",
59+
},
60+
{
61+
name: "error - min replica is bigger than max replicas",
62+
replicas: ptr.Int32(6),
63+
minReplicas: ptr.Int32(6),
64+
maxReplicas: ptr.Int32(4),
65+
expected: nil,
66+
wantErr: "min number of replicas 6 must be <= max number of replicas 4",
67+
},
68+
{
69+
name: "success - replicas stays at 0 when min replicas is 0 and max replicas is 4",
70+
replicas: ptr.Int32(0),
71+
minReplicas: ptr.Int32(0),
72+
maxReplicas: ptr.Int32(4),
73+
expected: &ValidatedScalingSpec{
74+
Replicas: 0,
75+
MinReplicas: 0,
76+
MaxReplicas: 4,
77+
},
78+
wantErr: "",
79+
},
80+
{
81+
name: "success - all replica params are the same",
82+
replicas: ptr.Int32(4),
83+
minReplicas: ptr.Int32(4),
84+
maxReplicas: ptr.Int32(4),
85+
expected: &ValidatedScalingSpec{
86+
Replicas: 4,
87+
MinReplicas: 4,
88+
MaxReplicas: 4,
89+
},
90+
wantErr: "",
91+
},
92+
{
93+
name: "success - min and max replicas default to right params when only replicas is set",
94+
replicas: ptr.Int32(2),
95+
minReplicas: nil,
96+
maxReplicas: nil,
97+
expected: &ValidatedScalingSpec{
98+
Replicas: 2,
99+
MinReplicas: 0,
100+
MaxReplicas: math.MaxUint32,
101+
},
102+
wantErr: "",
103+
},
104+
{
105+
name: "success - unset replica params defaults to 1",
106+
replicas: nil,
107+
minReplicas: nil,
108+
maxReplicas: nil,
109+
expected: &ValidatedScalingSpec{
110+
Replicas: 1,
111+
MinReplicas: 0,
112+
MaxReplicas: math.MaxUint32,
113+
},
114+
wantErr: "",
115+
},
116+
{
117+
name: "success - unset replica params defaults to value of min replicas",
118+
replicas: nil,
119+
minReplicas: ptr.Int32(2),
120+
maxReplicas: nil,
121+
expected: &ValidatedScalingSpec{
122+
Replicas: 2,
123+
MinReplicas: 2,
124+
MaxReplicas: math.MaxUint32,
125+
},
126+
wantErr: "",
127+
},
128+
{
129+
name: "success - unset replica params defaults to 1 when min replicas is set to 0",
130+
replicas: nil,
131+
minReplicas: ptr.Int32(0),
132+
maxReplicas: nil,
133+
expected: &ValidatedScalingSpec{
134+
Replicas: 1,
135+
MinReplicas: 0,
136+
MaxReplicas: math.MaxUint32,
137+
},
138+
wantErr: "",
139+
},
140+
{
141+
name: "success - unset minReplicas defaults to 0, unset replicas defaults to 1 when max replicas is set",
142+
replicas: nil,
143+
minReplicas: nil,
144+
maxReplicas: ptr.Int32(2),
145+
expected: &ValidatedScalingSpec{
146+
Replicas: 1,
147+
MinReplicas: 0,
148+
MaxReplicas: 2,
149+
},
150+
wantErr: "",
151+
},
152+
}
153+
154+
for _, test := range tests {
155+
t.Run(test.name, func(t *testing.T) {
156+
scalingSpec, err := GetValidatedScalingSpec(test.replicas, test.minReplicas, test.maxReplicas)
157+
158+
if test.wantErr != "" {
159+
if err == nil {
160+
t.Errorf("expected error: %v, got nil", test.wantErr)
161+
return
162+
}
163+
if err.Error() != test.wantErr {
164+
t.Errorf("expected error: %q, got: %q", test.wantErr, err.Error())
165+
}
166+
return
167+
}
168+
169+
if err != nil {
170+
t.Errorf("unexpected error: %v", err)
171+
return
172+
}
173+
174+
g.Expect(scalingSpec).To(Equal(test.expected))
175+
})
176+
}
177+
}

operator/go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ require (
1111
github.com/confluentinc/confluent-kafka-go/v2 v2.10.1
1212
github.com/ghodss/yaml v1.0.0
1313
github.com/go-logr/logr v1.4.3
14+
github.com/gotidy/ptr v1.4.0
1415
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0
1516
github.com/json-iterator/go v1.1.12
1617
github.com/onsi/ginkgo v1.16.5

0 commit comments

Comments
 (0)