Skip to content

Commit 38c1d4f

Browse files
committed
Create Slice for each PodSet in Workload.
1 parent 4f00dab commit 38c1d4f

File tree

10 files changed

+710
-112
lines changed

10 files changed

+710
-112
lines changed

slice/cmd/main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,18 @@ func main() {
232232
}
233233
}
234234

235+
ctx := ctrl.SetupSignalHandler()
236+
if err := controller.SetupSliceIndexer(ctx, mgr.GetFieldIndexer()); err != nil {
237+
setupLog.Error(err, "unable to setup indexes")
238+
os.Exit(1)
239+
}
240+
235241
go setupControllers(mgr, certsReady)
236242

237243
setupProbeEndpoints(mgr, certsReady)
238244

239245
setupLog.Info("starting manager")
240-
if err := mgr.Start(ctrl.SetupSignalHandler()); err != nil {
246+
if err := mgr.Start(ctx); err != nil {
241247
setupLog.Error(err, "problem running manager")
242248
os.Exit(1)
243249
}

slice/internal/controller/indexer.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
Copyright The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package controller
18+
19+
import (
20+
"context"
21+
"fmt"
22+
23+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
24+
"sigs.k8s.io/controller-runtime/pkg/client"
25+
26+
"tpu-slice-controller/api/v1alpha1"
27+
"tpu-slice-controller/internal/util/slices"
28+
)
29+
30+
const (
31+
OwnerReferenceUID = "metadata.ownerReferences.uid"
32+
)
33+
34+
func indexOwnerReferenceUID(obj client.Object) []string {
35+
return slices.Map(obj.GetOwnerReferences(), func(o *metav1.OwnerReference) string { return string(o.UID) })
36+
}
37+
38+
// SetupSliceIndexer configures the indexer to index specific fields for v1alpha1.Slice resources.
39+
func SetupSliceIndexer(ctx context.Context, indexer client.FieldIndexer) error {
40+
if err := indexer.IndexField(ctx, &v1alpha1.Slice{}, OwnerReferenceUID, indexOwnerReferenceUID); err != nil {
41+
return fmt.Errorf("setting index on ownerReferences.uid for Slice: %w", err)
42+
}
43+
return nil
44+
}

slice/internal/controller/workload_controller.go

Lines changed: 184 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ import (
2020
"context"
2121
"errors"
2222
"fmt"
23+
"k8s.io/apimachinery/pkg/api/meta"
24+
"strings"
2325
"time"
2426

25-
"github.com/go-logr/logr"
2627
corev1 "k8s.io/api/core/v1"
2728
apierrors "k8s.io/apimachinery/pkg/api/errors"
28-
"k8s.io/apimachinery/pkg/api/meta"
2929
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3030
"k8s.io/apimachinery/pkg/types"
3131
"k8s.io/apimachinery/pkg/util/sets"
@@ -100,22 +100,19 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
100100
log.V(3).Info("Reconcile Workload")
101101

102102
if finalize, reason := shouldFinalize(wl); finalize {
103-
if controllerutil.ContainsFinalizer(wl, SliceControllerName) {
104-
log.V(3).Info(fmt.Sprintf("Cleaning up the Slice and finalize the Workload because %s", reason))
105-
err = r.client.Delete(ctx, core.SliceWithMetadata(wl))
106-
if client.IgnoreNotFound(err) != nil {
107-
return ctrl.Result{}, err
108-
}
109-
controllerutil.RemoveFinalizer(wl, SliceControllerName)
110-
if err := r.client.Update(ctx, wl); err != nil {
111-
if !apierrors.IsNotFound(err) {
112-
log.Error(err, "Failed to remove finalizer")
113-
}
114-
return ctrl.Result{}, client.IgnoreNotFound(err)
115-
}
116-
log.V(3).Info("Removed finalizer")
103+
if !controllerutil.ContainsFinalizer(wl, SliceControllerName) {
104+
return ctrl.Result{}, nil
117105
}
118-
return ctrl.Result{}, nil
106+
107+
log.V(3).Info(fmt.Sprintf("Cleaning up the Slice and finalize the Workload because %s", reason))
108+
109+
err := r.cleanupSlices(ctx, wl)
110+
if err != nil {
111+
return ctrl.Result{}, err
112+
}
113+
114+
err = r.finalizeWorkload(ctx, wl)
115+
return ctrl.Result{}, client.IgnoreNotFound(err)
119116
}
120117

121118
if err = validateRelevantWorkload(wl); err != nil {
@@ -146,19 +143,12 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
146143
return ctrl.Result{}, nil
147144
}
148145

149-
slice := v1alpha1.Slice{}
150-
err = r.client.Get(ctx, core.SliceKeyFromWorkload(wl), &slice)
151-
if apierrors.IsNotFound(err) {
152-
// slice not found, create it and exit.
153-
err = r.createSlice(ctx, log, wl, ac)
154-
return ctrl.Result{}, err
155-
} else if err != nil {
156-
// error fetching slice
157-
log.Error(err, "Failed to fetch the Slice")
146+
slices, changed, err := r.syncSlices(ctx, wl, ac)
147+
if err != nil || changed {
158148
return ctrl.Result{}, err
159149
}
160150

161-
err = r.syncAdmissionCheckStatus(ctx, wl, ac, &slice)
151+
err = r.syncAdmissionCheckStatus(ctx, wl, ac, slices)
162152
return ctrl.Result{}, client.IgnoreNotFound(err)
163153
}
164154

@@ -182,6 +172,68 @@ func shouldFinalize(wl *kueue.Workload) (bool, string) {
182172
return false, ""
183173
}
184174

175+
func (r *WorkloadReconciler) cleanupSlices(ctx context.Context, wl *kueue.Workload) error {
176+
log := ctrl.LoggerFrom(ctx)
177+
178+
slices, err := r.findWorkloadSlices(ctx, wl)
179+
if err != nil {
180+
log.Error(err, "Failed to find Slices")
181+
return err
182+
}
183+
184+
for _, slice := range slices {
185+
err = r.client.Delete(ctx, &slice)
186+
if client.IgnoreNotFound(err) != nil {
187+
log.Error(err, "Failed to delete the Slice", "slice", klog.KObj(&slice))
188+
return err
189+
}
190+
}
191+
192+
return nil
193+
}
194+
195+
func (r *WorkloadReconciler) findWorkloadSlices(ctx context.Context, wl *kueue.Workload) ([]v1alpha1.Slice, error) {
196+
slices := &v1alpha1.SliceList{}
197+
opts := []client.ListOption{
198+
client.InNamespace(wl.Namespace),
199+
client.MatchingFields{OwnerReferenceUID: string(wl.UID)},
200+
}
201+
if err := r.client.List(ctx, slices, opts...); err != nil {
202+
return nil, err
203+
}
204+
return slices.Items, nil
205+
}
206+
207+
func (r *WorkloadReconciler) findWorkloadSlicesByName(ctx context.Context, wl *kueue.Workload) (map[string]*v1alpha1.Slice, error) {
208+
existingSlices, err := r.findWorkloadSlices(ctx, wl)
209+
if err != nil {
210+
return nil, err
211+
}
212+
213+
existingSlicesByName := make(map[string]*v1alpha1.Slice, len(existingSlices))
214+
for _, slice := range existingSlices {
215+
existingSlicesByName[slice.Name] = &slice
216+
}
217+
218+
return existingSlicesByName, nil
219+
}
220+
221+
func (r *WorkloadReconciler) finalizeWorkload(ctx context.Context, wl *kueue.Workload) error {
222+
log := ctrl.LoggerFrom(ctx)
223+
224+
controllerutil.RemoveFinalizer(wl, SliceControllerName)
225+
if err := r.client.Update(ctx, wl); err != nil {
226+
if !apierrors.IsNotFound(err) {
227+
log.Error(err, "Failed to remove the finalizer")
228+
}
229+
return err
230+
}
231+
232+
log.V(5).Info("Removed finalizer")
233+
234+
return nil
235+
}
236+
185237
func validateRelevantWorkload(wl *kueue.Workload) error {
186238
if !hasRelevantPodSet(wl.Spec.PodSets) {
187239
return errors.New("does not have a relevant podset")
@@ -228,46 +280,93 @@ func (r *WorkloadReconciler) sliceAC(ctx context.Context, wl *kueue.Workload) (*
228280
return workload.FindAdmissionCheck(wl.Status.AdmissionChecks, relevantChecks[0]), nil
229281
}
230282

231-
func parseTopologyAssignmentIntoNodeSelector(slice *v1alpha1.Slice, wl *kueue.Workload) {
232-
nodeSelectors := sets.New[string]()
283+
func (r *WorkloadReconciler) syncSlices(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState) ([]v1alpha1.Slice, bool, error) {
284+
log := ctrl.LoggerFrom(ctx)
285+
286+
existingSlices, err := r.findWorkloadSlicesByName(ctx, wl)
287+
if err != nil {
288+
log.Error(err, "Failed to list Slices")
289+
return nil, false, err
290+
}
291+
292+
createdSlices := make([]v1alpha1.Slice, 0, len(existingSlices))
293+
changed := false
294+
233295
for _, psa := range wl.Status.Admission.PodSetAssignments {
234-
// we already validated that all assignments have a valid level,
235-
// in validateRelevantWorkload.
236-
subblockLevelIndex := topology.SubblockLevelIndex(&psa)
237-
for _, domain := range psa.TopologyAssignment.Domains {
238-
nodeSelectors.Insert(domain.Values[subblockLevelIndex])
296+
sliceKey := core.SliceKeyFromWorkload(wl, psa.Name)
297+
298+
if createdSlice, ok := existingSlices[sliceKey.Name]; ok {
299+
createdSlices = append(createdSlices, *createdSlice)
300+
continue
301+
}
302+
303+
createdSlice, err := r.createSlice(ctx, wl, ac, psa.Name)
304+
if err != nil {
305+
return nil, false, err
239306
}
307+
308+
createdSlices = append(createdSlices, *createdSlice)
309+
changed = true
310+
}
311+
312+
if changed {
313+
msg := fmt.Sprintf("The Slices %v have been created", joinSliceNames(createdSlices))
314+
log.V(3).Info(msg)
315+
r.record.Event(wl, corev1.EventTypeNormal, SliceCreatedEventType, msg)
316+
ac.Message = msg
317+
return createdSlices, true, r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac)
318+
}
319+
320+
return createdSlices, false, nil
321+
}
322+
323+
func parseTopologyAssignmentIntoNodeSelector(slice *v1alpha1.Slice, wl *kueue.Workload, podSetName kueue.PodSetReference) {
324+
psa := findPodSetAssignmentByName(wl.Status.Admission.PodSetAssignments, podSetName)
325+
if psa == nil {
326+
return
327+
}
328+
nodeSelectors := sets.New[string]()
329+
// we already validated that all assignments have a valid level,
330+
// in validateRelevantWorkload.
331+
subblockLevelIndex := topology.SubblockLevelIndex(psa)
332+
for _, domain := range psa.TopologyAssignment.Domains {
333+
nodeSelectors.Insert(domain.Values[subblockLevelIndex])
240334
}
241335
slice.Spec.NodeSelector = map[string][]string{
242336
TPUReservationSubblockLabel: sets.List(nodeSelectors),
243337
}
244338
}
245339

246-
func (r *WorkloadReconciler) createSlice(ctx context.Context, log logr.Logger, wl *kueue.Workload, ac *kueue.AdmissionCheckState) error {
247-
slice := core.SliceWithMetadata(wl)
248-
log = log.WithValues("slice", klog.KObj(slice))
340+
func findPodSetAssignmentByName(podSetAssignments []kueue.PodSetAssignment, name kueue.PodSetReference) *kueue.PodSetAssignment {
341+
for _, psa := range podSetAssignments {
342+
if psa.Name == name {
343+
return &psa
344+
}
345+
}
346+
return nil
347+
}
348+
349+
func (r *WorkloadReconciler) createSlice(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState, podSetName kueue.PodSetReference) (*v1alpha1.Slice, error) {
350+
slice := core.SliceWithMetadata(wl, podSetName)
351+
log := ctrl.LoggerFrom(ctx).WithValues("slice", klog.KObj(slice))
249352
log.V(3).Info("Creating Slice")
250353

251354
if err := controllerutil.SetControllerReference(wl, slice, r.client.Scheme()); err != nil {
252-
return err
355+
return nil, err
253356
}
254-
parseTopologyAssignmentIntoNodeSelector(slice, wl)
357+
parseTopologyAssignmentIntoNodeSelector(slice, wl, podSetName)
255358

256359
if err := r.client.Create(ctx, slice); err != nil {
257-
msg := fmt.Sprintf("Error creating Slice %q: %v", slice.Name, err)
360+
msg := fmt.Sprintf("Error creating Slice %q: %v", client.ObjectKeyFromObject(slice), err)
258361
log.Error(err, msg)
259362
r.record.Event(wl, corev1.EventTypeWarning, FailedCreateSliceEventType, api.TruncateEventMessage(msg))
363+
ac.State = kueue.CheckStatePending
260364
ac.Message = api.TruncateConditionMessage(msg)
261365
patchErr := r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac)
262-
return errors.Join(err, patchErr)
366+
return nil, errors.Join(err, patchErr)
263367
}
264368

265-
msg := fmt.Sprintf("The Slice %s has been created", client.ObjectKeyFromObject(slice))
266-
log.V(3).Info(msg)
267-
r.record.Event(wl, corev1.EventTypeNormal, SliceCreatedEventType, msg)
268-
ac.Message = msg
269-
270-
return r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac)
369+
return slice, nil
271370
}
272371

273372
func (r *WorkloadReconciler) updateWorkloadAdmissionCheckStatus(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState) error {
@@ -280,27 +379,49 @@ func (r *WorkloadReconciler) updateWorkloadAdmissionCheckStatus(ctx context.Cont
280379
return err
281380
}
282381

283-
// syncAdmissionCheckStatus syncs the admission check status with the state of the slice.
284-
func (r *WorkloadReconciler) syncAdmissionCheckStatus(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState, slice *v1alpha1.Slice) error {
382+
func joinSliceNames(slices []v1alpha1.Slice) string {
383+
sliceNames := make([]string, len(slices))
384+
for index, slice := range slices {
385+
sliceNames[index] = fmt.Sprintf("%q", client.ObjectKeyFromObject(&slice))
386+
}
387+
return strings.Join(sliceNames, ", ")
388+
}
389+
390+
// syncAdmissionCheckStatus syncs the admission check status with the state of the Slices.
391+
func (r *WorkloadReconciler) syncAdmissionCheckStatus(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState, slices []v1alpha1.Slice) error {
285392
originalState := ac.State
393+
originalMessage := ac.Message
286394

287-
errCond := meta.FindStatusCondition(slice.Status.Conditions, string(v1alpha1.Error))
395+
slicesByStatus := make(map[v1alpha1.SliceConditionType][]v1alpha1.Slice)
396+
for _, slice := range slices {
397+
for _, status := range core.SliceStatuses {
398+
if meta.IsStatusConditionTrue(slice.Status.Conditions, string(status)) {
399+
slicesByStatus[status] = append(slicesByStatus[status], slice)
400+
}
401+
}
402+
}
288403

289404
switch {
290-
case meta.IsStatusConditionTrue(slice.Status.Conditions, string(v1alpha1.Forming)):
291-
ac.Message = fmt.Sprintf("The Slice %q is being formed", slice.Name)
292-
case meta.IsStatusConditionTrue(slice.Status.Conditions, string(v1alpha1.Ready)):
293-
ac.State = kueue.CheckStateReady
294-
ac.Message = fmt.Sprintf("The Slice %q is fully operational", slice.Name)
295-
case meta.IsStatusConditionTrue(slice.Status.Conditions, string(v1alpha1.Degraded)):
296-
ac.State = kueue.CheckStateReady
297-
ac.Message = fmt.Sprintf("The Slice %q is running with reduced capacity or performance", slice.Name)
298-
case meta.IsStatusConditionTrue(slice.Status.Conditions, string(v1alpha1.Deformed)):
405+
case len(slicesByStatus[v1alpha1.Error]) > 0:
299406
ac.State = kueue.CheckStateRejected
300-
ac.Message = fmt.Sprintf("The Slice %q is being torn down", slice.Name)
301-
case errCond != nil && errCond.Status == metav1.ConditionTrue:
407+
ac.Message = fmt.Sprintf("The Slices %s are not operational due to an errors", joinSliceNames(slicesByStatus[v1alpha1.Error]))
408+
case len(slicesByStatus[v1alpha1.Deformed]) > 0:
302409
ac.State = kueue.CheckStateRejected
303-
ac.Message = fmt.Sprintf("The Slice %q is not operational due to an error: %s", slice.Name, errCond.Message)
410+
ac.Message = fmt.Sprintf("The Slices %s are being torn down", joinSliceNames(slicesByStatus[v1alpha1.Deformed]))
411+
case len(slicesByStatus[v1alpha1.Forming]) > 0:
412+
ac.State = kueue.CheckStatePending
413+
ac.Message = fmt.Sprintf("The Slices %s are being formed", joinSliceNames(slicesByStatus[v1alpha1.Forming]))
414+
case len(slicesByStatus[v1alpha1.Degraded]) > 0:
415+
ac.State = kueue.CheckStateReady
416+
ac.Message = fmt.Sprintf("The Slices %s are running with reduced capacity or performance", joinSliceNames(slicesByStatus[v1alpha1.Degraded]))
417+
case len(slicesByStatus[v1alpha1.Ready]) > 0:
418+
ac.State = kueue.CheckStateReady
419+
ac.Message = fmt.Sprintf("The Slices %s are fully operational", joinSliceNames(slicesByStatus[v1alpha1.Ready]))
420+
}
421+
422+
// No changes.
423+
if originalState == ac.State && ac.Message == originalMessage {
424+
return nil
304425
}
305426

306427
err := r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac)
@@ -330,8 +451,7 @@ type sliceHandler struct {
330451
func (h *sliceHandler) Generic(context.Context, event.GenericEvent, workqueue.TypedRateLimitingInterface[reconcile.Request]) {
331452
}
332453

333-
func (h *sliceHandler) Create(ctx context.Context, e event.CreateEvent, q workqueue.TypedRateLimitingInterface[reconcile.Request]) {
334-
h.handleEvent(ctx, e.Object, q)
454+
func (h *sliceHandler) Create(context.Context, event.CreateEvent, workqueue.TypedRateLimitingInterface[reconcile.Request]) {
335455
}
336456

337457
func (h *sliceHandler) Delete(ctx context.Context, e event.DeleteEvent, q workqueue.TypedRateLimitingInterface[reconcile.Request]) {

0 commit comments

Comments
 (0)