Skip to content

Commit 7b85ed0

Browse files
authored
[Slice] Handle ReplicatedJob with >1 Replicas
2 parents 4f00dab + 206b13c commit 7b85ed0

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

slice/internal/util/testingjobs/jobset/wrappers.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ func (j *JobSetWrapper) ReplicatedJobs(replicatedJobs ...ReplicatedJobRequiremen
8787
},
8888
}
8989
}
90+
if req.Replicas == 0 {
91+
req.Replicas = 1
92+
}
9093
j.Spec.ReplicatedJobs[index] = jobsetutil.MakeReplicatedJob(req.Name).Job(jt).Replicas(req.Replicas).Obj()
9194
}
9295
return j

slice/internal/webhooks/jobset_webhook.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ func (r *JobSetWebhook) annotateReplicatedJobWithTopology(rj *v1alpha2.Replicate
7777
rj.Template.Spec.Template.Annotations[kueuealpha.PodSetRequiredTopologyAnnotation] = core.TPUBlockLabel
7878
rj.Template.Spec.Template.Annotations[kueuealpha.PodSetSliceRequiredTopologyAnnotation] = core.TPUSubBlockLabel
7979

80+
pods := ptr.Deref(rj.Template.Spec.Parallelism, 1) * rj.Replicas
81+
8082
size, err := r.podSetSliceSize(
8183
rj.Template.Spec.Template.Annotations[core.TPUTopologyAnnotation],
82-
ptr.Deref(rj.Template.Spec.Parallelism, 1),
84+
pods,
8385
)
8486
if err != nil {
8587
return err

slice/test/e2e/jobset_test.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ var _ = ginkgo.Describe("JobSet", func() {
103103
type testCase struct {
104104
tpuTopology string
105105
parallelism int32
106+
replicas int32
106107
wantSliceSize int32
107108
tpuRequests string
108109
wantDomains []kueue.TopologyDomainAssignment
@@ -117,7 +118,7 @@ var _ = ginkgo.Describe("JobSet", func() {
117118
Name: "rj1",
118119
Image: utils.E2eTestAgnHostImage,
119120
Args: utils.BehaviorWaitForDeletion,
120-
Replicas: 1,
121+
Replicas: tc.replicas,
121122
Parallelism: tc.parallelism,
122123
Completions: tc.parallelism,
123124
PodAnnotations: map[string]string{
@@ -165,7 +166,7 @@ var _ = ginkgo.Describe("JobSet", func() {
165166
g.Expect(createdWorkload.Spec.PodSets[0].TopologyRequest).To(gomega.BeComparableTo(&kueue.PodSetTopologyRequest{
166167
Required: ptr.To(core.TPUBlockLabel),
167168
PodSetSliceRequiredTopology: ptr.To(core.TPUSubBlockLabel),
168-
SubGroupCount: ptr.To[int32](1),
169+
SubGroupCount: ptr.To(tc.replicas),
169170
PodSetSliceSize: ptr.To(tc.wantSliceSize),
170171
}, ignorePodSetTopologyRequestFields))
171172
}, utils.Timeout, utils.Interval).Should(gomega.Succeed())
@@ -283,6 +284,7 @@ var _ = ginkgo.Describe("JobSet", func() {
283284
tpuTopology: "4x4x4",
284285
tpuRequests: "4",
285286
parallelism: 16,
287+
replicas: 1,
286288
wantSliceSize: 16,
287289
wantDomains: []kueue.TopologyDomainAssignment{{
288290
Values: []string{"b1", "sb1"},
@@ -296,6 +298,7 @@ var _ = ginkgo.Describe("JobSet", func() {
296298
tpuTopology: "4x4x4",
297299
tpuRequests: "1",
298300
parallelism: 64,
301+
replicas: 1,
299302
wantSliceSize: 64,
300303
wantDomains: []kueue.TopologyDomainAssignment{{
301304
Values: []string{"b1", "sb1"},
@@ -309,6 +312,7 @@ var _ = ginkgo.Describe("JobSet", func() {
309312
tpuTopology: "4x4x12",
310313
tpuRequests: "4",
311314
parallelism: 48,
315+
replicas: 1,
312316
wantSliceSize: 16,
313317
wantDomains: []kueue.TopologyDomainAssignment{
314318
{
@@ -332,6 +336,7 @@ var _ = ginkgo.Describe("JobSet", func() {
332336
tpuTopology: "4x4x12",
333337
tpuRequests: "2",
334338
parallelism: 96,
339+
replicas: 1,
335340
wantSliceSize: 32,
336341
wantDomains: []kueue.TopologyDomainAssignment{
337342
{
@@ -355,6 +360,7 @@ var _ = ginkgo.Describe("JobSet", func() {
355360
tpuTopology: "4x4x8",
356361
tpuRequests: "1",
357362
parallelism: 128,
363+
replicas: 1,
358364
wantSliceSize: 64,
359365
wantDomains: []kueue.TopologyDomainAssignment{
360366
{
@@ -370,6 +376,46 @@ var _ = ginkgo.Describe("JobSet", func() {
370376
controller.TPUReservationSubblockLabel: {"sb2", "sb3"},
371377
},
372378
}),
379+
ginkgo.Entry("TPU topology 4x4x4 split across 2 replicas", testCase{
380+
tpuTopology: "4x4x4",
381+
tpuRequests: "4",
382+
parallelism: 8,
383+
replicas: 2,
384+
wantSliceSize: 16,
385+
wantDomains: []kueue.TopologyDomainAssignment{
386+
{
387+
Values: []string{"b1", "sb1"},
388+
Count: 16,
389+
},
390+
},
391+
wantNodeSelector: map[string][]string{
392+
controller.TPUReservationSubblockLabel: {"sb1"},
393+
},
394+
}),
395+
ginkgo.Entry("TPU topology 4x4x12 split across 3 replicas", testCase{
396+
tpuTopology: "4x4x12",
397+
tpuRequests: "4",
398+
parallelism: 16,
399+
replicas: 3,
400+
wantSliceSize: 16,
401+
wantDomains: []kueue.TopologyDomainAssignment{
402+
{
403+
Values: []string{"b2", "sb2"},
404+
Count: 16,
405+
},
406+
{
407+
Values: []string{"b2", "sb3"},
408+
Count: 16,
409+
},
410+
{
411+
Values: []string{"b2", "sb4"},
412+
Count: 16,
413+
},
414+
},
415+
wantNodeSelector: map[string][]string{
416+
controller.TPUReservationSubblockLabel: {"sb2", "sb3", "sb4"},
417+
},
418+
}),
373419
)
374420
})
375421
})

0 commit comments

Comments
 (0)