Skip to content

Commit

Permalink
fix: propagate priority-class label for deploy and statefulset (#3980)
Browse files Browse the repository at this point in the history
Signed-off-by: Abirdcfly <[email protected]>
  • Loading branch information
Abirdcfly authored Jan 22, 2025
1 parent 6d9f745 commit 94a647f
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 9 deletions.
3 changes: 1 addition & 2 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ func MaximumExecutionTimeSeconds(job GenericJob) *int32 {
return ptr.To(int32(v))
}

func workloadPriorityClassName(job GenericJob) string {
object := job.Object()
func WorkloadPriorityClassName(object client.Object) string {
if workloadPriorityClassLabel := object.GetLabels()[constants.WorkloadPriorityClassLabel]; workloadPriorityClassLabel != "" {
return workloadPriorityClassLabel
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ func (r *JobReconciler) prepareWorkload(ctx context.Context, job GenericJob, wl
}

func (r *JobReconciler) extractPriority(ctx context.Context, podSets []kueue.PodSet, job GenericJob) (string, string, int32, error) {
if workloadPriorityClass := workloadPriorityClassName(job); len(workloadPriorityClass) > 0 {
if workloadPriorityClass := WorkloadPriorityClassName(job.Object()); len(workloadPriorityClass) > 0 {
return utilpriority.GetPriorityFromWorkloadPriorityClass(ctx, r.client, workloadPriorityClass)
}
if jobWithPriorityClass, isImplemented := job.(JobWithPriorityClass); isImplemented {
Expand Down
6 changes: 3 additions & 3 deletions pkg/controller/jobframework/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func ValidateJobOnCreate(job GenericJob) field.ErrorList {
// ValidateJobOnUpdate encapsulates all GenericJob validations that must be performed on a Update operation
func ValidateJobOnUpdate(oldJob, newJob GenericJob) field.ErrorList {
allErrs := validateUpdateForQueueName(oldJob, newJob)
allErrs = append(allErrs, validateUpdateForWorkloadPriorityClassName(oldJob, newJob)...)
allErrs = append(allErrs, ValidateUpdateForWorkloadPriorityClassName(oldJob.Object(), newJob.Object())...)
allErrs = append(allErrs, validateUpdateForMaxExecTime(oldJob, newJob)...)
return allErrs
}
Expand Down Expand Up @@ -123,8 +123,8 @@ func validateUpdateForQueueName(oldJob, newJob GenericJob) field.ErrorList {
return allErrs
}

func validateUpdateForWorkloadPriorityClassName(oldJob, newJob GenericJob) field.ErrorList {
allErrs := apivalidation.ValidateImmutableField(workloadPriorityClassName(newJob), workloadPriorityClassName(oldJob), workloadPriorityClassNamePath)
func ValidateUpdateForWorkloadPriorityClassName(oldObj, newObj client.Object) field.ErrorList {
allErrs := apivalidation.ValidateImmutableField(WorkloadPriorityClassName(newObj), WorkloadPriorityClassName(oldObj), workloadPriorityClassNamePath)
return allErrs
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/controller/jobs/deployment/deployment_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error {
if queueName != "" {
deployment.Spec.Template.Labels[constants.QueueLabel] = queueName
}
if priorityClass := jobframework.WorkloadPriorityClassName(deployment.Object()); priorityClass != "" {
deployment.Spec.Template.Labels[constants.WorkloadPriorityClassLabel] = priorityClass
}
}

return nil
Expand Down Expand Up @@ -122,6 +125,7 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob

allErrs := field.ErrorList{}
allErrs = append(allErrs, jobframework.ValidateQueueName(newDeployment.Object())...)
allErrs = append(allErrs, jobframework.ValidateUpdateForWorkloadPriorityClassName(oldDeployment.Object(), newDeployment.Object())...)

// Prevents updating the queue-name if at least one Pod is not suspended
// or if the queue-name has been deleted.
Expand Down
55 changes: 55 additions & 0 deletions pkg/controller/jobs/deployment/deployment_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

"sigs.k8s.io/kueue/pkg/cache"
"sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/controller/jobs/pod"
"sigs.k8s.io/kueue/pkg/features"
Expand Down Expand Up @@ -97,6 +98,44 @@ func TestDefault(t *testing.T) {
want: testingdeployment.MakeDeployment("test-pod", "").
Obj(),
},
"deployment with queue and priority class": {
deployment: testingdeployment.MakeDeployment("test-pod", "").
Queue("test-queue").
Label(constants.WorkloadPriorityClassLabel, "test").
Obj(),
want: testingdeployment.MakeDeployment("test-pod", "").
Queue("test-queue").
Label(constants.WorkloadPriorityClassLabel, "test").
PodTemplateSpecQueue("test-queue").
PodTemplateAnnotation(pod.SuspendedByParentAnnotation, FrameworkName).
PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test").
Obj(),
},
"deployment with queue, priority class and pod template spec queue, priority class": {
deployment: testingdeployment.MakeDeployment("test-pod", "").
Queue("new-test-queue").
Label(constants.WorkloadPriorityClassLabel, "new-test").
PodTemplateSpecQueue("test-queue").
PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test").
Obj(),
want: testingdeployment.MakeDeployment("test-pod", "").
Queue("new-test-queue").
Label(constants.WorkloadPriorityClassLabel, "new-test").
PodTemplateSpecQueue("new-test-queue").
PodTemplateAnnotation(pod.SuspendedByParentAnnotation, FrameworkName).
PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "new-test").
Obj(),
},
"deployment without queue with pod template spec queue and priority class": {
deployment: testingdeployment.MakeDeployment("test-pod", "").
PodTemplateSpecQueue("test-queue").
PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test").
Obj(),
want: testingdeployment.MakeDeployment("test-pod", "").
PodTemplateSpecQueue("test-queue").
PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test").
Obj(),
},
}

for name, tc := range testCases {
Expand Down Expand Up @@ -246,6 +285,22 @@ func TestValidateUpdate(t *testing.T) {
},
}.ToAggregate(),
},
"update priority-class": {
oldDeployment: testingdeployment.MakeDeployment("test-pod", "").
Queue("test-queue").
Label(constants.WorkloadPriorityClassLabel, "test").
Obj(),
newDeployment: testingdeployment.MakeDeployment("test-pod", "").
Queue("test-queue").
Label(constants.WorkloadPriorityClassLabel, "new-test").
Obj(),
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: "metadata.labels[kueue.x-k8s.io/priority-class]",
},
}.ToAggregate(),
},
}

for name, tc := range testCases {
Expand Down
14 changes: 11 additions & 3 deletions pkg/controller/jobs/statefulset/statefulset_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error {
ss.Spec.Template.Annotations[pod.GroupServingAnnotation] = "true"
ss.Spec.Template.Annotations[kueuealpha.PodGroupPodIndexLabelAnnotation] = appsv1.PodIndexLabel
}
if priorityClass := jobframework.WorkloadPriorityClassName(ss.Object()); priorityClass != "" {
ss.Spec.Template.Labels[constants.WorkloadPriorityClassLabel] = priorityClass
}
}

return nil
Expand All @@ -112,9 +115,10 @@ func (wh *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (warn
}

var (
labelsPath = field.NewPath("metadata", "labels")
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
replicasPath = field.NewPath("spec", "replicas")
labelsPath = field.NewPath("metadata", "labels")
queueNameLabelPath = labelsPath.Key(constants.QueueLabel)
replicasPath = field.NewPath("spec", "replicas")
priorityClassNameLabelPath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
)

func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (warnings admission.Warnings, err error) {
Expand All @@ -134,6 +138,10 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob
if oldStatefulSet.Status.ReadyReplicas > 0 || newQueueName == "" {
allErrs = append(allErrs, apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)...)
}
allErrs = append(allErrs, jobframework.ValidateUpdateForWorkloadPriorityClassName(
oldStatefulSet.Object(),
newStatefulSet.Object(),
)...)

if jobframework.IsManagedByKueue(newStatefulSet.Object()) {
oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1)
Expand Down
45 changes: 45 additions & 0 deletions pkg/controller/jobs/statefulset/statefulset_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,27 @@ func TestDefault(t *testing.T) {
PodTemplateSpecPodGroupPodIndexLabelAnnotation(appsv1.PodIndexLabel).
Obj(),
},
"statefulset with queue and priority class": {
enableIntegrations: []string{"pod"},
statefulset: testingstatefulset.MakeStatefulSet("test-pod", "").
Replicas(10).
Queue("test-queue").
Label(constants.WorkloadPriorityClassLabel, "test").
Obj(),
want: testingstatefulset.MakeStatefulSet("test-pod", "").
Replicas(10).
Queue("test-queue").
Label(constants.WorkloadPriorityClassLabel, "test").
PodTemplateSpecQueue("test-queue").
PodTemplateAnnotation(pod.SuspendedByParentAnnotation, FrameworkName).
PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test").
PodTemplateSpecPodGroupNameLabel("test-pod", "", gvk).
PodTemplateSpecPodGroupTotalCountAnnotation(10).
PodTemplateSpecPodGroupFastAdmissionAnnotation(true).
PodTemplateSpecPodGroupServingAnnotation(true).
PodTemplateSpecPodGroupPodIndexLabelAnnotation(appsv1.PodIndexLabel).
Obj(),
},
"statefulset without replicas": {
enableIntegrations: []string{"pod"},
statefulset: testingstatefulset.MakeStatefulSet("test-pod", "").
Expand Down Expand Up @@ -307,6 +328,30 @@ func TestValidateUpdate(t *testing.T) {
},
}.ToAggregate(),
},
"change in priority class label": {
oldObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
constants.WorkloadPriorityClassLabel: "priority1",
},
},
},
newObj: &appsv1.StatefulSet{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
constants.QueueLabel: "queue1",
constants.WorkloadPriorityClassLabel: "priority2",
},
},
},
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: priorityClassNameLabelPath.String(),
},
}.ToAggregate(),
},
"change in replicas (scale down to zero)": {
oldObj: &appsv1.StatefulSet{
Spec: appsv1.StatefulSetSpec{
Expand Down

0 comments on commit 94a647f

Please sign in to comment.