@@ -18,9 +18,15 @@ package controller
18
18
19
19
import (
20
20
"context"
21
+ "errors"
22
+ "fmt"
23
+ "maps"
21
24
"slices"
22
25
26
+ apierrors "k8s.io/apimachinery/pkg/api/errors"
23
27
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28
+ "k8s.io/apimachinery/pkg/util/sets"
29
+ "k8s.io/klog/v2"
24
30
ctrl "sigs.k8s.io/controller-runtime"
25
31
"sigs.k8s.io/controller-runtime/pkg/client"
26
32
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
@@ -33,7 +39,18 @@ import (
33
39
34
40
const (
35
41
CleanupSliceFinalizerName = "accelerator.gke.io/slice"
36
- TPUReservationSubblockLabel = "cloud.google.com/gke-tpu-reservation-subblock"
42
+ TPUReservationSubBlockLabel = "cloud.google.com/gke-tpu-reservation-subblock"
43
+ NodePoolLabel = "cloud.google.com/gke-nodepool"
44
+ TPUTopologyLabel = "cloud.google.com/gke-tpu-topology"
45
+ TPUAcceleratorLabel = "cloud.google.com/gke-tpu-accelerator"
46
+ )
47
+
48
+ var (
49
+ errPodSetNotFound = errors .New ("PodSet not found" )
50
+ errPodSetAssignmentNotFound = errors .New ("PodSetAssignment not found" )
51
+ errTPUTopologyLabelNotFound = fmt .Errorf ("%s label not found" , TPUTopologyLabel )
52
+ errTPUAcceleratorLabelNotFound = fmt .Errorf ("%s label not found" , TPUAcceleratorLabel )
53
+ errTopologyAssignmentNotFound = errors .New ("TopologyAssignment not found" )
37
54
)
38
55
39
56
// WorkloadReconciler reconciles a Workload object
@@ -64,18 +81,7 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
64
81
log .V (2 ).Info ("Reconcile Workload" )
65
82
66
83
if r .shouldFinalize (wl ) {
67
- if controllerutil .ContainsFinalizer (wl , CleanupSliceFinalizerName ) {
68
- err = r .client .Delete (ctx , r .newEmptySlice (wl ))
69
- if client .IgnoreNotFound (err ) != nil {
70
- return ctrl.Result {}, err
71
- }
72
- controllerutil .RemoveFinalizer (wl , CleanupSliceFinalizerName )
73
- if err := r .client .Update (ctx , wl ); err != nil {
74
- return ctrl.Result {}, err
75
- }
76
- log .V (5 ).Info ("Removed finalizer" )
77
- }
78
- return ctrl.Result {}, nil
84
+ return ctrl.Result {}, client .IgnoreNotFound (r .finalize (ctx , wl ))
79
85
}
80
86
81
87
if controllerutil .AddFinalizer (wl , CleanupSliceFinalizerName ) {
@@ -85,75 +91,212 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
85
91
}
86
92
}
87
93
88
- return ctrl.Result {}, r .createSliceIfNotExist (ctx , wl )
94
+ return ctrl.Result {}, r .createSlicesIfNotExist (ctx , wl )
89
95
}
90
96
91
97
func (r * WorkloadReconciler ) shouldFinalize (wl * kueue.Workload ) bool {
92
98
return ! wl .DeletionTimestamp .IsZero () || workload .IsFinished (wl ) || workload .IsEvicted (wl ) || ! workload .IsActive (wl )
93
99
}
94
100
95
- func (r * WorkloadReconciler ) newEmptySlice (wl * kueue.Workload ) * v1alpha1.Slice {
96
- return & v1alpha1.Slice {
97
- ObjectMeta : metav1.ObjectMeta {
98
- Name : wl .Name ,
99
- Namespace : wl .Namespace ,
100
- },
101
+ func (r * WorkloadReconciler ) finalize (ctx context.Context , wl * kueue.Workload ) error {
102
+ if ! controllerutil .ContainsFinalizer (wl , CleanupSliceFinalizerName ) {
103
+ return nil
104
+ }
105
+
106
+ log := ctrl .LoggerFrom (ctx )
107
+
108
+ slices , err := r .findWorkloadSlices (ctx , wl )
109
+ if err != nil {
110
+ log .Error (err , "Failed to find Slices" )
111
+ return err
101
112
}
102
- }
103
113
104
- func (r * WorkloadReconciler ) newSlice (wl * kueue.Workload ) (* v1alpha1.Slice , error ) {
105
- slice := r .newEmptySlice (wl )
114
+ for _ , slice := range slices {
115
+ err = r .client .Delete (ctx , & slice )
116
+ if client .IgnoreNotFound (err ) != nil {
117
+ log .Error (err , "Failed to delete the Slice" , "slice" , klog .KObj (& slice ))
118
+ return err
119
+ }
120
+ }
121
+
122
+ controllerutil .RemoveFinalizer (wl , CleanupSliceFinalizerName )
123
+ if err := r .client .Update (ctx , wl ); err != nil {
124
+ if ! apierrors .IsNotFound (err ) {
125
+ log .Error (err , "Failed to remove finalizer" )
126
+ }
127
+ return err
128
+ }
106
129
107
- if err := controllerutil .SetControllerReference (wl , slice , r .client .Scheme ()); err != nil {
130
+ log .V (5 ).Info ("Removed finalizer" )
131
+
132
+ return nil
133
+ }
134
+
135
+ func (r * WorkloadReconciler ) findWorkloadSlices (ctx context.Context , wl * kueue.Workload ) ([]v1alpha1.Slice , error ) {
136
+ slices := & v1alpha1.SliceList {}
137
+ opts := []client.ListOption {
138
+ client .InNamespace (wl .Namespace ),
139
+ client.MatchingFields {OwnerReferenceUID : string (wl .UID )},
140
+ }
141
+ if err := r .client .List (ctx , slices , opts ... ); err != nil {
108
142
return nil , err
109
143
}
144
+ return slices .Items , nil
145
+ }
146
+
147
+ func (r * WorkloadReconciler ) createSlicesIfNotExist (ctx context.Context , wl * kueue.Workload ) error {
148
+ log := ctrl .LoggerFrom (ctx )
149
+
150
+ createdSlices , err := r .findWorkloadSlices (ctx , wl )
151
+ if err != nil {
152
+ log .Error (err , "Failed to find Slices" )
153
+ return err
154
+ }
110
155
111
- if wl .Status .Admission != nil && wl .Status .Admission .PodSetAssignments != nil {
156
+ createdSlicesByName := make (map [string ]* v1alpha1.Slice , len (createdSlices ))
157
+ for _ , slice := range createdSlices {
158
+ createdSlicesByName [slice .Name ] = & slice
159
+ }
160
+
161
+ var slicesToCreate []* v1alpha1.Slice
162
+
163
+ if wl .Status .Admission != nil {
112
164
for _ , psa := range wl .Status .Admission .PodSetAssignments {
113
- for _ , domain := range psa .TopologyAssignment .Domains {
114
- if slice .Spec .NodeSelector == nil {
115
- slice .Spec .NodeSelector = make (map [string ][]string )
116
- }
117
- if slice .Spec .NodeSelector [TPUReservationSubblockLabel ] == nil {
118
- slice .Spec .NodeSelector [TPUReservationSubblockLabel ] = []string {}
119
- }
120
- // make sure there are no duplicates in the nodeSelector
121
- for _ , v := range domain .Values {
122
- exists := slices .Contains (slice .Spec .NodeSelector [TPUReservationSubblockLabel ], v )
123
- if ! exists {
124
- slice .Spec .NodeSelector [TPUReservationSubblockLabel ] = append (slice .Spec .NodeSelector [TPUReservationSubblockLabel ], v )
125
- }
165
+ sliceName := GetSliceName (wl .Name , psa .Name )
166
+
167
+ if _ , ok := createdSlicesByName [sliceName ]; ok {
168
+ delete (createdSlicesByName , sliceName )
169
+ continue
170
+ }
171
+
172
+ slice , err := newSlice (wl , psa .Name )
173
+ if err != nil {
174
+ if ! isUnsupportedPodSetError (err ) {
175
+ log .Error (err , "Failed to create a Slice" )
176
+ return err
126
177
}
178
+ log .V (8 ).Info ("Failed to create Slice" , "error" , err )
179
+ continue
180
+ }
181
+
182
+ if err := controllerutil .SetControllerReference (wl , slice , r .client .Scheme ()); err != nil {
183
+ return err
127
184
}
185
+
186
+ slicesToCreate = append (slicesToCreate , slice )
128
187
}
129
188
}
130
189
131
- return slice , nil
190
+ for _ , slice := range slicesToCreate {
191
+ err = r .client .Create (ctx , slice )
192
+ if err != nil {
193
+ log .Error (err , "Failed to create a Slice" , "slice" , klog .KObj (slice ))
194
+ return err
195
+ }
196
+ }
197
+
198
+ for _ , slice := range slices .Collect (maps .Values (createdSlicesByName )) {
199
+ err = r .client .Delete (ctx , slice )
200
+ if client .IgnoreNotFound (err ) != nil {
201
+ log .Error (err , "Failed to delete the redundant Slice" , "slice" , klog .KObj (slice ))
202
+ return err
203
+ }
204
+ }
205
+
206
+ return nil
132
207
}
133
208
134
- func (r * WorkloadReconciler ) createSliceIfNotExist (ctx context.Context , wl * kueue.Workload ) error {
135
- slice := r .newEmptySlice (wl )
209
+ func GetSliceName (workloadName string , podSetName kueue.PodSetReference ) string {
210
+ return fmt .Sprintf ("%s-%s" , workloadName , podSetName )
211
+ }
136
212
137
- err := r .client .Get (ctx , client .ObjectKeyFromObject (slice ), slice )
138
- if client .IgnoreNotFound (err ) != nil {
139
- return err
213
+ func newSlice (wl * kueue.Workload , podSetName kueue.PodSetReference ) (* v1alpha1.Slice , error ) {
214
+ ps := findPodSet (wl , podSetName )
215
+ if findPodSet (wl , podSetName ) == nil {
216
+ return nil , errPodSetNotFound
140
217
}
141
- if err == nil {
142
- return nil
218
+
219
+ if ps .Template .Spec .NodeSelector [TPUTopologyLabel ] == "" {
220
+ return nil , errTPUTopologyLabelNotFound
143
221
}
144
222
145
- slice , err = r .newSlice (wl )
146
- if err != nil {
147
- return err
223
+ if ps .Template .Spec .NodeSelector [TPUAcceleratorLabel ] == "" {
224
+ return nil , errTPUAcceleratorLabelNotFound
148
225
}
149
226
150
- return r .client .Create (ctx , slice )
227
+ psa := findPodSetAssignment (wl , podSetName )
228
+ if psa == nil {
229
+ return nil , errPodSetAssignmentNotFound
230
+ }
231
+
232
+ if psa .TopologyAssignment == nil {
233
+ return nil , errTopologyAssignmentNotFound
234
+ }
235
+
236
+ slice := & v1alpha1.Slice {
237
+ ObjectMeta : metav1.ObjectMeta {
238
+ Name : GetSliceName (wl .Name , podSetName ),
239
+ Namespace : wl .Namespace ,
240
+ },
241
+ Spec : v1alpha1.SliceSpec {
242
+ AcceleratorTopology : ps .Template .Spec .NodeSelector [TPUTopologyLabel ],
243
+ AcceleratorType : ps .Template .Spec .NodeSelector [TPUAcceleratorLabel ],
244
+ NodeSelector : make (map [string ][]string ),
245
+ },
246
+ }
247
+
248
+ for _ , domain := range psa .TopologyAssignment .Domains {
249
+ if ps .Template .Spec .NodeSelector [TPUReservationSubBlockLabel ] != "" {
250
+ subBlockDomains := sets .New [string ]()
251
+ for _ , v := range domain .Values {
252
+ subBlockDomains .Insert (v )
253
+ }
254
+ if subBlockDomains .Len () > 0 {
255
+ slice .Spec .NodeSelector [TPUReservationSubBlockLabel ] = sets .List (subBlockDomains )
256
+ }
257
+ }
258
+ if ps .Template .Spec .NodeSelector [NodePoolLabel ] != "" {
259
+ nodePoolDomains := sets .New [string ]()
260
+ for _ , v := range domain .Values {
261
+ nodePoolDomains .Insert (v )
262
+ }
263
+ if nodePoolDomains .Len () > 0 {
264
+ slice .Spec .NodeSelector [NodePoolLabel ] = sets .List (nodePoolDomains )
265
+ }
266
+ }
267
+ }
268
+
269
+ return slice , nil
270
+ }
271
+
272
+ func findPodSet (wl * kueue.Workload , podSetName kueue.PodSetReference ) * kueue.PodSet {
273
+ for _ , ps := range wl .Spec .PodSets {
274
+ if ps .Name == podSetName {
275
+ return & ps
276
+ }
277
+ }
278
+ return nil
279
+ }
280
+
281
+ func findPodSetAssignment (wl * kueue.Workload , podSetName kueue.PodSetReference ) * kueue.PodSetAssignment {
282
+ for _ , psa := range wl .Status .Admission .PodSetAssignments {
283
+ if psa .Name == podSetName {
284
+ return & psa
285
+ }
286
+ }
287
+ return nil
288
+ }
289
+
290
+ func isUnsupportedPodSetError (err error ) bool {
291
+ return errors .Is (err , errTPUTopologyLabelNotFound ) ||
292
+ errors .Is (err , errTPUAcceleratorLabelNotFound ) ||
293
+ errors .Is (err , errTopologyAssignmentNotFound )
151
294
}
152
295
153
296
// SetupWithManager sets up the controller with the Manager.
154
297
func (r * WorkloadReconciler ) SetupWithManager (mgr ctrl.Manager ) error {
155
298
return ctrl .NewControllerManagedBy (mgr ).
156
299
For (& kueue.Workload {}).
157
- Named ("workload " ).
300
+ Named ("workload_controller " ).
158
301
Complete (r )
159
302
}
0 commit comments