diff --git a/slice/.golangci.yml b/slice/.golangci.yml index 713805002..d14e88e6f 100644 --- a/slice/.golangci.yml +++ b/slice/.golangci.yml @@ -7,7 +7,6 @@ run: linters: enable: - copyloopvar - - dupl - dupword - durationcheck - fatcontext diff --git a/slice/cmd/main.go b/slice/cmd/main.go index 111a91573..f7f295cf4 100644 --- a/slice/cmd/main.go +++ b/slice/cmd/main.go @@ -254,8 +254,8 @@ func setupControllers(mgr ctrl.Manager, certsReady chan struct{}) { os.Exit(1) } - if err := controller.NewWorkloadReconciler(mgr.GetClient()).SetupWithManager(mgr); err != nil { - setupLog.Error(err, "unable to create controller", "controller", "Workload") + if failedCtrl, err := controller.SetupControllers(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", failedCtrl) os.Exit(1) } diff --git a/slice/config/rbac/role.yaml b/slice/config/rbac/role.yaml index 3e4eb5e21..bc18a0d35 100644 --- a/slice/config/rbac/role.yaml +++ b/slice/config/rbac/role.yaml @@ -4,6 +4,15 @@ kind: ClusterRole metadata: name: manager-role rules: +- apiGroups: + - "" + resources: + - events + verbs: + - create + - patch + - update + - watch - apiGroups: - "" resources: @@ -22,6 +31,23 @@ rules: - list - update - watch +- apiGroups: + - kueue.x-k8s.io + resources: + - admissionchecks + verbs: + - get + - list + - watch +- apiGroups: + - kueue.x-k8s.io + resources: + - admissionchecks/status + - workloads/status + verbs: + - get + - patch + - update - apiGroups: - kueue.x-k8s.io resources: diff --git a/slice/go.mod b/slice/go.mod index ac3d397b7..76ef19ed5 100644 --- a/slice/go.mod +++ b/slice/go.mod @@ -99,7 +99,7 @@ require ( k8s.io/apiserver v0.33.2 // indirect k8s.io/component-base v0.33.2 // indirect k8s.io/component-helpers v0.33.2 // indirect - k8s.io/klog/v2 v2.130.1 // indirect + k8s.io/klog/v2 v2.130.1 k8s.io/kube-openapi v0.0.0-20250318190949-c8a335a9a2ff // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2 // indirect sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 // indirect diff --git a/slice/internal/controller/admissioncheck_controller.go b/slice/internal/controller/admissioncheck_controller.go new file mode 100644 index 000000000..6d4bbe2e3 --- /dev/null +++ b/slice/internal/controller/admissioncheck_controller.go @@ -0,0 +1,78 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "context" + + apimeta "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" +) + +type AdmissionCheckReconciler struct { + client client.Client +} + +var _ reconcile.Reconciler = (*AdmissionCheckReconciler)(nil) + +func NewAdmissionCheckReconciler(cl client.Client) *AdmissionCheckReconciler { + return &AdmissionCheckReconciler{ + client: cl, + } +} + +// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=admissionchecks,verbs=get;list;watch +// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=admissionchecks/status,verbs=get;update;patch + +func (r *AdmissionCheckReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + ac := &kueue.AdmissionCheck{} + if err := r.client.Get(ctx, req.NamespacedName, ac); err != nil || ac.Spec.ControllerName != SliceControllerName { + return reconcile.Result{}, client.IgnoreNotFound(err) + } + + log := ctrl.LoggerFrom(ctx) + log.V(2).Info("Reconcile AdmissionCheck") + + currentCondition := ptr.Deref(apimeta.FindStatusCondition(ac.Status.Conditions, kueue.AdmissionCheckActive), metav1.Condition{}) + newCondition := metav1.Condition{ + Type: kueue.AdmissionCheckActive, + Status: metav1.ConditionTrue, + Reason: "Active", + Message: "The admission check is active", + ObservedGeneration: ac.Generation, + } + + if currentCondition.Status != newCondition.Status { + apimeta.SetStatusCondition(&ac.Status.Conditions, newCondition) + return reconcile.Result{}, client.IgnoreNotFound(r.client.Status().Update(ctx, ac)) + } + + return reconcile.Result{}, nil +} + +// SetupWithManager sets up the controller with the Manager. +func (r *AdmissionCheckReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&kueue.AdmissionCheck{}). + Named("admissioncheck_controller"). + Complete(r) +} diff --git a/slice/internal/controller/admissioncheck_controller_test.go b/slice/internal/controller/admissioncheck_controller_test.go new file mode 100644 index 000000000..b01db882c --- /dev/null +++ b/slice/internal/controller/admissioncheck_controller_test.go @@ -0,0 +1,95 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + + utiltesting "tpu-slice-controller/internal/util/testing" +) + +func TestAdmissionCheckReconciler(t *testing.T) { + baseAdmissionCheckName := "ac" + baseGeneration := int64(1) + baseRequest := types.NamespacedName{Name: baseAdmissionCheckName, Namespace: corev1.NamespaceDefault} + baseAdmissionCheckWrapper := utiltesting.MakeAdmissionCheck(baseAdmissionCheckName). + Generation(baseGeneration). + ControllerName(SliceControllerName) + + testCases := map[string]struct { + request types.NamespacedName + admissionCheck *kueue.AdmissionCheck + wantAdmissionChecks []kueue.AdmissionCheck + wantErr error + }{ + "unrelated check": { + request: baseRequest, + admissionCheck: baseAdmissionCheckWrapper.Clone().ControllerName("other-controller").Obj(), + wantAdmissionChecks: []kueue.AdmissionCheck{ + *baseAdmissionCheckWrapper.Clone().ControllerName("other-controller").Obj(), + }, + }, + "should set Active status": { + request: baseRequest, + admissionCheck: baseAdmissionCheckWrapper.DeepCopy(), + wantAdmissionChecks: []kueue.AdmissionCheck{ + *baseAdmissionCheckWrapper.Clone(). + Condition(metav1.Condition{ + Type: kueue.AdmissionCheckActive, + Status: metav1.ConditionTrue, + Reason: "Active", + Message: "The admission check is active", + ObservedGeneration: baseGeneration, + }). + Obj(), + }, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + scheme := runtime.NewScheme() + utilruntime.Must(kueue.AddToScheme(scheme)) + utilruntime.Must(kueue.AddToScheme(scheme)) + + clientBuilder := fake.NewClientBuilder().WithScheme(scheme) + + if tc.admissionCheck != nil { + clientBuilder = clientBuilder.WithObjects(tc.admissionCheck) + } + + kClient := clientBuilder.Build() + reconciler := NewAdmissionCheckReconciler(kClient) + + ctx, _ := utiltesting.ContextWithLog(t) + + _, err := reconciler.Reconcile(ctx, reconcile.Request{NamespacedName: tc.request}) + if diff := cmp.Diff(tc.wantErr, err); diff != "" { + t.Errorf("Error after reconcile (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/slice/internal/controller/controller.go b/slice/internal/controller/controller.go new file mode 100644 index 000000000..d4f879d07 --- /dev/null +++ b/slice/internal/controller/controller.go @@ -0,0 +1,37 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package controller + +import ( + ctrl "sigs.k8s.io/controller-runtime" +) + +const ( + SliceWorkloadControllerName = "slice-workload-controller" +) + +func SetupControllers(mgr ctrl.Manager) (string, error) { + wlRec := NewWorkloadReconciler(mgr.GetClient(), mgr.GetEventRecorderFor(SliceWorkloadControllerName)) + if err := wlRec.SetupWithManager(mgr); err != nil { + return "Workload", err + } + acRec := NewAdmissionCheckReconciler(mgr.GetClient()) + if err := acRec.SetupWithManager(mgr); err != nil { + return "AdmissionCheck", err + } + return "", nil +} diff --git a/slice/internal/controller/workload_controller.go b/slice/internal/controller/workload_controller.go index a5cede327..727e5af5f 100644 --- a/slice/internal/controller/workload_controller.go +++ b/slice/internal/controller/workload_controller.go @@ -18,40 +18,74 @@ package controller import ( "context" + "errors" + "fmt" + "time" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/tools/record" + "k8s.io/client-go/util/workqueue" + "k8s.io/klog/v2" + "k8s.io/utils/clock" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/reconcile" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/pkg/util/admissioncheck" "sigs.k8s.io/kueue/pkg/workload" "tpu-slice-controller/api/v1alpha1" + "tpu-slice-controller/internal/core" + "tpu-slice-controller/internal/util/api" ) const ( - CleanupSliceFinalizerName = "accelerator.gke.io/slice" + SliceControllerName = "accelerator.gke.io/slice" TPUReservationSubblockLabel = "cloud.google.com/gke-tpu-reservation-subblock" + + SliceCreatedEventType = "SliceCreated" + FailedCreateSliceEventType = "FailedCreateSlice" + AdmissionCheckUpdatedEventType = "AdmissionCheckUpdated" +) + +const ( + updatesBatchPeriod = time.Second +) + +var ( + realClock = clock.RealClock{} ) // WorkloadReconciler reconciles a Workload object type WorkloadReconciler struct { client client.Client + record record.EventRecorder + clock clock.Clock } var _ reconcile.Reconciler = (*WorkloadReconciler)(nil) -func NewWorkloadReconciler(cl client.Client) *WorkloadReconciler { +func NewWorkloadReconciler(cl client.Client, record record.EventRecorder) *WorkloadReconciler { return &WorkloadReconciler{ client: cl, + record: record, + clock: realClock, } } // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch +// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch // +kubebuilder:rbac:groups=slice.accelerator.gke.io,resources=slices,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=slice.accelerator.gke.io,resources=slices/finalizers,verbs=update +// +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update;patch func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { wl := &kueue.Workload{} @@ -64,34 +98,113 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c log.V(2).Info("Reconcile Workload") if r.shouldFinalize(wl) { - if controllerutil.ContainsFinalizer(wl, CleanupSliceFinalizerName) { + if controllerutil.ContainsFinalizer(wl, SliceControllerName) { err = r.client.Delete(ctx, r.newEmptySlice(wl)) if client.IgnoreNotFound(err) != nil { return ctrl.Result{}, err } - controllerutil.RemoveFinalizer(wl, CleanupSliceFinalizerName) + controllerutil.RemoveFinalizer(wl, SliceControllerName) if err := r.client.Update(ctx, wl); err != nil { - return ctrl.Result{}, err + if !apierrors.IsNotFound(err) { + log.Error(err, "Failed to remove finalizer") + } + return ctrl.Result{}, client.IgnoreNotFound(err) } log.V(5).Info("Removed finalizer") } return ctrl.Result{}, nil } - if controllerutil.AddFinalizer(wl, CleanupSliceFinalizerName) { - if err := r.client.Update(ctx, wl); err != nil { - log.V(5).Info("Added finalizer") - return ctrl.Result{}, err + ac, err := r.sliceAC(ctx, wl) + if err != nil { + return reconcile.Result{}, err + } + if ac == nil { + log.V(5).Info("Admission check not found – ignoring reconciliation for now") + return reconcile.Result{}, nil + } + + log = log.WithValues("admissionCheck", ac.Name) + ctrl.LoggerInto(ctx, log) + + if !r.isRelevantWorkload(wl) { + return ctrl.Result{}, nil + } + + if controllerutil.AddFinalizer(wl, SliceControllerName) { + if err = r.client.Update(ctx, wl); err != nil { + if !apierrors.IsNotFound(err) { + log.Error(err, "Failed to add finalizer") + } + return ctrl.Result{}, client.IgnoreNotFound(err) } + log.V(5).Info("Added finalizer") + return ctrl.Result{}, nil } - return ctrl.Result{}, r.createSliceIfNotExist(ctx, wl) + slice := r.newEmptySlice(wl) + + err = r.client.Get(ctx, client.ObjectKeyFromObject(slice), slice) + if client.IgnoreNotFound(err) != nil { + log.Error(err, "Failed to fetch the Slice") + return ctrl.Result{}, err + } + if err != nil { + return ctrl.Result{}, r.createSlice(ctx, wl, ac) + } + + err = r.syncAdmissionCheckStatus(ctx, wl, ac, slice) + return ctrl.Result{}, client.IgnoreNotFound(err) } func (r *WorkloadReconciler) shouldFinalize(wl *kueue.Workload) bool { return !wl.DeletionTimestamp.IsZero() || workload.IsFinished(wl) || workload.IsEvicted(wl) || !workload.IsActive(wl) } +func (r *WorkloadReconciler) isRelevantWorkload(wl *kueue.Workload) bool { + return hasRelevantPodSet(wl.Spec.PodSets) && + workload.HasQuotaReservation(wl) && + wl.Status.Admission != nil && + hasRelevantPodSetAssignment(wl.Status.Admission.PodSetAssignments) +} + +func hasRelevantPodSet(podSets []kueue.PodSet) bool { + // At least one PodSet should be relevant. + for _, ps := range podSets { + if core.IsRelevantPodTemplateSpec(ps.Template) { + return true + } + } + return false +} + +func hasRelevantPodSetAssignment(podSetAssignments []kueue.PodSetAssignment) bool { + for _, psa := range podSetAssignments { + // Only podSet with a TopologyAssignment should be processed. + if psa.TopologyAssignment != nil { + return true + } + } + return false +} + +func (r *WorkloadReconciler) sliceAC(ctx context.Context, wl *kueue.Workload) (*kueue.AdmissionCheckState, error) { + relevantChecks, err := admissioncheck.FilterForController(ctx, r.client, wl.Status.AdmissionChecks, SliceControllerName) + if err != nil { + return nil, err + } + if len(relevantChecks) == 0 { + return nil, nil + } + if len(relevantChecks) > 1 { + ctrl.LoggerFrom(ctx).V(2).Info( + "WARNING: More than one AdmissionCheck found. Using the first one", + "selected", relevantChecks[0], + ) + } + return workload.FindAdmissionCheck(wl.Status.AdmissionChecks, relevantChecks[0]), nil +} + func (r *WorkloadReconciler) newEmptySlice(wl *kueue.Workload) *v1alpha1.Slice { return &v1alpha1.Slice{ ObjectMeta: metav1.ObjectMeta{ @@ -107,9 +220,6 @@ func (r *WorkloadReconciler) newSlice(wl *kueue.Workload) (*v1alpha1.Slice, erro if err := controllerutil.SetControllerReference(wl, slice, r.client.Scheme()); err != nil { return nil, err } - if wl.Status.Admission == nil || wl.Status.Admission.PodSetAssignments == nil { - return slice, nil - } nodeSelectors := sets.New[string]() for _, psa := range wl.Status.Admission.PodSetAssignments { @@ -123,29 +233,129 @@ func (r *WorkloadReconciler) newSlice(wl *kueue.Workload) (*v1alpha1.Slice, erro return slice, nil } -func (r *WorkloadReconciler) createSliceIfNotExist(ctx context.Context, wl *kueue.Workload) error { - slice := r.newEmptySlice(wl) +func (r *WorkloadReconciler) createSlice(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState) error { + log := ctrl.LoggerFrom(ctx) - err := r.client.Get(ctx, client.ObjectKeyFromObject(slice), slice) - if client.IgnoreNotFound(err) != nil { + slice, err := r.newSlice(wl) + if err != nil { return err } - if err == nil { - return nil - } - slice, err = r.newSlice(wl) + log = log.WithValues("slice", klog.KObj(slice)) + + err = r.client.Create(ctx, slice) if err != nil { - return err + msg := fmt.Sprintf("Error creating Slice %q: %v", slice.Name, err) + log.Error(err, msg) + r.record.Event(wl, corev1.EventTypeWarning, FailedCreateSliceEventType, api.TruncateEventMessage(msg)) + ac.Message = api.TruncateConditionMessage(msg) + patchErr := r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac) + return errors.Join(err, patchErr) } - return r.client.Create(ctx, slice) + msg := fmt.Sprintf("The Slice %q has been created", slice.Name) + log.V(5).Info(msg) + r.record.Event(wl, corev1.EventTypeNormal, SliceCreatedEventType, msg) + ac.Message = msg + + return r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac) +} + +func (r *WorkloadReconciler) updateWorkloadAdmissionCheckStatus(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState) error { + wlPatch := workload.BaseSSAWorkload(wl) + workload.SetAdmissionCheckState(&wlPatch.Status.AdmissionChecks, *ac, r.clock) + err := r.client.Status().Patch(ctx, wlPatch, client.Apply, client.FieldOwner(SliceControllerName), client.ForceOwnership) + if err != nil && !apierrors.IsNotFound(err) { + ctrl.LoggerFrom(ctx).Error(err, "Failed to patch the Workload's admission status") + } + return err +} + +// syncAdmissionCheckStatus syncs the admission check status with the state of the slice. +func (r *WorkloadReconciler) syncAdmissionCheckStatus(ctx context.Context, wl *kueue.Workload, ac *kueue.AdmissionCheckState, slice *v1alpha1.Slice) error { + originalState := ac.State + + errCond := meta.FindStatusCondition(slice.Status.Conditions, string(v1alpha1.Error)) + + switch { + case meta.IsStatusConditionTrue(slice.Status.Conditions, string(v1alpha1.Forming)): + ac.Message = fmt.Sprintf("The Slice %q is being formed", slice.Name) + case meta.IsStatusConditionTrue(slice.Status.Conditions, string(v1alpha1.Ready)): + ac.State = kueue.CheckStateReady + ac.Message = fmt.Sprintf("The Slice %q is fully operational", slice.Name) + case meta.IsStatusConditionTrue(slice.Status.Conditions, string(v1alpha1.Degraded)): + ac.State = kueue.CheckStateReady + ac.Message = fmt.Sprintf("The Slice %q is running with reduced capacity or performance", slice.Name) + case meta.IsStatusConditionTrue(slice.Status.Conditions, string(v1alpha1.Deformed)): + ac.State = kueue.CheckStateRejected + ac.Message = fmt.Sprintf("The Slice %q is being torn down", slice.Name) + case errCond != nil && errCond.Status == metav1.ConditionTrue: + ac.State = kueue.CheckStateRejected + ac.Message = fmt.Sprintf("The Slice %q is not operational due to an error: %s", slice.Name, errCond.Message) + } + + err := r.updateWorkloadAdmissionCheckStatus(ctx, wl, ac) + if err == nil && originalState != ac.State { + message := fmt.Sprintf("Admission check %q updated state from %q to %q", ac.Name, originalState, ac.State) + r.record.Event(wl, corev1.EventTypeNormal, AdmissionCheckUpdatedEventType, message) + } + + return err } // SetupWithManager sets up the controller with the Manager. func (r *WorkloadReconciler) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). For(&kueue.Workload{}). - Named("workload"). + Named("workload_controller"). + Watches(&v1alpha1.Slice{}, &sliceHandler{client: r.client}). Complete(r) } + +var _ handler.EventHandler = (*sliceHandler)(nil) + +type sliceHandler struct { + client client.Client +} + +func (h *sliceHandler) Generic(context.Context, event.GenericEvent, workqueue.TypedRateLimitingInterface[reconcile.Request]) { +} + +func (h *sliceHandler) Create(ctx context.Context, e event.CreateEvent, q workqueue.TypedRateLimitingInterface[reconcile.Request]) { + h.handleEvent(ctx, e.Object, q) +} + +func (h *sliceHandler) Delete(ctx context.Context, e event.DeleteEvent, q workqueue.TypedRateLimitingInterface[reconcile.Request]) { + h.handleEvent(ctx, e.Object, q) +} + +func (h *sliceHandler) Update(ctx context.Context, e event.UpdateEvent, q workqueue.TypedRateLimitingInterface[reconcile.Request]) { + h.handleEvent(ctx, e.ObjectNew, q) +} + +func (h *sliceHandler) handleEvent(ctx context.Context, obj client.Object, q workqueue.TypedRateLimitingInterface[reconcile.Request]) { + slice, isSlice := obj.(*v1alpha1.Slice) + // Only Slice should be handled. + if !isSlice { + return + } + + log := ctrl.LoggerFrom(ctx) + + owner := metav1.GetControllerOf(slice) + if owner == nil { + log.V(5).Info("Owner not found") + return + } + + log.V(5).Info("Handle Slice event", "workload", klog.KRef(slice.Namespace, slice.Name)) + + req := reconcile.Request{ + NamespacedName: types.NamespacedName{ + Name: owner.Name, + Namespace: slice.Namespace, + }, + } + + q.AddAfter(req, updatesBatchPeriod) +} diff --git a/slice/internal/controller/workload_controller_test.go b/slice/internal/controller/workload_controller_test.go index af6776822..8cf2d7ea7 100644 --- a/slice/internal/controller/workload_controller_test.go +++ b/slice/internal/controller/workload_controller_test.go @@ -17,6 +17,9 @@ limitations under the License. package controller import ( + "context" + "errors" + "fmt" "testing" "time" @@ -27,11 +30,18 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/client-go/util/workqueue" + testingclock "k8s.io/utils/clock/testing" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" + "sigs.k8s.io/controller-runtime/pkg/controller/priorityqueue" "sigs.k8s.io/controller-runtime/pkg/reconcile" + jobset "sigs.k8s.io/jobset/api/jobset/v1alpha2" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" - "tpu-slice-controller/api/v1alpha1" + slice "tpu-slice-controller/api/v1alpha1" + "tpu-slice-controller/internal/core" utiltesting "tpu-slice-controller/internal/util/testing" ) @@ -41,143 +51,545 @@ var ( cmpopts.IgnoreFields(metav1.ObjectMeta{}, "ResourceVersion"), cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime"), } + errTest = errors.New("test error") ) func TestWorkloadReconciler(t *testing.T) { + now := time.Now().Truncate(time.Second) + fakeClock := testingclock.NewFakeClock(now) + baseWorkloadName := "workload" + baseAdmissionCheckName := "ac" baseRequest := types.NamespacedName{Name: baseWorkloadName, Namespace: corev1.NamespaceDefault} - baseWorkloadWrapper := utiltesting.MakeWorkload(baseWorkloadName, corev1.NamespaceDefault) - baseSliceWrapper := utiltesting.MakeSliceWrapper(baseWorkloadName, corev1.NamespaceDefault) + baseAdmissionCheckWrapper := utiltesting.MakeAdmissionCheck(baseAdmissionCheckName).ControllerName(SliceControllerName) + baseWorkloadWrapper := utiltesting.MakeWorkload(baseWorkloadName, corev1.NamespaceDefault). + UID(types.UID(baseWorkloadName)). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStatePending, + LastTransitionTime: metav1.NewTime(now), + Message: "", + }) + baseWorkloadWrapperWithPodSets := baseWorkloadWrapper.Clone(). + PodSets( + *utiltesting.MakePodSet("ps1", 2). + Annotation(core.TPUTopologyAnnotation, "4x4x12"). + NodeSelector(core.TPUAcceleratorLabel, "tpu-v7x"). + Obj(), + *utiltesting.MakePodSet("ps2", 2). + Annotation(core.TPUTopologyAnnotation, "4x4x12"). + NodeSelector(core.TPUAcceleratorLabel, "tpu-v7x"). + Obj(), + ) + baseWorkloadWrapperWithAdmission := baseWorkloadWrapperWithPodSets.Clone(). + ReserveQuota( + &kueue.Admission{ + PodSetAssignments: []kueue.PodSetAssignment{ + utiltesting.MakePodSetAssignment("ps1"). + TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ + { + Values: []string{"domain1", "domain2"}, + Count: 2, + }, + }).Obj(), + utiltesting.MakePodSetAssignment("ps2"). + TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ + { + Values: []string{"domain2", "domain3"}, + Count: 2, + }, + }). + Obj(), + }, + }, now, + ) + baseSliceWrapper := utiltesting.MakeSliceWrapper(baseWorkloadName, corev1.NamespaceDefault). + ControllerReference(kueue.GroupVersion.WithKind("Workload"), baseWorkloadName, baseWorkloadName). + NodeSelector(map[string][]string{TPUReservationSubblockLabel: {"domain1", "domain2", "domain3"}}) cases := map[string]struct { - request types.NamespacedName - workload *kueue.Workload - slice *v1alpha1.Slice - wantWorkloads []kueue.Workload - wantSlices []v1alpha1.Slice - wantErr error + interceptorFuncsCreate func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.CreateOption) error + request types.NamespacedName + objs []client.Object + wantWorkloads []kueue.Workload + wantSlices []slice.Slice + wantErr error + wantEvents []utiltesting.EventRecord }{ - "workload not found": { - request: types.NamespacedName{Name: "other-workload", Namespace: corev1.NamespaceDefault}, - workload: baseWorkloadWrapper.DeepCopy(), - slice: baseSliceWrapper.DeepCopy(), - wantWorkloads: []kueue.Workload{*baseWorkloadWrapper.DeepCopy()}, - wantSlices: []v1alpha1.Slice{*baseSliceWrapper.DeepCopy()}, - }, - "should delete finalizer because workload has DeletionTimestamp": { - request: baseRequest, - workload: baseWorkloadWrapper.Clone(). - DeletionTimestamp(time.Now()). - Finalizers(CleanupSliceFinalizerName). - Obj(), - slice: baseSliceWrapper.DeepCopy(), + "should skip reconciliation because the Workload was not found": { + request: types.NamespacedName{Name: "other-workload", Namespace: corev1.NamespaceDefault}, + objs: []client.Object{ + baseWorkloadWrapper.Clone().Finalizers(SliceControllerName).DeletionTimestamp(now).Obj(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapper.Clone().Finalizers(SliceControllerName).DeletionTimestamp(now).Obj(), + }, + }, + "should delete the finalizer because the Workload has a DeletionTimestamp": { + request: baseRequest, + objs: []client.Object{ + baseWorkloadWrapper.Clone().Finalizers(SliceControllerName).DeletionTimestamp(now).Obj(), + baseSliceWrapper.DeepCopy(), + }, }, - "should delete finalizer because workload is finished": { - request: baseRequest, - workload: baseWorkloadWrapper.Clone().Finalizers(CleanupSliceFinalizerName).Finished().Obj(), - slice: baseSliceWrapper.DeepCopy(), + "should delete the finalizer because the Workload is finished": { + request: baseRequest, + objs: []client.Object{ + baseWorkloadWrapper.Clone().Finalizers(SliceControllerName).Finished().Obj(), + baseSliceWrapper.DeepCopy(), + }, wantWorkloads: []kueue.Workload{*baseWorkloadWrapper.Clone().Finished().Obj()}, }, - "should delete finalizer because workload is evicted": { - request: baseRequest, - workload: baseWorkloadWrapper.Clone().Finalizers(CleanupSliceFinalizerName).Evicted().Obj(), - slice: baseSliceWrapper.DeepCopy(), + "should delete the finalizer because the Workload is evicted": { + request: baseRequest, + objs: []client.Object{ + baseWorkloadWrapper.Clone().Finalizers(SliceControllerName).Evicted().Obj(), + baseSliceWrapper.DeepCopy(), + }, wantWorkloads: []kueue.Workload{*baseWorkloadWrapper.Clone().Evicted().Obj()}, }, - "should delete finalizer because workload is deactivated": { - request: baseRequest, - workload: baseWorkloadWrapper.Clone().Finalizers(CleanupSliceFinalizerName).Active(false).Obj(), - slice: baseSliceWrapper.DeepCopy(), + "should delete the finalizer because the Workload is deactivated": { + request: baseRequest, + objs: []client.Object{ + baseWorkloadWrapper.Clone().Finalizers(SliceControllerName).Active(false).Obj(), + baseSliceWrapper.DeepCopy(), + }, wantWorkloads: []kueue.Workload{*baseWorkloadWrapper.Clone().Active(false).Obj()}, }, - "should add finalizer and create slice": { - request: baseRequest, - workload: baseWorkloadWrapper.UID(types.UID(baseWorkloadName)).DeepCopy(), + "shouldn't add finalizer because invalid TPU topology annotation": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapper.Clone(). + PodSets( + *utiltesting.MakePodSet("ps", 2). + Annotation(core.TPUTopologyAnnotation, "4x4"). + NodeSelector(core.TPUAcceleratorLabel, "tpu-v7x"). + Obj(), + ). + ReserveQuota( + &kueue.Admission{ + PodSetAssignments: []kueue.PodSetAssignment{ + utiltesting.MakePodSetAssignment("ps1"). + TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ + { + Values: []string{"domain1", "domain2"}, + Count: 2, + }, + }).Obj(), + }, + }, now, + ). + Obj(), + }, wantWorkloads: []kueue.Workload{ *baseWorkloadWrapper.Clone(). - UID(types.UID(baseWorkloadName)). - Finalizers(CleanupSliceFinalizerName). + PodSets( + *utiltesting.MakePodSet("ps", 2). + Annotation(core.TPUTopologyAnnotation, "4x4"). + NodeSelector(core.TPUAcceleratorLabel, "tpu-v7x"). + Obj(), + ). + ReserveQuota( + &kueue.Admission{ + PodSetAssignments: []kueue.PodSetAssignment{ + utiltesting.MakePodSetAssignment("ps1"). + TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ + { + Values: []string{"domain1", "domain2"}, + Count: 2, + }, + }).Obj(), + }, + }, now, + ). + Obj(), + }, + }, + "shouldn't add finalizer because invalid TPU accelerator node selector": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapper.Clone(). + PodSets( + *utiltesting.MakePodSet("ps", 2). + Annotation(core.TPUTopologyAnnotation, "4x4x12"). + NodeSelector(core.TPUAcceleratorLabel, "invalid"). + Obj(), + ). + ReserveQuota( + &kueue.Admission{ + PodSetAssignments: []kueue.PodSetAssignment{ + utiltesting.MakePodSetAssignment("ps1"). + TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ + { + Values: []string{"domain1", "domain2"}, + Count: 2, + }, + }).Obj(), + }, + }, now, + ). Obj(), }, - wantSlices: []v1alpha1.Slice{ - *baseSliceWrapper.Clone(). - ControllerReference(kueue.GroupVersion.WithKind("Workload"), baseWorkloadName, baseWorkloadName). + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapper.Clone(). + PodSets( + *utiltesting.MakePodSet("ps", 2). + Annotation(core.TPUTopologyAnnotation, "4x4x12"). + NodeSelector(core.TPUAcceleratorLabel, "invalid"). + Obj(), + ). + ReserveQuota( + &kueue.Admission{ + PodSetAssignments: []kueue.PodSetAssignment{ + utiltesting.MakePodSetAssignment("ps1"). + TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ + { + Values: []string{"domain1", "domain2"}, + Count: 2, + }, + }).Obj(), + }, + }, now, + ). Obj(), }, }, - "parse TAS Assignment to populate NodeSelector in Slice": { + "shouldn't add finalizer because there’s no Admission": { request: baseRequest, - workload: baseWorkloadWrapper.Clone(). - UID(types.UID(baseWorkloadName)). - PodSetAssignments(utiltesting.MakePodSetAssignment("psa1"). - TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ - { - Values: []string{"domain1", "domain2"}, - Count: 2, - }, - }).Obj(), - utiltesting.MakePodSetAssignment("psa2"). - TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ - { - Values: []string{"domain2", "domain3"}, - Count: 2, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithPodSets.DeepCopy(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithPodSets.DeepCopy(), + }, + }, + "shouldn't add finalizer because there’s no TopologyAssignment": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithPodSets.Clone(). + ReserveQuota( + &kueue.Admission{ + PodSetAssignments: []kueue.PodSetAssignment{ + utiltesting.MakePodSetAssignment("ps1").Obj(), + utiltesting.MakePodSetAssignment("ps2").Obj(), }, - }). - Obj(), - ).Obj(), + }, now, + ). + Obj(), + }, wantWorkloads: []kueue.Workload{ - *baseWorkloadWrapper.Clone(). - UID(types.UID(baseWorkloadName)). - PodSetAssignments(utiltesting.MakePodSetAssignment("psa1"). - TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ - { - Values: []string{"domain1", "domain2"}, - Count: 2, + *baseWorkloadWrapperWithPodSets.Clone(). + ReserveQuota( + &kueue.Admission{ + PodSetAssignments: []kueue.PodSetAssignment{ + utiltesting.MakePodSetAssignment("ps1").Obj(), + utiltesting.MakePodSetAssignment("ps2").Obj(), }, - }).Obj(), - utiltesting.MakePodSetAssignment("psa2"). - TopologyAssignment(nil, []kueue.TopologyDomainAssignment{ - { - Values: []string{"domain2", "domain3"}, - Count: 2, - }, - }). - Obj(), + }, now, ). - Finalizers(CleanupSliceFinalizerName). Obj(), }, - wantSlices: []v1alpha1.Slice{ - *baseSliceWrapper.Clone(). - ControllerReference(kueue.GroupVersion.WithKind("Workload"), baseWorkloadName, baseWorkloadName). - NodeSelector(map[string][]string{ - TPUReservationSubblockLabel: {"domain1", "domain2", "domain3"}, + }, + "shouldn't add finalizer because there’s no AdmissionCheck": { + request: baseRequest, + objs: []client.Object{ + baseWorkloadWrapperWithAdmission.DeepCopy(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.DeepCopy(), + }, + }, + "should add finalizer": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.DeepCopy(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + Obj(), + }, + }, + "should create a Slice": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.Finalizers(SliceControllerName).DeepCopy(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStatePending, + LastTransitionTime: metav1.NewTime(now), + Message: fmt.Sprintf(`The Slice "%s" has been created`, baseWorkloadName), + }). + Obj(), + }, + wantSlices: []slice.Slice{*baseSliceWrapper.DeepCopy()}, + wantEvents: []utiltesting.EventRecord{ + { + Key: client.ObjectKeyFromObject(baseWorkloadWrapper), + EventType: corev1.EventTypeNormal, + Reason: SliceCreatedEventType, + Message: fmt.Sprintf(`The Slice "%s" has been created`, baseWorkloadName), + }, + }, + }, + "parse TAS Assignment to populate NodeSelector in Slice": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.Clone().Finalizers(SliceControllerName).Obj(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStatePending, + LastTransitionTime: metav1.NewTime(now), + Message: fmt.Sprintf(`The Slice "%s" has been created`, baseWorkloadName), + }). + Obj(), + }, + wantSlices: []slice.Slice{ + *baseSliceWrapper.DeepCopy(), + }, + wantEvents: []utiltesting.EventRecord{ + { + Key: client.ObjectKeyFromObject(baseWorkloadWrapper), + EventType: corev1.EventTypeNormal, + Reason: SliceCreatedEventType, + Message: fmt.Sprintf(`The Slice "%s" has been created`, baseWorkloadName), + }, + }, + }, + "error on Slice creation": { + interceptorFuncsCreate: func(ctx context.Context, client client.WithWatch, obj client.Object, opts ...client.CreateOption) error { + if _, ok := obj.(*slice.Slice); ok { + return errTest + } + return client.Create(ctx, obj, opts...) + }, + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.Finalizers(SliceControllerName).DeepCopy(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStatePending, + LastTransitionTime: metav1.NewTime(now), + Message: "Error creating Slice \"workload\": test error", + }). + Obj(), + }, + wantErr: errTest, + wantEvents: []utiltesting.EventRecord{ + { + Key: client.ObjectKeyFromObject(baseWorkloadWrapper), + EventType: corev1.EventTypeWarning, + Reason: FailedCreateSliceEventType, + Message: `Error creating Slice "workload": test error`, + }, + }, + }, + "should update the Workload AdmissionCheckState when the Slice status is changed to Forming": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.Finalizers(SliceControllerName).DeepCopy(), + baseSliceWrapper.Clone().Forming().Obj(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStatePending, + LastTransitionTime: metav1.NewTime(now), + Message: fmt.Sprintf(`The Slice %q is being formed`, baseWorkloadName), + }). + Obj(), + }, + wantSlices: []slice.Slice{*baseSliceWrapper.Clone().Forming().Obj()}, + }, + "should update the Workload AdmissionCheckState when the Slice status is changed to Ready": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.Finalizers(SliceControllerName).DeepCopy(), + baseSliceWrapper.Clone().Ready().Obj(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStateReady, + LastTransitionTime: metav1.NewTime(now), + Message: fmt.Sprintf(`The Slice %q is fully operational`, baseWorkloadName), + }). + Obj(), + }, + wantSlices: []slice.Slice{*baseSliceWrapper.Clone().Ready().Obj()}, + wantEvents: []utiltesting.EventRecord{ + { + Key: client.ObjectKeyFromObject(baseWorkloadWrapper), + EventType: corev1.EventTypeNormal, + Reason: AdmissionCheckUpdatedEventType, + Message: fmt.Sprintf(`Admission check %q updated state from "Pending" to "Ready"`, baseAdmissionCheckName), + }, + }, + }, + "should update the Workload AdmissionCheckState when the Slice status is changed to Degraded": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.Finalizers(SliceControllerName).DeepCopy(), + baseSliceWrapper.Clone().Degraded().Obj(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStateReady, + LastTransitionTime: metav1.NewTime(now), + Message: fmt.Sprintf(`The Slice %q is running with reduced capacity or performance`, baseWorkloadName), + }). + Obj(), + }, + wantSlices: []slice.Slice{*baseSliceWrapper.Clone().Degraded().Obj()}, + wantEvents: []utiltesting.EventRecord{ + { + Key: client.ObjectKeyFromObject(baseWorkloadWrapper), + EventType: corev1.EventTypeNormal, + Reason: AdmissionCheckUpdatedEventType, + Message: fmt.Sprintf(`Admission check %q updated state from "Pending" to "Ready"`, baseAdmissionCheckName), + }, + }, + }, + "should update the Workload AdmissionCheckState when the Slice status is changed to Deformed": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.Finalizers(SliceControllerName).DeepCopy(), + baseSliceWrapper.Clone().Deformed().Obj(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStateRejected, + LastTransitionTime: metav1.NewTime(now), + Message: fmt.Sprintf(`The Slice %q is being torn down`, baseWorkloadName), + }). + Obj(), + }, + wantSlices: []slice.Slice{*baseSliceWrapper.Clone().Deformed().Obj()}, + wantEvents: []utiltesting.EventRecord{ + { + Key: client.ObjectKeyFromObject(baseWorkloadWrapper), + EventType: corev1.EventTypeNormal, + Reason: AdmissionCheckUpdatedEventType, + Message: fmt.Sprintf(`Admission check %q updated state from "Pending" to "Rejected"`, baseAdmissionCheckName), + }, + }, + }, + "should update the Workload AdmissionCheckState when the Slice status is changed to Error": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseWorkloadWrapperWithAdmission.Finalizers(SliceControllerName).DeepCopy(), + baseSliceWrapper.Clone().Error().Obj(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStateRejected, + LastTransitionTime: metav1.NewTime(now), + Message: fmt.Sprintf(`The Slice %q is not operational due to an error: Error by test`, baseWorkloadName), + }). + Obj(), + }, + wantSlices: []slice.Slice{*baseSliceWrapper.Clone().Error().Obj()}, + wantEvents: []utiltesting.EventRecord{ + { + Key: client.ObjectKeyFromObject(baseWorkloadWrapper), + EventType: corev1.EventTypeNormal, + Reason: AdmissionCheckUpdatedEventType, + Message: fmt.Sprintf(`Admission check %q updated state from "Pending" to "Rejected"`, baseAdmissionCheckName), + }, + }, + }, + "should use the first AdmissionCheck if more than one is found": { + request: baseRequest, + objs: []client.Object{ + baseAdmissionCheckWrapper.DeepCopy(), + baseAdmissionCheckWrapper.Clone().Name(baseAdmissionCheckName + "2").Obj(), + baseWorkloadWrapperWithAdmission.Finalizers(SliceControllerName).DeepCopy(), + baseSliceWrapper.Clone().Ready().Obj(), + }, + wantWorkloads: []kueue.Workload{ + *baseWorkloadWrapperWithAdmission.Clone(). + Finalizers(SliceControllerName). + AdmissionCheck(kueue.AdmissionCheckState{ + Name: kueue.AdmissionCheckReference(baseAdmissionCheckName), + State: kueue.CheckStateReady, + LastTransitionTime: metav1.NewTime(now), + Message: fmt.Sprintf(`The Slice %q is fully operational`, baseWorkloadName), }). Obj(), }, + wantSlices: []slice.Slice{*baseSliceWrapper.Clone().Ready().Obj()}, + wantEvents: []utiltesting.EventRecord{ + { + Key: client.ObjectKeyFromObject(baseWorkloadWrapper), + EventType: corev1.EventTypeNormal, + Reason: AdmissionCheckUpdatedEventType, + Message: fmt.Sprintf(`Admission check %q updated state from "Pending" to "Ready"`, baseAdmissionCheckName), + }, + }, }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { scheme := runtime.NewScheme() utilruntime.Must(kueue.AddToScheme(scheme)) - utilruntime.Must(v1alpha1.AddToScheme(scheme)) + utilruntime.Must(slice.AddToScheme(scheme)) - ctx, _ := utiltesting.ContextWithLog(t) - clientBuilder := fake.NewClientBuilder().WithScheme(scheme) - - if tc.workload != nil { - clientBuilder.WithObjects(tc.workload) - } - if tc.slice != nil { - clientBuilder.WithObjects(tc.slice) + interceptorFuncs := interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge} + if tc.interceptorFuncsCreate != nil { + interceptorFuncs.Create = tc.interceptorFuncsCreate } + ctx, _ := utiltesting.ContextWithLog(t) + clientBuilder := fake.NewClientBuilder().WithScheme(scheme). + WithStatusSubresource(&kueue.Workload{}). + WithObjects(tc.objs...). + WithInterceptorFuncs(interceptorFuncs) + kClient := clientBuilder.Build() - reconciler := NewWorkloadReconciler(kClient) + recorder := &utiltesting.EventRecorder{} + reconciler := NewWorkloadReconciler(kClient, recorder) + reconciler.clock = fakeClock _, err := reconciler.Reconcile(ctx, reconcile.Request{NamespacedName: tc.request}) - if diff := cmp.Diff(tc.wantErr, err); diff != "" { + if diff := cmp.Diff(tc.wantErr, err, cmpopts.EquateErrors()); diff != "" { t.Errorf("Error after reconcile (-want,+got):\n%s", diff) } @@ -190,7 +602,7 @@ func TestWorkloadReconciler(t *testing.T) { t.Errorf("Workloads after reconcile (-want,+got):\n%s", diff) } - slices := &v1alpha1.SliceList{} + slices := &slice.SliceList{} err = kClient.List(ctx, slices) if err != nil { t.Errorf("Error listing slices: %v", err) @@ -198,6 +610,90 @@ func TestWorkloadReconciler(t *testing.T) { if diff := cmp.Diff(tc.wantSlices, slices.Items, baseCmpOpts); diff != "" { t.Errorf("Slices after reconcile (-want,+got):\n%s", diff) } + + if diff := cmp.Diff(tc.wantEvents, recorder.RecordedEvents); diff != "" { + t.Errorf("Unexpected events (-want/+got):\n%s", diff) + } }) } } + +func TestSliceHandlerHandleEvent(t *testing.T) { + const ( + baseWlName = "wl" + baseSliceName = "slice" + ) + + type requestDuration struct { + Request reconcile.Request + Duration time.Duration + } + + cases := map[string]struct { + obj client.Object + want []requestDuration + }{ + "invalid object": { + obj: utiltesting.MakeWorkload(baseWlName, corev1.NamespaceDefault).Obj(), + }, + "has a workload that should be handled": { + obj: utiltesting.MakeSliceWrapper(baseSliceName, corev1.NamespaceDefault). + ControllerReference(kueue.SchemeGroupVersion.WithKind("Workload"), baseWlName, baseWlName). + Obj(), + want: []requestDuration{ + { + Request: reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: corev1.NamespaceDefault, + Name: baseWlName, + }, + }, + Duration: updatesBatchPeriod, + }, + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + scheme := runtime.NewScheme() + utilruntime.Must(kueue.AddToScheme(scheme)) + utilruntime.Must(slice.AddToScheme(scheme)) + utilruntime.Must(jobset.AddToScheme(scheme)) + + ctx, _ := utiltesting.ContextWithLog(t) + clientBuilder := fake.NewClientBuilder().WithScheme(scheme) + + kClient := clientBuilder.Build() + testSliceHandler := &sliceHandler{client: kClient} + + var gotRequestDurations []requestDuration + testFakePriorityQueue := &fakePriorityQueue{ + addAfter: func(item reconcile.Request, duration time.Duration) { + gotRequestDurations = append(gotRequestDurations, requestDuration{Request: item, Duration: duration}) + }, + } + + testSliceHandler.handleEvent(ctx, tc.obj, testFakePriorityQueue) + if diff := cmp.Diff(tc.want, gotRequestDurations); diff != "" { + t.Errorf("Result after handleEvent (-want,+got):\n%s", diff) + } + }) + } +} + +type fakePriorityQueue struct { + workqueue.TypedRateLimitingInterface[reconcile.Request] + addAfter func(item reconcile.Request, duration time.Duration) +} + +func (f *fakePriorityQueue) AddAfter(item reconcile.Request, duration time.Duration) { + f.addAfter(item, duration) +} + +func (f *fakePriorityQueue) Add(reconcile.Request) {} + +func (f *fakePriorityQueue) AddWithOpts(priorityqueue.AddOpts, ...reconcile.Request) {} + +func (f *fakePriorityQueue) GetWithPriority() (item reconcile.Request, priority int, shutdown bool) { + panic("GetWithPriority is not expected to be called") +} diff --git a/slice/internal/core/constants.go b/slice/internal/core/constants.go new file mode 100644 index 000000000..3db7ad422 --- /dev/null +++ b/slice/internal/core/constants.go @@ -0,0 +1,24 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +const ( + TPUTopologyAnnotation = "cloud.google.com/gke-tpu-topology" + TPUAcceleratorLabel = "cloud.google.com/gke-tpu-accelerator" + TPUBlockLabel = "cloud.google.com/gke-tpu-block" + TPUSubBlockLabel = "cloud.google.com/gke-tpu-subblock" +) diff --git a/slice/internal/core/core.go b/slice/internal/core/core.go new file mode 100644 index 000000000..fcbc27a45 --- /dev/null +++ b/slice/internal/core/core.go @@ -0,0 +1,37 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package core + +import ( + "regexp" + + corev1 "k8s.io/api/core/v1" +) + +func IsValidTPUTopology(tpuTopology string) bool { + validTopology, _ := regexp.MatchString("[0-9]+x[0-9]+x[0-9]+", tpuTopology) + return validTopology +} + +func IsValidTPUAccelerator(tpuAccelerator string) bool { + return tpuAccelerator == "tpu-v7x" +} + +func IsRelevantPodTemplateSpec(spec corev1.PodTemplateSpec) bool { + return IsValidTPUTopology(spec.Annotations[TPUTopologyAnnotation]) && + IsValidTPUAccelerator(spec.Spec.NodeSelector[TPUAcceleratorLabel]) +} diff --git a/slice/internal/util/api/api.go b/slice/internal/util/api/api.go new file mode 100644 index 000000000..2e201df14 --- /dev/null +++ b/slice/internal/util/api/api.go @@ -0,0 +1,40 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package api + +const ( + maxEventMsgSize = 1024 + maxConditionMsgSize = 32 * 1024 +) + +// TruncateEventMessage truncates a message if it hits the maxEventMessage. +func TruncateEventMessage(message string) string { + return truncateMessage(message, maxEventMsgSize) +} + +func TruncateConditionMessage(message string) string { + return truncateMessage(message, maxConditionMsgSize) +} + +// truncateMessage truncates a message if it hits the NoteLengthLimit. +func truncateMessage(message string, limit int) string { + if len(message) <= limit { + return message + } + suffix := " ..." + return message[:limit-len(suffix)] + suffix +} diff --git a/slice/internal/util/testing/client.go b/slice/internal/util/testing/client.go new file mode 100644 index 000000000..64060ac11 --- /dev/null +++ b/slice/internal/util/testing/client.go @@ -0,0 +1,92 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "context" + "fmt" + "sync" + + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +type EventRecord struct { + Key types.NamespacedName + EventType string + Reason string + Message string + // add annotations if ever needed +} + +type EventRecorder struct { + lock sync.Mutex + RecordedEvents []EventRecord +} + +var _ record.EventRecorder = (*EventRecorder)(nil) + +func (tr *EventRecorder) Event(object runtime.Object, eventType, reason, message string) { + tr.generateEvent(object, eventType, reason, message) +} + +func (tr *EventRecorder) Eventf(object runtime.Object, eventType, reason, messageFmt string, args ...any) { + tr.AnnotatedEventf(object, nil, eventType, reason, messageFmt, args...) +} + +func (tr *EventRecorder) AnnotatedEventf(targetObject runtime.Object, _ map[string]string, eventType, reason, messageFmt string, args ...any) { + tr.generateEvent(targetObject, eventType, reason, fmt.Sprintf(messageFmt, args...)) +} + +func (tr *EventRecorder) generateEvent(targetObject runtime.Object, eventType, reason, message string) { + tr.lock.Lock() + defer tr.lock.Unlock() + key := types.NamespacedName{} + if cObj, isCObj := targetObject.(client.Object); isCObj { + key = client.ObjectKeyFromObject(cObj) + } + tr.RecordedEvents = append(tr.RecordedEvents, EventRecord{ + Key: key, + EventType: eventType, + Reason: reason, + Message: message, + }) +} + +type ssaPatchAsStrategicMerge struct { + client.Patch +} + +func (*ssaPatchAsStrategicMerge) Type() types.PatchType { + return types.StrategicMergePatchType +} + +func wrapSSAPatch(patch client.Patch) client.Patch { + if patch.Type() == types.ApplyPatchType { + return &ssaPatchAsStrategicMerge{Patch: patch} + } + return patch +} + +// TreatSSAAsStrategicMerge - can be used as a SubResourcePatch interceptor function to treat SSA patches as StrategicMergePatchType. +// Note: By doing so the values set in the patch will be updated but the call will have no knowledge of FieldManagement when it +// comes to detecting conflicts between managers or removing fields that are missing from the patch. +func TreatSSAAsStrategicMerge(ctx context.Context, clnt client.Client, subResourceName string, obj client.Object, patch client.Patch, opts ...client.SubResourcePatchOption) error { + return clnt.SubResource(subResourceName).Patch(ctx, obj, wrapSSAPatch(patch), opts...) +} diff --git a/slice/internal/util/testing/error_matchers.go b/slice/internal/util/testing/error_matchers.go index 7936fcfec..aa3492dcd 100644 --- a/slice/internal/util/testing/error_matchers.go +++ b/slice/internal/util/testing/error_matchers.go @@ -67,7 +67,6 @@ func BeError(errorType ErrorType) types.GomegaMatcher { func (matcher *errorMatcher) Match(actual any) (success bool, err error) { err, ok := actual.(error) if !ok { - //nolint:staticcheck // We keep the error capitalized for consistency with built-in matchers. return false, fmt.Errorf("Error matcher expects an error. Got:\n%s", format.Object(actual, 1)) } diff --git a/slice/internal/util/testing/wrappers.go b/slice/internal/util/testing/wrappers.go index f94e8f2a6..0c7765f1f 100644 --- a/slice/internal/util/testing/wrappers.go +++ b/slice/internal/util/testing/wrappers.go @@ -17,11 +17,12 @@ limitations under the License. package testing import ( + "fmt" "time" batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/meta" + apimeta "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" @@ -111,7 +112,7 @@ func (w *WorkloadWrapper) Finished() *WorkloadWrapper { Reason: "ByTest", Message: "Finished by test", } - meta.SetStatusCondition(&w.Status.Conditions, cond) + apimeta.SetStatusCondition(&w.Status.Conditions, cond) return w } @@ -123,7 +124,7 @@ func (w *WorkloadWrapper) Evicted() *WorkloadWrapper { Reason: "ByTest", Message: "Evicted by test", } - meta.SetStatusCondition(&w.Status.Conditions, cond) + apimeta.SetStatusCondition(&w.Status.Conditions, cond) return w } @@ -132,14 +133,63 @@ func (w *WorkloadWrapper) Active(a bool) *WorkloadWrapper { return w } -// PodSetAssignments sets the PodSetAssignments for the workload. -func (w *WorkloadWrapper) PodSetAssignments(assignments ...kueue.PodSetAssignment) *WorkloadWrapper { - if w.Status.Admission == nil { - w.Status.Admission = &kueue.Admission{ - PodSetAssignments: make([]kueue.PodSetAssignment, 0, len(assignments)), +func (w *WorkloadWrapper) AdmissionCheck(admissionCheckState kueue.AdmissionCheckState) *WorkloadWrapper { + var admissionCheckStates []kueue.AdmissionCheckState + for _, acs := range w.Status.AdmissionChecks { + if acs.Name != admissionCheckState.Name { + admissionCheckStates = append(admissionCheckStates, acs) } } - w.Status.Admission.PodSetAssignments = assignments + w.Status.AdmissionChecks = append(admissionCheckStates, admissionCheckState) + return w +} + +// ReserveQuota sets workload admission and adds a "QuotaReserved" status condition +func (w *WorkloadWrapper) ReserveQuota(a *kueue.Admission, now time.Time) *WorkloadWrapper { + w.Status.Admission = a + w.Status.Conditions = []metav1.Condition{{ + Type: kueue.WorkloadQuotaReserved, + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.NewTime(now), + Reason: "AdmittedByTest", + Message: fmt.Sprintf("Admitted by ClusterQueue %s", w.Status.Admission.ClusterQueue), + }} + return w +} + +func (w *WorkloadWrapper) PodSets(podSets ...kueue.PodSet) *WorkloadWrapper { + w.Spec.PodSets = podSets + return w +} + +type PodSetWrapper struct{ kueue.PodSet } + +func MakePodSet(name kueue.PodSetReference, count int) *PodSetWrapper { + return &PodSetWrapper{ + kueue.PodSet{ + Name: name, + Count: int32(count), + }, + } +} + +func (w *PodSetWrapper) Obj() *kueue.PodSet { + return &w.PodSet +} + +func (w *PodSetWrapper) Annotation(key, value string) *PodSetWrapper { + if w.Template.Annotations == nil { + w.Template.Annotations = make(map[string]string) + } + w.Template.Annotations[key] = value + return w +} + +func (w *PodSetWrapper) NodeSelector(key, value string) *PodSetWrapper { + if w.Template.Spec.NodeSelector == nil { + w.Template.Spec.NodeSelector = make(map[string]string) + } + w.Template.Spec.NodeSelector[key] = value return w } @@ -205,6 +255,66 @@ func (s *SliceWrapper) NodeSelector(ns map[string][]string) *SliceWrapper { return s } +func (s *SliceWrapper) Ready() *SliceWrapper { + cond := metav1.Condition{ + Type: string(v1alpha1.Ready), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Now(), + Reason: "ByTest", + Message: "Ready by test", + } + apimeta.SetStatusCondition(&s.Status.Conditions, cond) + return s +} + +func (s *SliceWrapper) Forming() *SliceWrapper { + cond := metav1.Condition{ + Type: string(v1alpha1.Forming), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Now(), + Reason: "ByTest", + Message: "Forming by test", + } + apimeta.SetStatusCondition(&s.Status.Conditions, cond) + return s +} + +func (s *SliceWrapper) Deformed() *SliceWrapper { + cond := metav1.Condition{ + Type: string(v1alpha1.Deformed), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Now(), + Reason: "ByTest", + Message: "Deformed by test", + } + apimeta.SetStatusCondition(&s.Status.Conditions, cond) + return s +} + +func (s *SliceWrapper) Degraded() *SliceWrapper { + cond := metav1.Condition{ + Type: string(v1alpha1.Degraded), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Now(), + Reason: "ByTest", + Message: "Degraded by test", + } + apimeta.SetStatusCondition(&s.Status.Conditions, cond) + return s +} + +func (s *SliceWrapper) Error() *SliceWrapper { + cond := metav1.Condition{ + Type: string(v1alpha1.Error), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Now(), + Reason: "ByTest", + Message: "Error by test", + } + apimeta.SetStatusCondition(&s.Status.Conditions, cond) + return s +} + func AppendOwnerReference(obj client.Object, gvk schema.GroupVersionKind, name, uid string, controller, blockDeletion *bool) { obj.SetOwnerReferences(append(obj.GetOwnerReferences(), metav1.OwnerReference{ APIVersion: gvk.GroupVersion().String(), @@ -333,6 +443,12 @@ func (c *ClusterQueueWrapper) ResourceGroup(flavors ...kueue.FlavorQuotas) *Clus return c } +// AdmissionChecks replaces the queue additional checks +func (c *ClusterQueueWrapper) AdmissionChecks(checks ...kueue.AdmissionCheckReference) *ClusterQueueWrapper { + c.Spec.AdmissionChecks = checks + return c +} + // FlavorQuotasWrapper wraps a FlavorQuotas object. type FlavorQuotasWrapper struct{ kueue.FlavorQuotas } @@ -451,3 +567,44 @@ func (t *TopologyWrapper) Levels(levels ...string) *TopologyWrapper { func (t *TopologyWrapper) Obj() *kueuealpha.Topology { return &t.Topology } + +type AdmissionCheckWrapper struct{ kueue.AdmissionCheck } + +func MakeAdmissionCheck(name string) *AdmissionCheckWrapper { + return &AdmissionCheckWrapper{ + AdmissionCheck: kueue.AdmissionCheck{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + }, + }, + } +} + +func (ac *AdmissionCheckWrapper) Obj() *kueue.AdmissionCheck { + return &ac.AdmissionCheck +} + +func (ac *AdmissionCheckWrapper) Clone() *AdmissionCheckWrapper { + return &AdmissionCheckWrapper{AdmissionCheck: *ac.DeepCopy()} +} + +func (ac *AdmissionCheckWrapper) Name(name string) *AdmissionCheckWrapper { + ac.ObjectMeta.Name = name + return ac +} + +func (ac *AdmissionCheckWrapper) ControllerName(c string) *AdmissionCheckWrapper { + ac.Spec.ControllerName = c + return ac +} + +// Generation sets the generation of the AdmissionCheck. +func (ac *AdmissionCheckWrapper) Generation(num int64) *AdmissionCheckWrapper { + ac.ObjectMeta.Generation = num + return ac +} + +func (ac *AdmissionCheckWrapper) Condition(cond metav1.Condition) *AdmissionCheckWrapper { + apimeta.SetStatusCondition(&ac.Status.Conditions, cond) + return ac +} diff --git a/slice/internal/webhooks/jobset_webhook.go b/slice/internal/webhooks/jobset_webhook.go index cd1eea104..36c7134d6 100644 --- a/slice/internal/webhooks/jobset_webhook.go +++ b/slice/internal/webhooks/jobset_webhook.go @@ -18,7 +18,6 @@ package webhooks import ( "context" - "regexp" "strconv" "strings" @@ -29,13 +28,8 @@ import ( "sigs.k8s.io/jobset/api/jobset/v1alpha2" kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" kueueconstants "sigs.k8s.io/kueue/pkg/controller/constants" -) -const ( - TPUTopologyAnnotation = "cloud.google.com/gke-tpu-topology" - TPUAcceleratorLabel = "cloud.google.com/gke-tpu-accelerator" - TPUBlockLabel = "cloud.google.com/gke-tpu-block" - TPUSubBlockLabel = "cloud.google.com/gke-tpu-subblock" + "tpu-slice-controller/internal/core" ) // JobSetWebhook is the schema for your resource (ensure this matches your resource definition). @@ -72,11 +66,7 @@ func (r *JobSetWebhook) Default(ctx context.Context, obj runtime.Object) error { } func (r *JobSetWebhook) annotateReplicatedJobWithTopology(rj *v1alpha2.ReplicatedJob) error { - tpuTopology := rj.Template.Spec.Template.Annotations[TPUTopologyAnnotation] - tpuAccelerator := rj.Template.Spec.Template.Spec.NodeSelector[TPUAcceleratorLabel] - - validTopology, _ := regexp.MatchString("[0-9]+x[0-9]+x[0-9]+", tpuTopology) - if !validTopology || tpuAccelerator != "tpu-v7x" { + if !core.IsRelevantPodTemplateSpec(rj.Template.Spec.Template) { return nil } @@ -84,10 +74,13 @@ func (r *JobSetWebhook) annotateReplicatedJobWithTopology(rj *v1alpha2.Replicate rj.Template.Spec.Template.Annotations = make(map[string]string) } - rj.Template.Spec.Template.Annotations[kueuealpha.PodSetRequiredTopologyAnnotation] = TPUBlockLabel - rj.Template.Spec.Template.Annotations[kueuealpha.PodSetSliceRequiredTopologyAnnotation] = TPUSubBlockLabel + rj.Template.Spec.Template.Annotations[kueuealpha.PodSetRequiredTopologyAnnotation] = core.TPUBlockLabel + rj.Template.Spec.Template.Annotations[kueuealpha.PodSetSliceRequiredTopologyAnnotation] = core.TPUSubBlockLabel - size, err := r.podSetSliceSize(tpuTopology, ptr.Deref(rj.Template.Spec.Parallelism, 1)) + size, err := r.podSetSliceSize( + rj.Template.Spec.Template.Annotations[core.TPUTopologyAnnotation], + ptr.Deref(rj.Template.Spec.Parallelism, 1), + ) if err != nil { return err } diff --git a/slice/internal/webhooks/jobset_webhook_test.go b/slice/internal/webhooks/jobset_webhook_test.go index 7e7fa3195..b700ed79a 100644 --- a/slice/internal/webhooks/jobset_webhook_test.go +++ b/slice/internal/webhooks/jobset_webhook_test.go @@ -24,6 +24,7 @@ import ( jobset "sigs.k8s.io/jobset/api/jobset/v1alpha2" kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" + "tpu-slice-controller/internal/core" testingjobjobset "tpu-slice-controller/internal/util/testingjobs/jobset" "tpu-slice-controller/test/utils" ) @@ -44,10 +45,10 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "4x4x12", + core.TPUTopologyAnnotation: "4x4x12", }, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "tpu-v7x", + core.TPUAcceleratorLabel: "tpu-v7x", }, }). Obj(), @@ -56,10 +57,10 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "4x4x12", + core.TPUTopologyAnnotation: "4x4x12", }, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "tpu-v7x", + core.TPUAcceleratorLabel: "tpu-v7x", }, }). Obj(), @@ -71,7 +72,7 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "tpu-v7x", + core.TPUAcceleratorLabel: "tpu-v7x", }, }). Obj(), @@ -81,7 +82,7 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "tpu-v7x", + core.TPUAcceleratorLabel: "tpu-v7x", }, }). Obj(), @@ -93,7 +94,7 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "4x4x12", + core.TPUTopologyAnnotation: "4x4x12", }, }). Obj(), @@ -103,7 +104,7 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "4x4x12", + core.TPUTopologyAnnotation: "4x4x12", }, }). Obj(), @@ -115,10 +116,10 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "4x4x12", + core.TPUTopologyAnnotation: "4x4x12", }, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "tpu-v7x", + core.TPUAcceleratorLabel: "tpu-v7x", }, }). Obj(), @@ -128,13 +129,13 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "4x4x12", - kueuealpha.PodSetRequiredTopologyAnnotation: TPUBlockLabel, - kueuealpha.PodSetSliceRequiredTopologyAnnotation: TPUSubBlockLabel, + core.TPUTopologyAnnotation: "4x4x12", + kueuealpha.PodSetRequiredTopologyAnnotation: core.TPUBlockLabel, + kueuealpha.PodSetSliceRequiredTopologyAnnotation: core.TPUSubBlockLabel, kueuealpha.PodSetSliceSizeAnnotation: "4", }, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "tpu-v7x", + core.TPUAcceleratorLabel: "tpu-v7x", }, }). Obj(), @@ -146,10 +147,10 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "invalid", + core.TPUTopologyAnnotation: "invalid", }, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "tpu-v7x", + core.TPUAcceleratorLabel: "tpu-v7x", }, }). Obj(), @@ -159,10 +160,10 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "invalid", + core.TPUTopologyAnnotation: "invalid", }, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "tpu-v7x", + core.TPUAcceleratorLabel: "tpu-v7x", }, }). Obj(), @@ -174,10 +175,10 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "4x4x12", + core.TPUTopologyAnnotation: "4x4x12", }, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "test", + core.TPUAcceleratorLabel: "test", }, }). Obj(), @@ -187,10 +188,10 @@ func TestDefault(t *testing.T) { Name: "rj1", Parallelism: 12, PodAnnotations: map[string]string{ - TPUTopologyAnnotation: "4x4x12", + core.TPUTopologyAnnotation: "4x4x12", }, NodeSelector: map[string]string{ - TPUAcceleratorLabel: "test", + core.TPUAcceleratorLabel: "test", }, }). Obj(), diff --git a/slice/test/e2e/jobset_test.go b/slice/test/e2e/jobset_test.go index cf81f6816..b6481dea9 100644 --- a/slice/test/e2e/jobset_test.go +++ b/slice/test/e2e/jobset_test.go @@ -23,6 +23,8 @@ import ( "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" @@ -30,12 +32,13 @@ import ( kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" jobsetcontroller "sigs.k8s.io/kueue/pkg/controller/jobs/jobset" + "sigs.k8s.io/kueue/pkg/workload" slice "tpu-slice-controller/api/v1alpha1" "tpu-slice-controller/internal/controller" + "tpu-slice-controller/internal/core" "tpu-slice-controller/internal/util/testing" testingjobsjobset "tpu-slice-controller/internal/util/testingjobs/jobset" - "tpu-slice-controller/internal/webhooks" "tpu-slice-controller/test/utils" ) @@ -52,6 +55,7 @@ var _ = ginkgo.Describe("JobSet", func() { topology *kueuealpha.Topology ns *corev1.Namespace rf *kueue.ResourceFlavor + ac *kueue.AdmissionCheck cq *kueue.ClusterQueue lq *kueue.LocalQueue ) @@ -61,7 +65,7 @@ var _ = ginkgo.Describe("JobSet", func() { utils.MustCreate(ctx, k8sClient, ns) topology = testing.MakeTopology("topology"). - Levels(webhooks.TPUBlockLabel, webhooks.TPUSubBlockLabel). + Levels(core.TPUBlockLabel, core.TPUSubBlockLabel). Obj() utils.MustCreate(ctx, k8sClient, topology) @@ -71,7 +75,11 @@ var _ = ginkgo.Describe("JobSet", func() { Obj() utils.MustCreate(ctx, k8sClient, rf) + ac = testing.MakeAdmissionCheck("ac").ControllerName(controller.SliceControllerName).Obj() + utils.MustCreate(ctx, k8sClient, ac) + cq = testing.MakeClusterQueue("cq"). + AdmissionChecks(kueue.AdmissionCheckReference(ac.Name)). ResourceGroup(*testing.MakeFlavorQuotas(rf.Name). Resource(extraResource, "128"). Obj()). @@ -85,6 +93,7 @@ var _ = ginkgo.Describe("JobSet", func() { ginkgo.AfterEach(func() { gomega.Expect(utils.DeleteNamespace(ctx, k8sClient, ns)).To(gomega.Succeed()) utils.ExpectObjectToBeDeleted(ctx, k8sClient, cq, true) + utils.ExpectObjectToBeDeleted(ctx, k8sClient, ac, true) utils.ExpectObjectToBeDeleted(ctx, k8sClient, rf, true) utils.ExpectObjectToBeDeleted(ctx, k8sClient, topology, true) utils.ExpectAllPodsInNamespaceDeleted(ctx, k8sClient, ns) @@ -111,10 +120,10 @@ var _ = ginkgo.Describe("JobSet", func() { Parallelism: tc.parallelism, Completions: tc.parallelism, PodAnnotations: map[string]string{ - webhooks.TPUTopologyAnnotation: tc.tpuTopology, + core.TPUTopologyAnnotation: tc.tpuTopology, }, NodeSelector: map[string]string{ - webhooks.TPUAcceleratorLabel: tpuAccelerator, + core.TPUAcceleratorLabel: tpuAccelerator, }, }, ). @@ -134,9 +143,9 @@ var _ = ginkgo.Describe("JobSet", func() { for _, replicatedJob := range createdJobSet.Spec.ReplicatedJobs { annotations := replicatedJob.Template.Spec.Template.Annotations g.Expect(annotations[kueuealpha.PodSetRequiredTopologyAnnotation]). - Should(gomega.Equal(webhooks.TPUBlockLabel)) + Should(gomega.Equal(core.TPUBlockLabel)) g.Expect(annotations[kueuealpha.PodSetSliceRequiredTopologyAnnotation]). - Should(gomega.Equal(webhooks.TPUSubBlockLabel)) + Should(gomega.Equal(core.TPUSubBlockLabel)) g.Expect(annotations[kueuealpha.PodSetSliceSizeAnnotation]). Should(gomega.Equal(fmt.Sprint(tc.wantSliceSize))) } @@ -154,8 +163,8 @@ var _ = ginkgo.Describe("JobSet", func() { g.Expect(k8sClient.Get(ctx, wlKey, createdWorkload)).To(gomega.Succeed()) g.Expect(createdWorkload.Spec.PodSets).To(gomega.HaveLen(1)) g.Expect(createdWorkload.Spec.PodSets[0].TopologyRequest).To(gomega.BeComparableTo(&kueue.PodSetTopologyRequest{ - Required: ptr.To(webhooks.TPUBlockLabel), - PodSetSliceRequiredTopology: ptr.To(webhooks.TPUSubBlockLabel), + Required: ptr.To(core.TPUBlockLabel), + PodSetSliceRequiredTopology: ptr.To(core.TPUSubBlockLabel), SubGroupCount: ptr.To[int32](2), PodSetSliceSize: ptr.To[int32](tc.wantSliceSize), }, ignorePodSetTopologyRequestFields)) @@ -174,7 +183,7 @@ var _ = ginkgo.Describe("JobSet", func() { gomega.Expect(createdWorkload.Status.Admission.PodSetAssignments).Should(gomega.HaveLen(1)) gomega.Expect(createdWorkload.Status.Admission.PodSetAssignments[0].TopologyAssignment).Should(gomega.BeComparableTo( &kueue.TopologyAssignment{ - Levels: []string{webhooks.TPUBlockLabel, webhooks.TPUSubBlockLabel}, + Levels: []string{core.TPUBlockLabel, core.TPUSubBlockLabel}, Domains: tc.wantDomains, }, )) @@ -194,6 +203,74 @@ var _ = ginkgo.Describe("JobSet", func() { }, utils.Timeout, utils.Interval).Should(gomega.Succeed()) }) + ginkgo.By("Checking that the Workload waiting for admission", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, wlKey, createdWorkload)).Should(gomega.Succeed()) + g.Expect(workload.IsAdmitted(createdWorkload)).Should(gomega.BeFalse()) + g.Expect(createdWorkload.Status.AdmissionChecks).Should(gomega.BeComparableTo([]kueue.AdmissionCheckState{{ + Name: kueue.AdmissionCheckReference(ac.Name), + State: kueue.CheckStatePending, + Message: fmt.Sprintf("The Slice %q has been created", createdSlice.Name), + }}, cmpopts.IgnoreFields(kueue.AdmissionCheckState{}, "LastTransitionTime", "PodSetUpdates"))) + }, utils.Timeout, utils.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Adding Forming condition", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, sliceKey, createdSlice)).To(gomega.Succeed()) + meta.SetStatusCondition(&createdSlice.Status.Conditions, metav1.Condition{ + Type: string(slice.Forming), + Status: metav1.ConditionTrue, + Reason: "Test", + Message: "Test", + }) + g.Expect(k8sClient.Status().Update(ctx, createdSlice)).To(gomega.Succeed()) + }, utils.Timeout, utils.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Checking that the Workload still waiting for admission", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, wlKey, createdWorkload)).Should(gomega.Succeed()) + g.Expect(workload.IsAdmitted(createdWorkload)).Should(gomega.BeFalse()) + g.Expect(createdWorkload.Status.AdmissionChecks).Should(gomega.BeComparableTo([]kueue.AdmissionCheckState{{ + Name: kueue.AdmissionCheckReference(ac.Name), + State: kueue.CheckStatePending, + Message: fmt.Sprintf("The Slice %q is being formed", createdSlice.Name), + }}, cmpopts.IgnoreFields(kueue.AdmissionCheckState{}, "LastTransitionTime", "PodSetUpdates"))) + }, utils.LongTimeout, utils.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Adding Ready condition", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, sliceKey, createdSlice)).To(gomega.Succeed()) + meta.SetStatusCondition(&createdSlice.Status.Conditions, metav1.Condition{ + Type: string(slice.Forming), + Status: metav1.ConditionFalse, + Reason: "Test", + Message: "Test", + }) + meta.SetStatusCondition(&createdSlice.Status.Conditions, metav1.Condition{ + Type: string(slice.Ready), + Status: metav1.ConditionTrue, + Reason: "Test", + Message: "Test", + }) + g.Expect(k8sClient.Status().Update(ctx, createdSlice)).To(gomega.Succeed()) + }, utils.Timeout, utils.Interval).Should(gomega.Succeed()) + }) + + ginkgo.By("Checking that the Workload is admitted and admission check status is ready", func() { + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, wlKey, createdWorkload)).Should(gomega.Succeed()) + g.Expect(workload.IsAdmitted(createdWorkload)).Should(gomega.BeTrue()) + g.Expect(createdWorkload.Status.AdmissionChecks).Should(gomega.BeComparableTo([]kueue.AdmissionCheckState{{ + Name: kueue.AdmissionCheckReference(ac.Name), + State: kueue.CheckStateReady, + Message: fmt.Sprintf("The Slice %q is fully operational", createdWorkload.Name), + }}, cmpopts.IgnoreFields(kueue.AdmissionCheckState{}, "LastTransitionTime", "PodSetUpdates"))) + }, utils.LongTimeout, utils.Timeout).Should(gomega.Succeed()) + }) + ginkgo.By("Deleting JobSet", func() { utils.ExpectObjectToBeDeleted(ctx, k8sClient, jobSet, true) }) diff --git a/slice/test/e2e/manager.go b/slice/test/e2e/manager_test.go similarity index 100% rename from slice/test/e2e/manager.go rename to slice/test/e2e/manager_test.go diff --git a/slice/test/utils/constants.go b/slice/test/utils/constants.go index ba9e5f88b..3e8b95007 100644 --- a/slice/test/utils/constants.go +++ b/slice/test/utils/constants.go @@ -26,10 +26,12 @@ const ( ) const ( - Timeout = 10 * time.Second - LongTimeout = 45 * time.Second - StartUpTimeout = 5 * time.Minute - Interval = time.Millisecond * 250 + Timeout = 10 * time.Second + LongTimeout = 45 * time.Second + StartUpTimeout = 5 * time.Minute + ConsistentDuration = time.Second + ShortInterval = 10 * time.Millisecond + Interval = time.Millisecond * 250 ) var (