Skip to content

Commit c75be04

Browse files
committed
Create Slice for each PodSet in Workload.
1 parent 9c1c13b commit c75be04

File tree

12 files changed

+865
-152
lines changed

12 files changed

+865
-152
lines changed

slice/cmd/main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,18 @@ func main() {
232232
}
233233
}
234234

235+
ctx := ctrl.SetupSignalHandler()
236+
if err := controller.SetupIndexer(ctx, mgr.GetFieldIndexer()); err != nil {
237+
setupLog.Error(err, "unable to setup indexes")
238+
os.Exit(1)
239+
}
240+
235241
go setupControllers(mgr, certsReady)
236242

237243
setupProbeEndpoints(mgr, certsReady)
238244

239245
setupLog.Info("starting manager")
240-
if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
246+
if err := mgr.Start(ctx); err != nil {
241247
setupLog.Error(err, "problem running manager")
242248
os.Exit(1)
243249
}

slice/hack/kind-cluster.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@ nodes:
2424
- role: worker
2525
labels:
2626
instance-type: on-demand
27+
cloud.google.com/gke-tpu-topology: 2x2x2
28+
cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice
29+
cloud.google.com/gke-tpu-reservation-subblock: tpu-subblock-1
2730
- role: worker
2831
labels:
29-
instance-type: spot
30-
32+
instance-type: on-demand
33+
cloud.google.com/gke-tpu-topology: 2x2x2
34+
cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice
35+
cloud.google.com/gke-nodepool: tpu-v4-pool
3136
kubeadmConfigPatches:
3237
- |
3338
kind: JoinConfiguration

slice/internal/controller/indexer.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
Copyright The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package controller
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
24+
"sigs.k8s.io/controller-runtime/pkg/client"
25+
26+
"tpu-slice-controller/api/v1alpha1"
27+
"tpu-slice-controller/internal/util/slices"
28+
)
29+
30+
const (
31+
OwnerReferenceUID = "metadata.ownerReferences.uid"
32+
)
33+
34+
func IndexOwnerReferenceUID(obj client.Object) []string {
35+
return slices.Map(obj.GetOwnerReferences(), func(o *metav1.OwnerReference) string { return string(o.UID) })
36+
}
37+
38+
// SetupIndexer sets the index with the given fields for core apis.
39+
func SetupIndexer(ctx context.Context, indexer client.FieldIndexer) error {
40+
if err := indexer.IndexField(ctx, &v1alpha1.Slice{}, OwnerReferenceUID, IndexOwnerReferenceUID); err != nil {
41+
return fmt.Errorf("setting index on ownerReferences.uid for Slice: %w", err)
42+
}
43+
return nil
44+
}

slice/internal/controller/workload_controller.go

Lines changed: 194 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@ package controller
1818

1919
import (
2020
"context"
21+
"errors"
22+
"fmt"
23+
"maps"
2124
"slices"
2225

26+
apierrors "k8s.io/apimachinery/pkg/api/errors"
2327
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28+
"k8s.io/apimachinery/pkg/util/sets"
29+
"k8s.io/klog/v2"
2430
ctrl "sigs.k8s.io/controller-runtime"
2531
"sigs.k8s.io/controller-runtime/pkg/client"
2632
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
@@ -33,7 +39,18 @@ import (
3339

3440
const (
3541
CleanupSliceFinalizerName = "accelerator.gke.io/slice"
36-
TPUReservationSubblockLabel = "cloud.google.com/gke-tpu-reservation-subblock"
42+
TPUReservationSubBlockLabel = "cloud.google.com/gke-tpu-reservation-subblock"
43+
NodePoolLabel = "cloud.google.com/gke-nodepool"
44+
TPUTopologyLabel = "cloud.google.com/gke-tpu-topology"
45+
TPUAcceleratorLabel = "cloud.google.com/gke-tpu-accelerator"
46+
)
47+
48+
var (
49+
errPodSetNotFound = errors.New("PodSet not found")
50+
errPodSetAssignmentNotFound = errors.New("PodSetAssignment not found")
51+
errTPUTopologyLabelNotFound = fmt.Errorf("%s label not found", TPUTopologyLabel)
52+
errTPUAcceleratorLabelNotFound = fmt.Errorf("%s label not found", TPUAcceleratorLabel)
53+
errTopologyAssignmentNotFound = errors.New("TopologyAssignment not found")
3754
)
3855

3956
// WorkloadReconciler reconciles a Workload object
@@ -64,18 +81,7 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
6481
log.V(2).Info("Reconcile Workload")
6582

6683
if r.shouldFinalize(wl) {
67-
if controllerutil.ContainsFinalizer(wl, CleanupSliceFinalizerName) {
68-
err = r.client.Delete(ctx, r.newEmptySlice(wl))
69-
if client.IgnoreNotFound(err) != nil {
70-
return ctrl.Result{}, err
71-
}
72-
controllerutil.RemoveFinalizer(wl, CleanupSliceFinalizerName)
73-
if err := r.client.Update(ctx, wl); err != nil {
74-
return ctrl.Result{}, err
75-
}
76-
log.V(5).Info("Removed finalizer")
77-
}
78-
return ctrl.Result{}, nil
84+
return ctrl.Result{}, client.IgnoreNotFound(r.finalize(ctx, wl))
7985
}
8086

8187
if controllerutil.AddFinalizer(wl, CleanupSliceFinalizerName) {
@@ -85,75 +91,212 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
8591
}
8692
}
8793

88-
return ctrl.Result{}, r.createSliceIfNotExist(ctx, wl)
94+
return ctrl.Result{}, r.createSlicesIfNotExist(ctx, wl)
8995
}
9096

9197
func (r *WorkloadReconciler) shouldFinalize(wl *kueue.Workload) bool {
9298
return !wl.DeletionTimestamp.IsZero() || workload.IsFinished(wl) || workload.IsEvicted(wl) || !workload.IsActive(wl)
9399
}
94100

95-
func (r *WorkloadReconciler) newEmptySlice(wl *kueue.Workload) *v1alpha1.Slice {
96-
return &v1alpha1.Slice{
97-
ObjectMeta: metav1.ObjectMeta{
98-
Name: wl.Name,
99-
Namespace: wl.Namespace,
100-
},
101+
func (r *WorkloadReconciler) finalize(ctx context.Context, wl *kueue.Workload) error {
102+
if !controllerutil.ContainsFinalizer(wl, CleanupSliceFinalizerName) {
103+
return nil
104+
}
105+
106+
log := ctrl.LoggerFrom(ctx)
107+
108+
slices, err := r.findWorkloadSlices(ctx, wl)
109+
if err != nil {
110+
log.Error(err, "Failed to find Slices")
111+
return err
101112
}
102-
}
103113

104-
func (r *WorkloadReconciler) newSlice(wl *kueue.Workload) (*v1alpha1.Slice, error) {
105-
slice := r.newEmptySlice(wl)
114+
for _, slice := range slices {
115+
err = r.client.Delete(ctx, &slice)
116+
if client.IgnoreNotFound(err) != nil {
117+
log.Error(err, "Failed to delete the Slice", "slice", klog.KObj(&slice))
118+
return err
119+
}
120+
}
121+
122+
controllerutil.RemoveFinalizer(wl, CleanupSliceFinalizerName)
123+
if err := r.client.Update(ctx, wl); err != nil {
124+
if !apierrors.IsNotFound(err) {
125+
log.Error(err, "Failed to remove finalizer")
126+
}
127+
return err
128+
}
106129

107-
if err := controllerutil.SetControllerReference(wl, slice, r.client.Scheme()); err != nil {
130+
log.V(5).Info("Removed finalizer")
131+
132+
return nil
133+
}
134+
135+
func (r *WorkloadReconciler) findWorkloadSlices(ctx context.Context, wl *kueue.Workload) ([]v1alpha1.Slice, error) {
136+
slices := &v1alpha1.SliceList{}
137+
opts := []client.ListOption{
138+
client.InNamespace(wl.Namespace),
139+
client.MatchingFields{OwnerReferenceUID: string(wl.UID)},
140+
}
141+
if err := r.client.List(ctx, slices, opts...); err != nil {
108142
return nil, err
109143
}
144+
return slices.Items, nil
145+
}
146+
147+
func (r *WorkloadReconciler) createSlicesIfNotExist(ctx context.Context, wl *kueue.Workload) error {
148+
log := ctrl.LoggerFrom(ctx)
149+
150+
createdSlices, err := r.findWorkloadSlices(ctx, wl)
151+
if err != nil {
152+
log.Error(err, "Failed to find Slices")
153+
return err
154+
}
110155

111-
if wl.Status.Admission != nil && wl.Status.Admission.PodSetAssignments != nil {
156+
createdSlicesByName := make(map[string]*v1alpha1.Slice, len(createdSlices))
157+
for _, slice := range createdSlices {
158+
createdSlicesByName[slice.Name] = &slice
159+
}
160+
161+
var slicesToCreate []*v1alpha1.Slice
162+
163+
if wl.Status.Admission != nil {
112164
for _, psa := range wl.Status.Admission.PodSetAssignments {
113-
for _, domain := range psa.TopologyAssignment.Domains {
114-
if slice.Spec.NodeSelector == nil {
115-
slice.Spec.NodeSelector = make(map[string][]string)
116-
}
117-
if slice.Spec.NodeSelector[TPUReservationSubblockLabel] == nil {
118-
slice.Spec.NodeSelector[TPUReservationSubblockLabel] = []string{}
119-
}
120-
// make sure there are no duplicates in the nodeSelector
121-
for _, v := range domain.Values {
122-
exists := slices.Contains(slice.Spec.NodeSelector[TPUReservationSubblockLabel], v)
123-
if !exists {
124-
slice.Spec.NodeSelector[TPUReservationSubblockLabel] = append(slice.Spec.NodeSelector[TPUReservationSubblockLabel], v)
125-
}
165+
sliceName := GetSliceName(wl.Name, psa.Name)
166+
167+
if _, ok := createdSlicesByName[sliceName]; ok {
168+
delete(createdSlicesByName, sliceName)
169+
continue
170+
}
171+
172+
slice, err := newSlice(wl, psa.Name)
173+
if err != nil {
174+
if !isUnsupportedPodSetError(err) {
175+
log.Error(err, "Failed to create a Slice")
176+
return err
126177
}
178+
log.V(8).Info("Failed to create Slice", "error", err)
179+
continue
180+
}
181+
182+
if err := controllerutil.SetControllerReference(wl, slice, r.client.Scheme()); err != nil {
183+
return err
127184
}
185+
186+
slicesToCreate = append(slicesToCreate, slice)
128187
}
129188
}
130189

131-
return slice, nil
190+
for _, slice := range slicesToCreate {
191+
err = r.client.Create(ctx, slice)
192+
if err != nil {
193+
log.Error(err, "Failed to create a Slice", "slice", klog.KObj(slice))
194+
return err
195+
}
196+
}
197+
198+
for _, slice := range slices.Collect(maps.Values(createdSlicesByName)) {
199+
err = r.client.Delete(ctx, slice)
200+
if client.IgnoreNotFound(err) != nil {
201+
log.Error(err, "Failed to delete the redundant Slice", "slice", klog.KObj(slice))
202+
return err
203+
}
204+
}
205+
206+
return nil
132207
}
133208

134-
func (r *WorkloadReconciler) createSliceIfNotExist(ctx context.Context, wl *kueue.Workload) error {
135-
slice := r.newEmptySlice(wl)
209+
func GetSliceName(workloadName string, podSetName kueue.PodSetReference) string {
210+
return fmt.Sprintf("%s-%s", workloadName, podSetName)
211+
}
136212

137-
err := r.client.Get(ctx, client.ObjectKeyFromObject(slice), slice)
138-
if client.IgnoreNotFound(err) != nil {
139-
return err
213+
func newSlice(wl *kueue.Workload, podSetName kueue.PodSetReference) (*v1alpha1.Slice, error) {
214+
ps := findPodSet(wl, podSetName)
215+
if findPodSet(wl, podSetName) == nil {
216+
return nil, errPodSetNotFound
140217
}
141-
if err == nil {
142-
return nil
218+
219+
if ps.Template.Spec.NodeSelector[TPUTopologyLabel] == "" {
220+
return nil, errTPUTopologyLabelNotFound
143221
}
144222

145-
slice, err = r.newSlice(wl)
146-
if err != nil {
147-
return err
223+
if ps.Template.Spec.NodeSelector[TPUAcceleratorLabel] == "" {
224+
return nil, errTPUAcceleratorLabelNotFound
148225
}
149226

150-
return r.client.Create(ctx, slice)
227+
psa := findPodSetAssignment(wl, podSetName)
228+
if psa == nil {
229+
return nil, errPodSetAssignmentNotFound
230+
}
231+
232+
if psa.TopologyAssignment == nil {
233+
return nil, errTopologyAssignmentNotFound
234+
}
235+
236+
slice := &v1alpha1.Slice{
237+
ObjectMeta: metav1.ObjectMeta{
238+
Name: GetSliceName(wl.Name, podSetName),
239+
Namespace: wl.Namespace,
240+
},
241+
Spec: v1alpha1.SliceSpec{
242+
AcceleratorTopology: ps.Template.Spec.NodeSelector[TPUTopologyLabel],
243+
AcceleratorType: ps.Template.Spec.NodeSelector[TPUAcceleratorLabel],
244+
NodeSelector: make(map[string][]string),
245+
},
246+
}
247+
248+
for _, domain := range psa.TopologyAssignment.Domains {
249+
if ps.Template.Spec.NodeSelector[TPUReservationSubBlockLabel] != "" {
250+
subBlockDomains := sets.New[string]()
251+
for _, v := range domain.Values {
252+
subBlockDomains.Insert(v)
253+
}
254+
if subBlockDomains.Len() > 0 {
255+
slice.Spec.NodeSelector[TPUReservationSubBlockLabel] = sets.List(subBlockDomains)
256+
}
257+
}
258+
if ps.Template.Spec.NodeSelector[NodePoolLabel] != "" {
259+
nodePoolDomains := sets.New[string]()
260+
for _, v := range domain.Values {
261+
nodePoolDomains.Insert(v)
262+
}
263+
if nodePoolDomains.Len() > 0 {
264+
slice.Spec.NodeSelector[NodePoolLabel] = sets.List(nodePoolDomains)
265+
}
266+
}
267+
}
268+
269+
return slice, nil
270+
}
271+
272+
func findPodSet(wl *kueue.Workload, podSetName kueue.PodSetReference) *kueue.PodSet {
273+
for _, ps := range wl.Spec.PodSets {
274+
if ps.Name == podSetName {
275+
return &ps
276+
}
277+
}
278+
return nil
279+
}
280+
281+
func findPodSetAssignment(wl *kueue.Workload, podSetName kueue.PodSetReference) *kueue.PodSetAssignment {
282+
for _, psa := range wl.Status.Admission.PodSetAssignments {
283+
if psa.Name == podSetName {
284+
return &psa
285+
}
286+
}
287+
return nil
288+
}
289+
290+
func isUnsupportedPodSetError(err error) bool {
291+
return errors.Is(err, errTPUTopologyLabelNotFound) ||
292+
errors.Is(err, errTPUAcceleratorLabelNotFound) ||
293+
errors.Is(err, errTopologyAssignmentNotFound)
151294
}
152295

153296
// SetupWithManager sets up the controller with the Manager.
154297
func (r *WorkloadReconciler) SetupWithManager(mgr ctrl.Manager) error {
155298
return ctrl.NewControllerManagedBy(mgr).
156299
For(&kueue.Workload{}).
157-
Named("workload").
300+
Named("workload_controller").
158301
Complete(r)
159302
}

0 commit comments

Comments
 (0)