diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index 360113b78a3d..7e3c81b714c0 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -14,7 +14,7 @@ from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp from deepspeed.accelerator import get_accelerator import deepspeed -if deepspeed.HAS_TRITON: +if deepspeed.HAS_TRITON and get_accelerator().is_triton_supported(): from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention