From 8ac8c9f5aae6720d99c2eee9fcdda01cab4885ea Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Wed, 26 Feb 2025 11:59:49 -0800 Subject: [PATCH] Improve inference tutorial docs (#7083) Fixes: #7082 --------- Signed-off-by: Logan Adams --- docs/_tutorials/inference-tutorial.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/_tutorials/inference-tutorial.md b/docs/_tutorials/inference-tutorial.md index 1d5899204f53..ddf287f24b96 100644 --- a/docs/_tutorials/inference-tutorial.md +++ b/docs/_tutorials/inference-tutorial.md @@ -21,18 +21,22 @@ if args.pre_load_checkpoint: model = model_class.from_pretrained(args.model_name_or_path) else: model = model_class() + +# create the tokenizer +tokenizer = model_class.from_pretrained(args.model_name_or_path) ... import deepspeed # Initialize the DeepSpeed-Inference engine ds_engine = deepspeed.init_inference(model, - tensor_parallel={"tp_size": 2}, - dtype=torch.half, - checkpoint=None if args.pre_load_checkpoint else args.checkpoint_json, - replace_with_kernel_inject=True) + tensor_parallel={"tp_size": world_size}, + dtype=torch.half, + checkpoint=None if args.pre_load_checkpoint else args.checkpoint_json, + replace_with_kernel_inject=True) model = ds_engine.module -output = model('Input String') +pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) +output = pipe('Input String') ``` To run inference with only model-parallelism for the models that we don't support kernels, you can pass an injection policy that shows the two specific linear layers on a Transformer Encoder/Decoder layer: 1) the attention output GeMM and 2) layer output GeMM. We need these part of the layer to add the required all-reduce communication between GPUs to merge the partial results across model-parallel ranks. Below, we bring an example that shows how you can use deepspeed-inference with a T5 model: