Skip to content

Commit

Permalink
revert workspace changes
Browse files Browse the repository at this point in the history
  • Loading branch information
oelayan7 committed Sep 18, 2024
1 parent 04f8b50 commit 1d9aeaf
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
19 changes: 14 additions & 5 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor, get_alibi_mask
from ..ops.transformer.inference.ds_attention import DeepSpeedSelfAttention
from ..model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
from ..ops.transformer.inference.op_binding.workspace import WorkspaceOp

DS_INFERENCE_ENABLED = False
from torch import nn
Expand All @@ -54,8 +53,13 @@ def __init__(self, model, config):
DS_INFERENCE_ENABLED = True

super().__init__()
self.workspace = WorkspaceOp()
self.destroy()

# Have to import here because inference_module is a global, but python
# globals only work at the module level and will not be updated unless
# we import it each time we init a new inference engine.
from ..model_implementations.transformers.ds_transformer import inference_module
if inference_module is not None:
self.destroy()

self.module = model
self._config = config
Expand Down Expand Up @@ -186,10 +190,15 @@ def __init__(self, model, config):
self._is_compiled = False

def destroy(self):
# Have to import here because inference_module is a global, but python
# globals only work at the module level and will not be updated unless
# we import it each time we init a new inference engine.
from ..model_implementations.transformers.ds_transformer import inference_module
DeepSpeedTransformerInference.layer_id = 0
DeepSpeedSelfAttention.num_layers = 0
if self.workspace.is_op_implemented():
self.workspace.release_workspace()
if inference_module is not None:
inference_module.release_workspace()
inference_module = None

def profile_model_time(self, use_cuda_events=True):
if not self.model_profile_enabled and not self._config.enable_cuda_graph:
Expand Down
15 changes: 13 additions & 2 deletions deepspeed/model_implementations/transformers/ds_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# DeepSpeed Team

import torch
from deepspeed import comm as dist
from deepspeed.model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference

inference_module = None


class DeepSpeedLlama2Inference(DeepSpeedTransformerInference):
"""Initialize the DeepSpeed OPT Transformer Layer.
Expand All @@ -24,9 +27,17 @@ def forward(self, *args, **kwargs):

input = args[0]
input_mask = None
get_present = True
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._alloc_workspace = False

self.allocate_workspace(input.size())
get_present = True

# We set the prev key/value to None when there is a prompt
if input.shape[1] > 1:
Expand Down
51 changes: 30 additions & 21 deletions deepspeed/model_implementations/transformers/ds_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
import torch
import torch.nn as nn
from deepspeed import comm as dist
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
from deepspeed.utils.logging import log_dist

from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.ops.transformer.inference.op_binding.workspace import WorkspaceOp
from deepspeed.ops.transformer.inference.op_binding.layer_norm import LayerNormOp
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
import deepspeed
if deepspeed.HAS_TRITON:
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention

inference_module = None


class DeepSpeedTransformerInference(nn.Module):
"""Initialize the DeepSpeed Transformer Layer.
Expand All @@ -36,7 +38,6 @@ class DeepSpeedTransformerInference(nn.Module):
for specific downstream tasks.
"""
layer_id = 0
workspace = None

def __init__(self,
config,
Expand All @@ -52,6 +53,10 @@ def __init__(self,
DeepSpeedTransformerInference.layer_id += 1

data_type = torch.half if self.config.dtype == torch.int8 else self.config.dtype
global inference_module
if inference_module is None:
builder = InferenceBuilder()
inference_module = builder.load()

if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
Expand Down Expand Up @@ -85,26 +90,22 @@ def __init__(self,
requires_grad=False)
self.layer_past = None
self.layer_norm = LayerNormOp()
DeepSpeedTransformerInference.workspace = WorkspaceOp(self.config)
self._should_allocate_workspace = True
self.allocate_workspace_func = self.workspace.allocate_workspace

def allocate_workspace(self, size):
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._should_allocate_workspace:
self.allocate_workspace_func(self.config.hidden_size, self.config.heads, size[1], size[0],
DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._should_allocate_workspace = False
try:
if config.dtype == torch.float32:
self.allocate_workspace = inference_module.allocate_workspace_fp32
elif config.dtype == torch.bfloat16:
self.allocate_workspace = inference_module.allocate_workspace_bf16
else:
self.allocate_workspace = inference_module.allocate_workspace_fp32
self._alloc_workspace = True
except AttributeError:
self.allocate_workspace = None
self._alloc_workspace = False

@classmethod
def reset_cache(cls):
if cls.workspace is None:
cls.workspace = WorkspaceOp()
if cls.workspace.is_op_implemented():
cls.workspace.reset_cache()
if inference_module is not None:
inference_module.reset_cache()

def forward(
self,
Expand Down Expand Up @@ -137,7 +138,15 @@ def forward(

input_mask = (input_mask if attn_mask is None else attn_mask) if attention_mask is None else attention_mask

self.allocate_workspace(input.size())
# Allocate memory only on first layer forward
if self.config.layer_id == 0 and self._alloc_workspace:
self.allocate_workspace(self.config.hidden_size, self.config.heads,
input.size()[1],
input.size()[0], DeepSpeedTransformerInference.layer_id, self.config.mp_size,
self.config.bigscience_bloom,
dist.get_rank() if dist.is_initialized() else 0, self.config.max_out_tokens,
self.config.min_out_tokens)
self._alloc_workspace = False

get_present = (get_present or get_key_value or use_cache)
input_mask = input_mask if attention_mask is None else attention_mask
Expand Down

0 comments on commit 1d9aeaf

Please sign in to comment.