Skip to content

Commit

Permalink
Add Go tests to neuron-inference
Browse files Browse the repository at this point in the history
  • Loading branch information
mattcjo committed Jan 13, 2025
1 parent 3730689 commit 5409163
Show file tree
Hide file tree
Showing 6 changed files with 518 additions and 0 deletions.
154 changes: 154 additions & 0 deletions test/cases/neuron-inference/bert_inference_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package inference

import (
"context"
_ "embed"
"fmt"
"io"
"log"
"testing"
"time"

fwext "github.com/aws/aws-k8s-tester/internal/e2e"
batchv1 "k8s.io/api/batch/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"sigs.k8s.io/e2e-framework/klient/wait"
"sigs.k8s.io/e2e-framework/pkg/envconf"
"sigs.k8s.io/e2e-framework/pkg/features"
)

//go:embed manifests/neuron-bert-inference.yaml
var neuronBertInferenceManifest []byte

var renderedManifest []byte

func TestNeuronInference(t *testing.T) {
if *bertInferenceImage == "" {
t.Fatal("bertInferenceImage must be set to run the test")
}

log.Printf("[INFO] Using nodeType=%s, inferenceMode=%s", *nodeType, *inferenceMode)
log.Printf("[INFO] Discovered neuronPerNode=%d, neuronCorePerNode=%d", neuronPerNode, neuronCorePerNode)

renderVars := map[string]string{
"BertInferenceImage": *bertInferenceImage,
"NodeType": *nodeType, // e.g. "inf2.xlarge"
"InferenceMode": *inferenceMode, // "throughput" or "latency"
"NeuronPerNode": fmt.Sprintf("%d", neuronPerNode),
"NeuronCorePerNode": fmt.Sprintf("%d", neuronCorePerNode),
}

feature := features.New("neuron-inference").
WithLabel("suite", "neuron").
WithLabel("hardware", "neuron").
Setup(func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
log.Println("[INFO] Rendering 'neuron-bert-inference.yaml' with dynamic vars...")
var err error
renderedManifest, err = fwext.RenderManifests(neuronBertInferenceManifest, renderVars)
if err != nil {
t.Fatalf("[ERROR] Failed to render inference manifest: %v", err)
}
log.Println("[INFO] Successfully rendered inference YAML.")

log.Println("[INFO] Applying neuron inference job manifest...")
if applyErr := fwext.ApplyManifests(cfg.Client().RESTConfig(), renderedManifest); applyErr != nil {
t.Fatalf("[ERROR] Failed to apply inference job manifest: %v", applyErr)
}
log.Println("[INFO] Inference job manifest applied successfully.")

ctx = context.WithValue(ctx, "applyTime", time.Now())
return ctx
}).
Assess("BERT inference Job succeeds", func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
log.Println("[INFO] Checking 'neuron-inference' job completion...")

job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "neuron-inference", Namespace: "default"},
}
err := wait.For(
fwext.NewConditionExtension(cfg.Client().Resources()).JobSucceeded(job),
wait.WithTimeout(20*time.Minute),
)
if err != nil {
t.Fatalf("[ERROR] Neuron inference job did not succeed: %v", err)
}
log.Println("[INFO] Neuron inference job succeeded. Gathering logs...")

applyTime := ctx.Value("applyTime")
if applyTime != nil {
if start, ok := applyTime.(time.Time); ok {
duration := time.Since(start)
log.Printf("[INFO] Neuron inference job completed in %s", duration)
}
}

if err := printJobLogs(ctx, cfg, "default", "neuron-inference"); err != nil {
t.Logf("[WARNING] Failed to retrieve neuron-inference job logs: %v", err)
}
return ctx
}).
Teardown(func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
log.Println("[INFO] Cleaning up neuron-inference job resources...")
if err := fwext.DeleteManifests(cfg.Client().RESTConfig(), renderedManifest); err != nil {
t.Fatalf("[ERROR] Failed to delete inference job resources: %v", err)
}
log.Println("[INFO] Inference job cleanup complete.")
return ctx
}).
Feature()

testenv.Test(t, feature)
}

func printJobLogs(ctx context.Context, cfg *envconf.Config, namespace, jobName string) error {
cs, err := getClientset(cfg.Client().RESTConfig())
if err != nil {
return fmt.Errorf("[ERROR] failed to create kubernetes client: %w", err)
}

pods, err := cs.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{
LabelSelector: fmt.Sprintf("job-name=%s", jobName),
})
if err != nil {
return fmt.Errorf("[ERROR] failed to list pods for job %s: %w", jobName, err)
}
if len(pods.Items) == 0 {
return fmt.Errorf("[ERROR] no pods found for job %s", jobName)
}

for _, pod := range pods.Items {
log.Printf("[INFO] Pod %s is on node %s", pod.Name, pod.Spec.NodeName)
stream, err := cs.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, &v1.PodLogOptions{}).Stream(ctx)
if err != nil {
return fmt.Errorf("[ERROR] failed to get logs from pod %s: %w", pod.Name, err)
}
defer stream.Close()

buf := make([]byte, 4096)
for {
n, readErr := stream.Read(buf)
if n > 0 {
log.Printf("[INFO] Logs from Pod %s:\n%s", pod.Name, string(buf[:n]))
}
if readErr == io.EOF {
log.Printf("[INFO] Completed log stream for pod %s.", pod.Name)
break
}
if readErr != nil {
return fmt.Errorf("[ERROR] reading logs from pod %s: %w", pod.Name, readErr)
}
}
}
return nil
}

func getClientset(restConfig *rest.Config) (*kubernetes.Clientset, error) {
cs, err := kubernetes.NewForConfig(restConfig)
if err != nil {
return nil, fmt.Errorf("cannot create kubernetes clientset: %w", err)
}
return cs, nil
}
143 changes: 143 additions & 0 deletions test/cases/neuron-inference/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package inference

import (
"context"
_ "embed"
"flag"
"fmt"
"log"
"os"
"slices"
"testing"
"time"

fwext "github.com/aws/aws-k8s-tester/internal/e2e"
appsv1 "k8s.io/api/apps/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"sigs.k8s.io/e2e-framework/klient/wait"
"sigs.k8s.io/e2e-framework/pkg/env"
"sigs.k8s.io/e2e-framework/pkg/envconf"
)

var (
//go:embed manifests/k8s-neuron-device-plugin-rbac.yml
neuronDevicePluginRbacManifest []byte
//go:embed manifests/k8s-neuron-device-plugin.yml
neuronDevicePluginManifest []byte
)

func TestMain(m *testing.M) {

flag.Parse()

cfg, err := envconf.NewFromFlags()
if err != nil {
log.Fatalf("[ERROR] Failed to create test environment: %v", err)
}
testenv = env.NewWithConfig(cfg)

manifests := [][]byte{
neuronDevicePluginRbacManifest,
neuronDevicePluginManifest,
}

// Setup steps: apply the device plugin, wait for DS readiness, discover capacity
testenv.Setup(
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("Applying Neuron device plugin RBAC and Neuron device plugin manifests.")
err := fwext.ApplyManifests(config.Client().RESTConfig(), manifests...)
if err != nil {
return ctx, fmt.Errorf("failed to apply manifests: %w", err)
}
log.Println("Successfully applied Neuron device plugin RBAC and Neuron device plugin manifests.")
return ctx, nil
},
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("Waiting for Neuron Device Plugin daemonset to be ready.")
daemonset := appsv1.DaemonSet{
ObjectMeta: metav1.ObjectMeta{Name: "neuron-device-plugin-daemonset", Namespace: "kube-system"},
}
err := wait.For(
fwext.NewConditionExtension(config.Client().Resources()).DaemonSetReady(&daemonset),
wait.WithTimeout(time.Minute*5),
)
if err != nil {
return ctx, fmt.Errorf("Neuron Device Plugin daemonset is not ready: %w", err)
}
log.Println("Neuron Device Plugin daemonset is ready.")
return ctx, nil
},
discoverNeuronCoreCapacity,
)

// Finish steps: remove device plugin if desired
testenv.Finish(
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("[INFO] Cleaning up Neuron device plugin.")
slices.Reverse(manifests)
if err := fwext.DeleteManifests(config.Client().RESTConfig(), manifests...); err != nil {
return ctx, fmt.Errorf("failed to delete neuron device plugin: %w", err)
}
log.Println("[INFO] Neuron device plugin cleanup complete.")
return ctx, nil
},
)

exitCode := testenv.Run(m)
log.Printf("[INFO] Test environment finished with exit code %d", exitCode)
os.Exit(exitCode)
}

// discoverNeuronCoreCapacity sets neuronPerNode and neuronCorePerNode by scanning the cluster
func discoverNeuronCoreCapacity(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("[INFO] Discovering cluster's Neuron capacity...")

cs, err := kubernetes.NewForConfig(config.Client().RESTConfig())
if err != nil {
return ctx, fmt.Errorf("failed to create kubernetes client: %w", err)
}

nodes, err := cs.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
if err != nil {
return ctx, fmt.Errorf("failed to list nodes: %w", err)
}
if len(nodes.Items) == 0 {
return ctx, fmt.Errorf("no nodes found in the cluster")
}

var totalNeuron, totalNeuronCore int
for _, node := range nodes.Items {
instanceType := node.Labels["node.kubernetes.io/instance-type"]
neuronCap, hasNeuron := node.Status.Capacity["aws.amazon.com/neuron"]
neuronCoreCap, hasNeuronCore := node.Status.Capacity["aws.amazon.com/neuroncore"]

if hasNeuron {
totalNeuron += int(neuronCap.Value())
} else {
log.Printf("[WARN] Node %s (type=%s) lacks 'aws.amazon.com/neuron'.", node.Name, instanceType)
}

if hasNeuronCore {
totalNeuronCore += int(neuronCoreCap.Value())
} else {
log.Printf("[WARN] Node %s (type=%s) lacks 'aws.amazon.com/neuroncore'.", node.Name, instanceType)
}
}

nodeCount := len(nodes.Items)
if nodeCount > 0 {
neuronPerNode = totalNeuron / nodeCount
neuronCorePerNode = totalNeuronCore / nodeCount
}

log.Printf("[INFO] Discovered neuronPerNode=%d, neuronCorePerNode=%d (across %d node(s))",
neuronPerNode, neuronCorePerNode, nodeCount)

if neuronCorePerNode <= 0 {
return ctx, fmt.Errorf("discovered %d neuronCorePerNode => no Neuron capacity found", neuronCorePerNode)
}

log.Println("[INFO] Neuron capacity discovery complete.")
return ctx, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Source: https://github.com/aws-neuron/aws-neuron-sdk/blob/master/src/k8/k8s-neuron-device-plugin-rbac.yml
kind: ClusterRole
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: neuron-device-plugin
rules:
- apiGroups:
- ""
resources:
- nodes
verbs:
- get
- list
- watch
- apiGroups:
- ""
resources:
- events
verbs:
- create
- patch
- apiGroups:
- ""
resources:
- pods
verbs:
- update
- patch
- get
- list
- watch
- apiGroups:
- ""
resources:
- nodes/status
verbs:
- patch
- update
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: neuron-device-plugin
namespace: kube-system
---
kind: ClusterRoleBinding
apiVersion: rbac.authorization.k8s.io/v1
metadata:
name: neuron-device-plugin
namespace: kube-system
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: neuron-device-plugin
subjects:
- kind: ServiceAccount
name: neuron-device-plugin
namespace: kube-system
Loading

0 comments on commit 5409163

Please sign in to comment.