Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2D RoPE + CLIP updates #1973

Merged
merged 12 commits into from
Nov 17, 2024
63 changes: 43 additions & 20 deletions torchtune/models/clip/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ def clip_vision_encoder(
activation: Callable = nn.SiLU,
cls_output_dim: int = 512,
attn_bias: bool = True,
rope_base: Optional[int] = None,
encoder_max_seq_len: Optional[int] = None,
out_indices: Optional[List[int]] = None,
output_cls_projection: bool = False,
max_num_tiles: int = 4,
in_channels: int = 3,
intermediate_act: torch.nn.Module = torch.nn.SiLU(),
append_cls_token: bool = False,
) -> VisionTransformer:
"""
Builds the vision encoder associated with the clip model. This includes:
Expand All @@ -67,6 +69,11 @@ def clip_vision_encoder(
activation (Callable): The activation function to use in the MLP layer.
cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module.
attn_bias (bool): Boolean for if to use bias in the attention module. Default True.
rope_base (Optional[int]): base for the rotary positional embeddings. CLIP does not include rope by default,
if a value is passed in then rope will be added to multihead attention. Default: None
encoder_max_seq_len (Optional[int]): maximum sequence length the encoder will be run with, as used
by :func:`~torchtune.modules.RotaryPositionalEmbeddings`. This is required if ``rope_base``
is specified. Default: None.
out_indices (Optional[List[int]]): The indices of hidden layers to return.
If provided, it will return the intermediate results of the transformer layers
before they go through a next layer. For example, ``out_indices=[0,3]`` will
Expand All @@ -76,33 +83,51 @@ def clip_vision_encoder(
max_num_tiles (int): The maximum number of tiles that can be processed. This is used to
determine the size of the positional embeddings.
in_channels (int): The number of image input channels.
intermediate_act (torch.nn.Module): The activation function used in the intermediate layers in the transformer encoder.
append_cls_token (bool): If True, adds CLS token embedding to the end of the sequence in the vision transformer.
Default is False, which adds CLS token to the beginning of the sequence.

Returns:
A `VisionTransformer` object.

Raises:
AssertionError: If ``embed_dim`` is not divisible by ``num_heads``.
"""
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim must be divisible by num_heads, got {embed_dim} and {num_heads}"
)
if rope_base is not None and encoder_max_seq_len is None:
raise ValueError(
"encoder_max_seq_len must be provided if rope_base is specified. "
"This is used to determine the maximum sequence length for the rotary positional embeddings."
)

head_dim = embed_dim // num_heads

cls_projection = (
CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim)
if output_cls_projection
else None
)
rope = (
RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=encoder_max_seq_len, base=rope_base
)
if rope_base is not None
else None
)

# transformer layer
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
k_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
v_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
output_proj=nn.Linear(embed_dim, embed_dim, bias=attn_bias),
pos_embeddings=None,
pos_embeddings=rope,
attn_dropout=0.0,
is_causal=False,
)
Expand Down Expand Up @@ -154,6 +179,7 @@ def clip_vision_encoder(
patch_size=patch_size,
embed_dim=embed_dim,
in_channels=in_channels,
append_cls_token=append_cls_token,
)


Expand Down Expand Up @@ -188,7 +214,6 @@ def clip_mlp(
def lora_clip_vision_encoder(
lora_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
*,
# clip encoder parameters
tile_size: int,
Expand All @@ -198,12 +223,11 @@ def lora_clip_vision_encoder(
num_heads: int,
activation: Callable = nn.SiLU,
cls_output_dim: int = 512,
attn_bias: bool = True,
attn_bias: bool = False,
out_indices: Optional[List[int]] = None,
output_cls_projection: bool = False,
max_num_tiles: int = 4,
in_channels: int = 3,
intermediate_act: torch.nn.Module = torch.nn.SiLU(),
# LoRA parameters
lora_rank: int = 8,
lora_alpha: float = 16,
Expand All @@ -220,8 +244,6 @@ def lora_clip_vision_encoder(
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise,
the size of the input image. In this case, the function will consider your image as a single tile.
patch_size (int): The size of each patch. Used to divide the tiles into patches.
Expand All @@ -232,7 +254,7 @@ def lora_clip_vision_encoder(
num_heads (int): The number of attention heads in each transformer layer.
activation (Callable): The activation function to use in the MLP layer.
cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module.
attn_bias (bool): Boolean for if to use bias in the attention module. Default True.
attn_bias (bool): Boolean for if to use bias in the attention module. Default False.
out_indices (Optional[List[int]]): The indices of hidden layers to return.
If provided, it will return the intermediate results of the transformer layers
before they go through a next layer. For example, ``out_indices=[0,3]`` will
Expand All @@ -242,7 +264,6 @@ def lora_clip_vision_encoder(
max_num_tiles (int): The maximum number of tiles that can be processed. This is used to
determine the size of the positional embeddings.
in_channels (int): The number of image input channels.
intermediate_act (torch.nn.Module): The activation function used in the intermediate layers in the transformer encoder.
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
Expand Down Expand Up @@ -277,6 +298,7 @@ def lora_clip_vision_encoder(
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
attn_bias=attn_bias,
)
if apply_lora_to_mlp:
mlp = lora_clip_mlp(
Expand Down Expand Up @@ -361,6 +383,7 @@ def lora_clip_attention(
num_heads: int,
num_kv_heads: int,
attn_dropout: float = 0.0,
attn_bias: bool = False,
# LoRA args
lora_rank: int,
lora_alpha: float,
Expand Down Expand Up @@ -417,9 +440,9 @@ def lora_clip_attention(
)
if "q_proj" in lora_modules
else (
nn.Linear(embed_dim, num_heads * head_dim, bias=False)
nn.Linear(embed_dim, num_heads * head_dim, bias=attn_bias)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False)
else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=attn_bias)
)
)
k_proj = (
Expand All @@ -433,9 +456,9 @@ def lora_clip_attention(
)
if "k_proj" in lora_modules
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias)
)
)
v_proj = (
Expand All @@ -449,9 +472,9 @@ def lora_clip_attention(
)
if "v_proj" in lora_modules
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=attn_bias)
)
)
output_proj = (
Expand All @@ -465,9 +488,9 @@ def lora_clip_attention(
)
if "output_proj" in lora_modules
else (
nn.Linear(embed_dim, embed_dim, bias=False)
nn.Linear(embed_dim, embed_dim, bias=attn_bias)
if not quantize_base
else FrozenNF4Linear(embed_dim, embed_dim, bias=False)
else FrozenNF4Linear(embed_dim, embed_dim, bias=attn_bias)
)
)

Expand Down
54 changes: 42 additions & 12 deletions torchtune/models/llama3/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp

from torchtune.modules import (
MultiHeadAttention,
FeedForward,
FrozenNF4Linear,
MultiHeadAttention,
RMSNorm,
RotaryPositionalEmbeddings,
TransformerDecoder,
Expand All @@ -40,6 +40,7 @@

# ------------------ Vanilla Llama3 ------------------


def llama3(
vocab_size: int,
num_layers: int,
Expand All @@ -48,7 +49,7 @@ def llama3(
embed_dim: int,
max_seq_len: int,
attn_dropout: float = 0.0,
rope_base: int = 500000.0,
rope_base: int = 500_000,
intermediate_dim: Optional[int] = None,
norm_eps: float = 1e-5,
) -> TransformerDecoder:
Expand All @@ -72,6 +73,7 @@ def llama3(
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
rope_base (int): base for the rotary positional embeddings. Default: 500_000
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
norm_eps (float): epsilon in RMS norms.
Expand All @@ -81,7 +83,9 @@ def llama3(
"""
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
rope = RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
Expand All @@ -95,7 +99,9 @@ def llama3(
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
hidden_dim = (
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
)
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim)
layer = TransformerSelfAttentionLayer(
attn=self_attn,
Expand All @@ -116,17 +122,29 @@ def llama3(
output=output_proj,
)


def llama3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward:
"""
Build the MLP layer associated with the Llama model.
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
gate_proj = (
nn.Linear(dim, hidden_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(dim, hidden_dim, bias=False)
)
down_proj = (
nn.Linear(hidden_dim, dim, bias=False)
if not quantize_base
else FrozenNF4Linear(hidden_dim, dim, bias=False)
)
up_proj = (
nn.Linear(dim, hidden_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(dim, hidden_dim, bias=False)
)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)



# ------------------ LoRA Llama3 ------------------


Expand Down Expand Up @@ -211,7 +229,9 @@ def lora_llama3(
use_dora=use_dora,
)

hidden_dim = intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
hidden_dim = (
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
)
if apply_lora_to_mlp:
mlp = lora_llama3_mlp(
dim=embed_dim,
Expand All @@ -223,7 +243,9 @@ def lora_llama3(
use_dora=use_dora,
)
else:
mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base)
mlp = llama3_mlp(
dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base
)

layer = TransformerSelfAttentionLayer(
attn=self_attn,
Expand All @@ -237,7 +259,13 @@ def lora_llama3(
# TODO: quantize_base is not applied to final output_proj currently.
adapter_cls = DoRALinear if use_dora else LoRALinear
output_proj = (
adapter_cls(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
adapter_cls(
embed_dim,
vocab_size,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
Expand Down Expand Up @@ -382,7 +410,9 @@ def lora_llama3_self_attention(
else FrozenNF4Linear(embed_dim, embed_dim, bias=False)
)
)
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
rope = RotaryPositionalEmbeddings(
dim=head_dim, max_seq_len=max_seq_len, base=rope_base
)
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
Expand Down
Loading
Loading