Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the Qwen positional emebdding config code #4955

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

// DeepSpeed Team

#include <cassert>
#include "blocked_kv_rotary.cuh"
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
Expand Down
31 changes: 8 additions & 23 deletions deepspeed/inference/v2/model_implementations/qwen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def norm_type(self) -> NormTypeEnum:
def positional_embedding_type(self) -> PositionalEmbeddingType:
return PositionalEmbeddingType.rotate_half

@property
def positional_embedding_config(self) -> Optional[RotateHalfConfig]:
return RotateHalfConfig(theta_base=self._config.rotary_emb_base)

def make_norm_layer(self) -> None:
"""
Instantiates the normalization layer for the model. This sets the `self.norm` attribute.
Expand All @@ -119,27 +123,6 @@ def make_norm_layer(self) -> None:

self.norm = heuristics.instantiate_pre_norm(norm_config, self._engine_config)

def make_attn_layer(self) -> None:
"""
Builds the attention layer for the model. This sets the `self.attn` attribute.
"""
softmax_scale = 1.0 / (self.head_size**0.5)

rotary_config = RotateHalfConfig(theta_base=self._config.rotary_emb_base)

attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size,
n_heads_q=self.n_heads_q_local,
n_heads_kv=self.n_heads_kv_local,
head_size=self.head_size,
max_sequences=self._engine_config.state_manager.max_ragged_sequence_count,
scale_factor=softmax_scale,
input_dtype=self.activation_dtype,
output_dtype=self.activation_dtype,
positional_embedding_type=self.positional_embedding_type,
positional_embedding_config=rotary_config)

self.attn = heuristics.instantiate_attention(attn_config, self._engine_config)

"""
Forward implementations
"""
Expand Down Expand Up @@ -210,8 +193,10 @@ def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: Ragge
Performs unembedding of the hidden states to logits. This will only sample the final
token of each sequence.
"""
logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info,
self._non_transformer.final_norm)
logits = self.unembed(hidden_states,
self._non_transformer.word_unembed,
ragged_batch_info,
gamma=self._non_transformer.final_norm)

if self.tp_size > 1:
comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1]))
Expand Down