From 04d6fe9a97882608c5c638b88f50dc95ee68decb Mon Sep 17 00:00:00 2001 From: Abirdcfly Date: Thu, 16 Jan 2025 14:19:47 +0800 Subject: [PATCH] fix: propagate priority-class label for deploy and statefulset Signed-off-by: Abirdcfly --- pkg/controller/jobframework/interface.go | 3 +- pkg/controller/jobframework/reconciler.go | 2 +- pkg/controller/jobframework/validation.go | 6 +- .../jobs/deployment/deployment_webhook.go | 4 ++ .../deployment/deployment_webhook_test.go | 55 +++++++++++++++++++ .../jobs/statefulset/statefulset_webhook.go | 14 ++++- .../statefulset/statefulset_webhook_test.go | 45 +++++++++++++++ 7 files changed, 120 insertions(+), 9 deletions(-) diff --git a/pkg/controller/jobframework/interface.go b/pkg/controller/jobframework/interface.go index 67e2b70bd8..ba600ad1b4 100644 --- a/pkg/controller/jobframework/interface.go +++ b/pkg/controller/jobframework/interface.go @@ -170,8 +170,7 @@ func MaximumExecutionTimeSeconds(job GenericJob) *int32 { return ptr.To(int32(v)) } -func workloadPriorityClassName(job GenericJob) string { - object := job.Object() +func WorkloadPriorityClassName(object client.Object) string { if workloadPriorityClassLabel := object.GetLabels()[constants.WorkloadPriorityClassLabel]; workloadPriorityClassLabel != "" { return workloadPriorityClassLabel } diff --git a/pkg/controller/jobframework/reconciler.go b/pkg/controller/jobframework/reconciler.go index 79a5ef1944..ddc1a4abee 100644 --- a/pkg/controller/jobframework/reconciler.go +++ b/pkg/controller/jobframework/reconciler.go @@ -988,7 +988,7 @@ func (r *JobReconciler) prepareWorkload(ctx context.Context, job GenericJob, wl } func (r *JobReconciler) extractPriority(ctx context.Context, podSets []kueue.PodSet, job GenericJob) (string, string, int32, error) { - if workloadPriorityClass := workloadPriorityClassName(job); len(workloadPriorityClass) > 0 { + if workloadPriorityClass := WorkloadPriorityClassName(job.Object()); len(workloadPriorityClass) > 0 { return utilpriority.GetPriorityFromWorkloadPriorityClass(ctx, r.client, workloadPriorityClass) } if jobWithPriorityClass, isImplemented := job.(JobWithPriorityClass); isImplemented { diff --git a/pkg/controller/jobframework/validation.go b/pkg/controller/jobframework/validation.go index 982f5b5ec0..418f6a6805 100644 --- a/pkg/controller/jobframework/validation.go +++ b/pkg/controller/jobframework/validation.go @@ -64,7 +64,7 @@ func ValidateJobOnCreate(job GenericJob) field.ErrorList { // ValidateJobOnUpdate encapsulates all GenericJob validations that must be performed on a Update operation func ValidateJobOnUpdate(oldJob, newJob GenericJob) field.ErrorList { allErrs := validateUpdateForQueueName(oldJob, newJob) - allErrs = append(allErrs, validateUpdateForWorkloadPriorityClassName(oldJob, newJob)...) + allErrs = append(allErrs, ValidateUpdateForWorkloadPriorityClassName(oldJob.Object(), newJob.Object())...) allErrs = append(allErrs, validateUpdateForMaxExecTime(oldJob, newJob)...) return allErrs } @@ -123,8 +123,8 @@ func validateUpdateForQueueName(oldJob, newJob GenericJob) field.ErrorList { return allErrs } -func validateUpdateForWorkloadPriorityClassName(oldJob, newJob GenericJob) field.ErrorList { - allErrs := apivalidation.ValidateImmutableField(workloadPriorityClassName(newJob), workloadPriorityClassName(oldJob), workloadPriorityClassNamePath) +func ValidateUpdateForWorkloadPriorityClassName(oldObj, newObj client.Object) field.ErrorList { + allErrs := apivalidation.ValidateImmutableField(WorkloadPriorityClassName(newObj), WorkloadPriorityClassName(oldObj), workloadPriorityClassNamePath) return allErrs } diff --git a/pkg/controller/jobs/deployment/deployment_webhook.go b/pkg/controller/jobs/deployment/deployment_webhook.go index 9704f3e0d4..7cb0d99c99 100644 --- a/pkg/controller/jobs/deployment/deployment_webhook.go +++ b/pkg/controller/jobs/deployment/deployment_webhook.go @@ -85,6 +85,9 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error { if queueName != "" { deployment.Spec.Template.Labels[constants.QueueLabel] = queueName } + if priorityClass := jobframework.WorkloadPriorityClassName(deployment.Object()); priorityClass != "" { + deployment.Spec.Template.Labels[constants.WorkloadPriorityClassLabel] = priorityClass + } } return nil @@ -122,6 +125,7 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob allErrs := field.ErrorList{} allErrs = append(allErrs, jobframework.ValidateQueueName(newDeployment.Object())...) + allErrs = append(allErrs, jobframework.ValidateUpdateForWorkloadPriorityClassName(oldDeployment.Object(), newDeployment.Object())...) // Prevents updating the queue-name if at least one Pod is not suspended // or if the queue-name has been deleted. diff --git a/pkg/controller/jobs/deployment/deployment_webhook_test.go b/pkg/controller/jobs/deployment/deployment_webhook_test.go index 19d17ff35b..daa79fc504 100644 --- a/pkg/controller/jobs/deployment/deployment_webhook_test.go +++ b/pkg/controller/jobs/deployment/deployment_webhook_test.go @@ -26,6 +26,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" "sigs.k8s.io/kueue/pkg/cache" + "sigs.k8s.io/kueue/pkg/controller/constants" "sigs.k8s.io/kueue/pkg/controller/jobframework" "sigs.k8s.io/kueue/pkg/controller/jobs/pod" "sigs.k8s.io/kueue/pkg/features" @@ -97,6 +98,44 @@ func TestDefault(t *testing.T) { want: testingdeployment.MakeDeployment("test-pod", ""). Obj(), }, + "deployment with queue and priority class": { + deployment: testingdeployment.MakeDeployment("test-pod", ""). + Queue("test-queue"). + Label(constants.WorkloadPriorityClassLabel, "test"). + Obj(), + want: testingdeployment.MakeDeployment("test-pod", ""). + Queue("test-queue"). + Label(constants.WorkloadPriorityClassLabel, "test"). + PodTemplateSpecQueue("test-queue"). + PodTemplateAnnotation(pod.SuspendedByParentAnnotation, FrameworkName). + PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test"). + Obj(), + }, + "deployment with queue, priority class and pod template spec queue, priority class": { + deployment: testingdeployment.MakeDeployment("test-pod", ""). + Queue("new-test-queue"). + Label(constants.WorkloadPriorityClassLabel, "new-test"). + PodTemplateSpecQueue("test-queue"). + PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test"). + Obj(), + want: testingdeployment.MakeDeployment("test-pod", ""). + Queue("new-test-queue"). + Label(constants.WorkloadPriorityClassLabel, "new-test"). + PodTemplateSpecQueue("new-test-queue"). + PodTemplateAnnotation(pod.SuspendedByParentAnnotation, FrameworkName). + PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "new-test"). + Obj(), + }, + "deployment without queue with pod template spec queue and priority class": { + deployment: testingdeployment.MakeDeployment("test-pod", ""). + PodTemplateSpecQueue("test-queue"). + PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test"). + Obj(), + want: testingdeployment.MakeDeployment("test-pod", ""). + PodTemplateSpecQueue("test-queue"). + PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test"). + Obj(), + }, } for name, tc := range testCases { @@ -246,6 +285,22 @@ func TestValidateUpdate(t *testing.T) { }, }.ToAggregate(), }, + "update priority-class": { + oldDeployment: testingdeployment.MakeDeployment("test-pod", ""). + Queue("test-queue"). + Label(constants.WorkloadPriorityClassLabel, "test"). + Obj(), + newDeployment: testingdeployment.MakeDeployment("test-pod", ""). + Queue("test-queue"). + Label(constants.WorkloadPriorityClassLabel, "new-test"). + Obj(), + wantErr: field.ErrorList{ + &field.Error{ + Type: field.ErrorTypeInvalid, + Field: "metadata.labels[kueue.x-k8s.io/priority-class]", + }, + }.ToAggregate(), + }, } for name, tc := range testCases { diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook.go b/pkg/controller/jobs/statefulset/statefulset_webhook.go index 09193badb0..71bb5ec867 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook.go @@ -91,6 +91,9 @@ func (wh *Webhook) Default(ctx context.Context, obj runtime.Object) error { ss.Spec.Template.Annotations[pod.GroupServingAnnotation] = "true" ss.Spec.Template.Annotations[kueuealpha.PodGroupPodIndexLabelAnnotation] = appsv1.PodIndexLabel } + if priorityClass := jobframework.WorkloadPriorityClassName(ss.Object()); priorityClass != "" { + ss.Spec.Template.Labels[constants.WorkloadPriorityClassLabel] = priorityClass + } } return nil @@ -112,9 +115,10 @@ func (wh *Webhook) ValidateCreate(ctx context.Context, obj runtime.Object) (warn } var ( - labelsPath = field.NewPath("metadata", "labels") - queueNameLabelPath = labelsPath.Key(constants.QueueLabel) - replicasPath = field.NewPath("spec", "replicas") + labelsPath = field.NewPath("metadata", "labels") + queueNameLabelPath = labelsPath.Key(constants.QueueLabel) + replicasPath = field.NewPath("spec", "replicas") + priorityClassNameLabelPath = labelsPath.Key(constants.WorkloadPriorityClassLabel) ) func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (warnings admission.Warnings, err error) { @@ -134,6 +138,10 @@ func (wh *Webhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Ob if oldStatefulSet.Status.ReadyReplicas > 0 || newQueueName == "" { allErrs = append(allErrs, apivalidation.ValidateImmutableField(oldQueueName, newQueueName, queueNameLabelPath)...) } + allErrs = append(allErrs, jobframework.ValidateUpdateForWorkloadPriorityClassName( + oldStatefulSet.Object(), + newStatefulSet.Object(), + )...) if jobframework.IsManagedByKueue(newStatefulSet.Object()) { oldReplicas := ptr.Deref(oldStatefulSet.Spec.Replicas, 1) diff --git a/pkg/controller/jobs/statefulset/statefulset_webhook_test.go b/pkg/controller/jobs/statefulset/statefulset_webhook_test.go index 7770eaf0d9..2eca0f730a 100644 --- a/pkg/controller/jobs/statefulset/statefulset_webhook_test.go +++ b/pkg/controller/jobs/statefulset/statefulset_webhook_test.go @@ -70,6 +70,27 @@ func TestDefault(t *testing.T) { PodTemplateSpecPodGroupPodIndexLabelAnnotation(appsv1.PodIndexLabel). Obj(), }, + "statefulset with queue and priority class": { + enableIntegrations: []string{"pod"}, + statefulset: testingstatefulset.MakeStatefulSet("test-pod", ""). + Replicas(10). + Queue("test-queue"). + Label(constants.WorkloadPriorityClassLabel, "test"). + Obj(), + want: testingstatefulset.MakeStatefulSet("test-pod", ""). + Replicas(10). + Queue("test-queue"). + Label(constants.WorkloadPriorityClassLabel, "test"). + PodTemplateSpecQueue("test-queue"). + PodTemplateAnnotation(pod.SuspendedByParentAnnotation, FrameworkName). + PodTemplateSpecLabel(constants.WorkloadPriorityClassLabel, "test"). + PodTemplateSpecPodGroupNameLabel("test-pod", "", gvk). + PodTemplateSpecPodGroupTotalCountAnnotation(10). + PodTemplateSpecPodGroupFastAdmissionAnnotation(true). + PodTemplateSpecPodGroupServingAnnotation(true). + PodTemplateSpecPodGroupPodIndexLabelAnnotation(appsv1.PodIndexLabel). + Obj(), + }, "statefulset without replicas": { enableIntegrations: []string{"pod"}, statefulset: testingstatefulset.MakeStatefulSet("test-pod", ""). @@ -307,6 +328,30 @@ func TestValidateUpdate(t *testing.T) { }, }.ToAggregate(), }, + "change in priority class label": { + oldObj: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + constants.QueueLabel: "queue1", + constants.WorkloadPriorityClassLabel: "priority1", + }, + }, + }, + newObj: &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + constants.QueueLabel: "queue1", + constants.WorkloadPriorityClassLabel: "priority2", + }, + }, + }, + wantErr: field.ErrorList{ + &field.Error{ + Type: field.ErrorTypeInvalid, + Field: priorityClassNameLabelPath.String(), + }, + }.ToAggregate(), + }, "change in replicas (scale down to zero)": { oldObj: &appsv1.StatefulSet{ Spec: appsv1.StatefulSetSpec{