Skip to content

Commit fe3d460

Browse files
committed
Create Slice for each PodSet in Workload.
1 parent 5ff57e1 commit fe3d460

File tree

8 files changed

+693
-152
lines changed

8 files changed

+693
-152
lines changed

slice/cmd/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ func main() {
233233
}
234234

235235
ctx := ctrl.SetupSignalHandler()
236-
if err := controller.SetupWorkloadIndexer(ctx, mgr.GetFieldIndexer()); err != nil {
236+
if err := controller.SetupIndexer(ctx, mgr.GetFieldIndexer()); err != nil {
237237
setupLog.Error(err, "unable to setup indexes")
238238
os.Exit(1)
239239
}

slice/internal/controller/indexer.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,25 @@ import (
2424
"sigs.k8s.io/controller-runtime/pkg/client"
2525
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
2626

27+
"tpu-slice-controller/api/v1alpha1"
2728
"tpu-slice-controller/internal/util/slices"
2829
)
2930

3031
const (
3132
OwnerReferenceUID = "metadata.ownerReferences.uid"
3233
)
3334

34-
// SetupWorkloadIndexer configures the indexer to index specific fields for kueue.Workload resources.
35-
func SetupWorkloadIndexer(ctx context.Context, indexer client.FieldIndexer) error {
36-
if err := indexer.IndexField(ctx, &kueue.Workload{}, OwnerReferenceUID, func(obj client.Object) []string {
37-
return slices.Map(obj.GetOwnerReferences(), func(o *metav1.OwnerReference) string { return string(o.UID) })
38-
}); err != nil {
35+
func indexOwnerReferenceUID(obj client.Object) []string {
36+
return slices.Map(obj.GetOwnerReferences(), func(o *metav1.OwnerReference) string { return string(o.UID) })
37+
}
38+
39+
// SetupIndexer configures the indexer to index specific fields for kueue.Workload and v1alpha1.Slice resources.
40+
func SetupIndexer(ctx context.Context, indexer client.FieldIndexer) error {
41+
if err := indexer.IndexField(ctx, &kueue.Workload{}, OwnerReferenceUID, indexOwnerReferenceUID); err != nil {
3942
return fmt.Errorf("setting index on ownerReferences.uid for Workload: %w", err)
4043
}
44+
if err := indexer.IndexField(ctx, &v1alpha1.Slice{}, OwnerReferenceUID, indexOwnerReferenceUID); err != nil {
45+
return fmt.Errorf("setting index on ownerReferences.uid for Slice: %w", err)
46+
}
4147
return nil
4248
}

slice/internal/controller/workload_controller.go

Lines changed: 186 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ import (
2020
"context"
2121
"errors"
2222
"fmt"
23+
"sort"
24+
"strings"
2325
"time"
2426

25-
"github.com/go-logr/logr"
2627
corev1 "k8s.io/api/core/v1"
2728
apierrors "k8s.io/apimachinery/pkg/api/errors"
2829
"k8s.io/apimachinery/pkg/api/meta"
@@ -55,7 +56,7 @@ const (
5556
SliceControllerName = "accelerator.gke.io/slice"
5657
TPUReservationSubblockLabel = "cloud.google.com/gke-tpu-reservation-subblock"
5758

58-
SliceCreatedEventType = "SliceCreated"
59+
SlicesCreatedEventType = "SlicesCreated"
5960
FailedCreateSliceEventType = "FailedCreateSlice"
6061
AdmissionCheckUpdatedEventType = "AdmissionCheckUpdated"
6162
)
@@ -106,8 +107,8 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
106107

107108
if finalize, reason := shouldFinalize(wl); finalize {
108109
if controllerutil.ContainsFinalizer(wl, SliceControllerName) {
109-
log.V(3).Info(fmt.Sprintf("Cleaning up the Slice and finalize the Workload because %s", reason))
110-
cleanedUp, err := r.cleanupSlice(ctx, wl)
110+
log.V(3).Info(fmt.Sprintf("Cleaning up the Slices and finalizing the Workload because %s", reason))
111+
cleanedUp, err := r.cleanupSlices(ctx, wl)
111112
if err != nil {
112113
return ctrl.Result{}, err
113114
}
@@ -148,19 +149,27 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
148149
return ctrl.Result{}, nil
149150
}
150151

151-
slice := v1alpha1.Slice{}
152-
err = r.client.Get(ctx, core.SliceKeyFromWorkload(wl), &slice)
153-
if apierrors.IsNotFound(err) {
154-
// slice not found, create it and exit.
155-
err = r.createSlice(ctx, log, wl, ac)
152+
slices, err := r.findWorkloadSlices(ctx, wl)
153+
if err != nil {
154+
log.Error(err, "Failed to list Slices")
156155
return ctrl.Result{}, err
157-
} else if err != nil {
158-
// error fetching slice
159-
log.Error(err, "Failed to fetch the Slice")
156+
}
157+
158+
deleted, _, _ := r.groupSlices(slices)
159+
if len(deleted) > 0 {
160+
log.V(3).Info(
161+
"Waiting for deleted Slices to be cleaned up; skipping reconciliation for now",
162+
"deletedSlices", klog.KObjSlice(deleted),
163+
)
160164
return ctrl.Result{}, err
161165
}
162166

163-
err = r.syncAdmissionCheckStatus(ctx, wl, ac, &slice)
167+
changed, err := r.syncSlices(ctx, wl, ac, slices)
168+
if err != nil || changed {
169+
return ctrl.Result{}, err
170+
}
171+
172+
err = r.syncAdmissionCheckStatus(ctx, wl, ac, slices)
164173
return ctrl.Result{}, client.IgnoreNotFound(err)
165174
}
166175

@@ -204,48 +213,98 @@ func isJobSetOwner(wl *kueue.Workload) bool {
204213
return false
205214
}
206215

207-
func (r *WorkloadReconciler) cleanupSlice(ctx context.Context, wl *kueue.Workload) (bool, error) {
208-
slice := v1alpha1.Slice{}
209-
sliceKey := core.SliceKeyFromWorkload(wl)
210-
211-
log := ctrl.LoggerFrom(ctx).WithValues("slice", klog.KRef(sliceKey.Namespace, sliceKey.Name))
212-
ctrl.LoggerInto(ctx, log)
216+
func (r *WorkloadReconciler) cleanupSlices(ctx context.Context, wl *kueue.Workload) (bool, error) {
217+
log := ctrl.LoggerFrom(ctx)
213218

214-
err := r.client.Get(ctx, sliceKey, &slice)
215-
if apierrors.IsNotFound(err) {
216-
// slice not found
217-
return true, nil
218-
} else if err != nil {
219-
// error fetching slice
220-
log.Error(err, "Failed to fetch the Slice")
219+
slices, err := r.findWorkloadSlices(ctx, wl)
220+
if err != nil {
221+
log.Error(err, "Failed to fetch Slices")
221222
return false, err
222223
}
223224

224-
if !slice.DeletionTimestamp.IsZero() {
225-
log.V(3).Info("Slice already deleted, finishing cleanup")
225+
deleted, deformed, other := r.groupSlices(slices)
226+
227+
if len(deleted) == len(slices) {
228+
log.V(3).Info("All slices already deleted; finishing cleanup")
226229
return true, nil
227230
}
228231

229-
if !core.Deformed(&slice) {
232+
if len(deformed) > 0 {
233+
log.V(3).Info("Found Slices in deformed state; cleaning them up", "deformedSlices", klog.KObjSlice(deformed))
234+
// We still need to delete deformed Slices because requeueing causes a conflict error during Slice creation.
235+
err = r.deleteSlices(ctx, deformed)
236+
if err != nil {
237+
return false, err
238+
}
239+
}
240+
241+
if len(other) > 0 {
230242
terminated, err := r.ownerPodsFinished(ctx, wl)
231243
if err != nil || !terminated {
232244
return false, err
233245
}
234-
} else {
235-
log.V(3).Info("Slice in deformed state")
236-
// We still need to delete the Slice because requeueing causes a conflict error during Slice creation.
237246
}
238247

239-
log.V(3).Info("Deleting the Slice")
248+
log.V(3).Info("Deleting Slices", "slices", klog.KObjSlice(other))
249+
err = r.deleteSlices(ctx, other)
250+
if err != nil {
251+
return false, err
252+
}
240253

241-
err = r.client.Delete(ctx, &slice)
242-
if apierrors.IsNotFound(err) {
243-
return true, nil
244-
} else if err != nil {
245-
log.Error(err, "Failed to delete the Slice")
254+
return true, nil
255+
}
256+
257+
func (r *WorkloadReconciler) findWorkloadSlices(ctx context.Context, wl *kueue.Workload) ([]v1alpha1.Slice, error) {
258+
slices := &v1alpha1.SliceList{}
259+
opts := []client.ListOption{
260+
client.InNamespace(wl.Namespace),
261+
client.MatchingFields{OwnerReferenceUID: string(wl.UID)},
246262
}
263+
if err := r.client.List(ctx, slices, opts...); err != nil {
264+
return nil, err
265+
}
266+
return slices.Items, nil
267+
}
268+
269+
// groupSlices categorizes a list of Slice objects into three groups based on their state.
270+
// It separates slices into deleted (marked for deletion), deformed (being torn down),
271+
// and other (active) slices.
272+
//
273+
// Parameters:
274+
//
275+
// slices - A slice of v1alpha1.Slice objects to be categorized.
276+
//
277+
// Returns:
278+
// - A slice containing deleted Slice objects (with non-zero DeletionTimestamp).
279+
// - A slice containing deformed Slice objects (being torn down).
280+
// - A slice containing other Slice objects (active/valid slices).
281+
func (r *WorkloadReconciler) groupSlices(slices []v1alpha1.Slice) ([]v1alpha1.Slice, []v1alpha1.Slice, []v1alpha1.Slice) {
282+
var deleted, deformed, other []v1alpha1.Slice
283+
for _, slice := range slices {
284+
switch {
285+
case !slice.DeletionTimestamp.IsZero():
286+
deleted = append(deleted, slice)
287+
case core.Deformed(&slice):
288+
deformed = append(deformed, slice)
289+
default:
290+
other = append(other, slice)
291+
}
292+
}
293+
return deleted, deformed, other
294+
}
247295

248-
return true, err
296+
func (r *WorkloadReconciler) deleteSlices(ctx context.Context, slices []v1alpha1.Slice) error {
297+
log := ctrl.LoggerFrom(ctx)
298+
for _, slice := range slices {
299+
log = log.WithValues("slice", klog.KObj(&slice))
300+
log.V(3).Info("Deleting the Slice")
301+
err := r.client.Delete(ctx, &slice)
302+
if client.IgnoreNotFound(err) != nil {
303+
log.Error(err, "Failed to delete the Slice")
304+
return err
305+
}
306+
}
307+
return nil
249308
}
250309

251310
func (r *WorkloadReconciler) ownerPodsFinished(ctx context.Context, wl *kueue.Workload) (bool, error) {
@@ -361,46 +420,76 @@ func (r *WorkloadReconciler) sliceAC(ctx context.Context, wl *kueue.Workload) (*
361420
return workload.FindAdmissionCheck(wl.Status.AdmissionChecks, relevantChecks[0]), nil
362421
}
363422

364-
func parseTopologyAssignmentIntoNodeSelector(slice *v1alpha1.Slice, wl *kueue.Workload) {
365-
nodeSelectors := sets.New[string]()
423+
func (r *WorkloadReconciler) syncSlices(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState, slices []v1alpha1.Slice) (bool, error) {
424+
slicesGroupedByName := make(map[string]*v1alpha1.Slice, len(slices))
425+
for _, slice := range slices {
426+
slicesGroupedByName[slice.Name] = &slice
427+
}
428+
429+
changed := false
430+
366431
for _, psa := range wl.Status.Admission.PodSetAssignments {
367-
// we already validated that all assignments have a valid level,
368-
// in validateRelevantWorkload.
369-
subblockLevelIndex := topology.SubblockLevelIndex(&psa)
370-
for _, domain := range psa.TopologyAssignment.Domains {
371-
nodeSelectors.Insert(domain.Values[subblockLevelIndex])
432+
sliceName := core.SliceName(wl.Name, psa.Name)
433+
434+
if _, exist := slicesGroupedByName[sliceName]; exist {
435+
// Slice already exists, nothing to do.
436+
continue
372437
}
438+
439+
createdSlice, err := r.createSlice(ctx, wl, ac, &psa)
440+
if err != nil {
441+
return false, err
442+
}
443+
444+
slices = append(slices, *createdSlice)
445+
changed = true
446+
}
447+
448+
if changed {
449+
msg := fmt.Sprintf("The Slices %v have been created", joinSliceNames(slices))
450+
ctrl.LoggerFrom(ctx).V(3).Info(msg)
451+
r.record.Event(wl, corev1.EventTypeNormal, SlicesCreatedEventType, msg)
452+
ac.Message = msg
453+
return true, r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac)
454+
}
455+
456+
return false, nil
457+
}
458+
459+
func parseTopologyAssignmentIntoNodeSelector(slice *v1alpha1.Slice, psa *kueue.PodSetAssignment) {
460+
nodeSelectors := sets.New[string]()
461+
// we already validated that all assignments have a valid level,
462+
// in validateRelevantWorkload.
463+
subblockLevelIndex := topology.SubblockLevelIndex(psa)
464+
for _, domain := range psa.TopologyAssignment.Domains {
465+
nodeSelectors.Insert(domain.Values[subblockLevelIndex])
373466
}
374467
slice.Spec.NodeSelector = map[string][]string{
375468
TPUReservationSubblockLabel: sets.List(nodeSelectors),
376469
}
377470
}
378471

379-
func (r *WorkloadReconciler) createSlice(ctx context.Context, log logr.Logger, wl *kueue.Workload, ac *kueue.AdmissionCheckState) error {
380-
slice := core.SliceWithMetadata(wl)
381-
log = log.WithValues("slice", klog.KObj(slice))
472+
func (r *WorkloadReconciler) createSlice(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState, psa *kueue.PodSetAssignment) (*v1alpha1.Slice, error) {
473+
slice := core.SliceWithMetadata(wl, psa.Name)
474+
log := ctrl.LoggerFrom(ctx).WithValues("slice", klog.KObj(slice))
382475
log.V(3).Info("Creating Slice")
383476

384477
if err := controllerutil.SetControllerReference(wl, slice, r.client.Scheme()); err != nil {
385-
return err
478+
return nil, err
386479
}
387-
parseTopologyAssignmentIntoNodeSelector(slice, wl)
480+
parseTopologyAssignmentIntoNodeSelector(slice, psa)
388481

389482
if err := r.client.Create(ctx, slice); err != nil {
390-
msg := fmt.Sprintf("Error creating Slice %q: %v", slice.Name, err)
483+
msg := fmt.Sprintf("Error creating Slice %q: %v", client.ObjectKeyFromObject(slice), err)
391484
log.Error(err, msg)
392485
r.record.Event(wl, corev1.EventTypeWarning, FailedCreateSliceEventType, api.TruncateEventMessage(msg))
486+
ac.State = kueue.CheckStatePending
393487
ac.Message = api.TruncateConditionMessage(msg)
394488
patchErr := r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac)
395-
return errors.Join(err, patchErr)
489+
return nil, errors.Join(err, patchErr)
396490
}
397491

398-
msg := fmt.Sprintf("The Slice %s has been created", client.ObjectKeyFromObject(slice))
399-
log.V(3).Info(msg)
400-
r.record.Event(wl, corev1.EventTypeNormal, SliceCreatedEventType, msg)
401-
ac.Message = msg
402-
403-
return r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac)
492+
return slice, nil
404493
}
405494

406495
func (r *WorkloadReconciler) updateWorkloadAdmissionCheckStatus(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState) error {
@@ -413,27 +502,50 @@ func (r *WorkloadReconciler) updateWorkloadAdmissionCheckStatus(ctx context.Cont
413502
return err
414503
}
415504

416-
// syncAdmissionCheckStatus syncs the admission check status with the state of the slice.
417-
func (r *WorkloadReconciler) syncAdmissionCheckStatus(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState, slice *v1alpha1.Slice) error {
505+
func joinSliceNames(slices []v1alpha1.Slice) string {
506+
sliceNames := make([]string, len(slices))
507+
for index, slice := range slices {
508+
sliceNames[index] = fmt.Sprintf("%q", client.ObjectKeyFromObject(&slice))
509+
}
510+
sort.Strings(sliceNames)
511+
return strings.Join(sliceNames, ", ")
512+
}
513+
514+
// syncAdmissionCheckStatus syncs the admission check status with the state of the Slices.
515+
func (r *WorkloadReconciler) syncAdmissionCheckStatus(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState, slices []v1alpha1.Slice) error {
418516
originalState := ac.State
517+
originalMessage := ac.Message
419518

420-
errCond := meta.FindStatusCondition(slice.Status.Conditions, string(v1alpha1.Error))
519+
slicesByStatus := make(map[v1alpha1.SliceConditionType][]v1alpha1.Slice)
520+
for _, slice := range slices {
521+
for _, status := range core.SliceStatuses {
522+
if meta.IsStatusConditionTrue(slice.Status.Conditions, string(status)) {
523+
slicesByStatus[status] = append(slicesByStatus[status], slice)
524+
}
525+
}
526+
}
421527

422528
switch {
423-
case core.Forming(slice):
424-
ac.Message = fmt.Sprintf("The Slice %q is being formed", slice.Name)
425-
case core.Ready(slice):
426-
ac.State = kueue.CheckStateReady
427-
ac.Message = fmt.Sprintf("The Slice %q is fully operational", slice.Name)
428-
case core.Degraded(slice):
429-
ac.State = kueue.CheckStateReady
430-
ac.Message = fmt.Sprintf("The Slice %q is running with reduced capacity or performance", slice.Name)
431-
case core.Deformed(slice):
529+
case len(slicesByStatus[v1alpha1.Error]) > 0:
432530
ac.State = kueue.CheckStateRejected
433-
ac.Message = fmt.Sprintf("The Slice %q is being torn down", slice.Name)
434-
case errCond != nil && errCond.Status == metav1.ConditionTrue:
531+
ac.Message = fmt.Sprintf("The Slices %s are not operational due to an errors", joinSliceNames(slicesByStatus[v1alpha1.Error]))
532+
case len(slicesByStatus[v1alpha1.Deformed]) > 0:
435533
ac.State = kueue.CheckStateRejected
436-
ac.Message = fmt.Sprintf("The Slice %q is not operational due to an error: %s", slice.Name, errCond.Message)
534+
ac.Message = fmt.Sprintf("The Slices %s are being torn down", joinSliceNames(slicesByStatus[v1alpha1.Deformed]))
535+
case len(slicesByStatus[v1alpha1.Forming]) > 0:
536+
ac.State = kueue.CheckStatePending
537+
ac.Message = fmt.Sprintf("The Slices %s are being formed", joinSliceNames(slicesByStatus[v1alpha1.Forming]))
538+
case len(slicesByStatus[v1alpha1.Degraded]) > 0:
539+
ac.State = kueue.CheckStateReady
540+
ac.Message = fmt.Sprintf("The Slices %s are running with reduced capacity or performance", joinSliceNames(slicesByStatus[v1alpha1.Degraded]))
541+
case len(slicesByStatus[v1alpha1.Ready]) > 0:
542+
ac.State = kueue.CheckStateReady
543+
ac.Message = fmt.Sprintf("The Slices %s are fully operational", joinSliceNames(slicesByStatus[v1alpha1.Ready]))
544+
}
545+
546+
// No changes.
547+
if originalState == ac.State && ac.Message == originalMessage {
548+
return nil
437549
}
438550

439551
err := r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac)

0 commit comments

Comments
 (0)