Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2D RoPE + CLIP updates #1973

Merged
merged 12 commits into from
Nov 17, 2024
60 changes: 60 additions & 0 deletions tests/torchtune/models/phi3/test_phi3_position_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from tests.test_utils import assert_expected, mps_ignored_test
from torch import tensor
from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings

from torchtune.training.seed import set_seed


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestPhi3RotaryPositionalEmbeddings:
"""
Class for testing the Phi3 models RoPE Embeddings. The expected tensors are
computed from the reference implementation here:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
"""

@pytest.fixture
def input_params(self):
bsz = 4
num_heads = 32
embed_dim = 3072
seq_len = 60
max_seq_len = 4096
head_dim = embed_dim // num_heads
return bsz, num_heads, head_dim, seq_len, max_seq_len

@pytest.fixture
def input(self, input_params) -> tensor:
bsz, num_heads, head_dim, seq_len, _ = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope_phi3(self, input_params) -> Phi3RotaryPositionalEmbeddings:
_, _, head_dim, _, max_seq_len = input_params
return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)

@mps_ignored_test()
def test_forward(
self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings
) -> None:
x_out = rope_phi3(input)

# check the numerics of the computed tensor
assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4)
assert_expected(x_out.sum(), tensor(-381.0620))

# check shapes
assert_expected(x_out.shape, input.shape)
130 changes: 95 additions & 35 deletions tests/torchtune/modules/test_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import pytest
import torch

from tests.test_utils import assert_expected, mps_ignored_test
from torch import tensor
from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings

from torchtune.modules.position_embeddings import RotaryPositionalEmbeddings
from torchtune.modules.position_embeddings import (
RotaryPositionalEmbeddings,
VisionRotaryPositionalEmbeddings,
)
from torchtune.training.seed import set_seed


Expand All @@ -35,7 +35,7 @@ class TestRotaryPositionEmbedding:
EXPECTED_X_OUT_MAX = tensor(5.4546)

@pytest.fixture
def input_params(self) -> Tuple[int, int, int, int]:
def input_params(self):
bsz = 4
num_heads = 32
embed_dim = 4096
Expand All @@ -45,14 +45,12 @@ def input_params(self) -> Tuple[int, int, int, int]:
return bsz, num_heads, head_dim, seq_len, max_seq_len

@pytest.fixture
def input(self, input_params: Tuple[int, int, int, int]) -> tensor:
def input(self, input_params) -> tensor:
bsz, num_heads, head_dim, seq_len, _ = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope(
self, input_params: Tuple[int, int, int, int]
) -> RotaryPositionalEmbeddings:
def rope(self, input_params) -> RotaryPositionalEmbeddings:
_, _, head_dim, _, max_seq_len = input_params
return RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)

Expand Down Expand Up @@ -136,44 +134,106 @@ def test_rope_init_meta_device(self, input_params):
torch.testing.assert_close(p1, p2)


class TestPhi3RotaryPositionalEmbeddings:
"""
Class for testing the Phi3 models RoPE Embeddings. The expected tensors are
computed from the reference implementation here:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
"""
class TestVisionRotaryPositionEmbedding:

EXPECTED_X_OUT_MEAN = tensor(0.0789793)
EXPECTED_X_OUT_SUM = tensor(25.2733822)
EXPECTED_X_OUT_MAX = tensor(3.1225626)

@pytest.fixture
def input_params(self) -> Tuple[int, int, int, int]:
bsz = 4
num_heads = 32
embed_dim = 3072
seq_len = 60
max_seq_len = 4096
def input_params(self):
bsz = 2
num_heads = 8
embed_dim = 32
head_dim = embed_dim // num_heads
return bsz, num_heads, head_dim, seq_len, max_seq_len
seq_len = 5
patch_size = 4
tile_size = 16
return bsz, num_heads, head_dim, seq_len, patch_size, tile_size

@pytest.fixture
def input(self, input_params: Tuple[int, int, int, int]) -> tensor:
bsz, num_heads, head_dim, seq_len, _ = input_params
def input(self, input_params) -> tensor:
bsz, num_heads, head_dim, seq_len, *_ = input_params
return torch.randn(bsz, seq_len, num_heads, head_dim)

@pytest.fixture
def rope_phi3(
self, input_params: Tuple[int, int, int, int]
) -> Phi3RotaryPositionalEmbeddings:
_, _, head_dim, _, max_seq_len = input_params
return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
def rope(self, input_params):
_, _, head_dim, _, patch_size, tile_size = input_params
return VisionRotaryPositionalEmbeddings(
patch_size=patch_size, tile_size=tile_size, dim=head_dim // 2
)

@mps_ignored_test()
def test_forward(
self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings
) -> None:
x_out = rope_phi3(input)
def test_forward(self, input, rope) -> None:
x_out = rope(input)

# check the numerics of the computed tensor
assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4)
assert_expected(x_out.sum(), tensor(-381.0620))
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_curr_pos(self, input, rope) -> None:
(
_,
seq_len,
_,
_,
) = input.shape
x_out = rope(input, input_pos=torch.arange(seq_len))

# these values should be exactly the same as test_forward
# since in this case input_pos covers the entire input
# sequence. This tests that input_pos works as expected i.e.
# extracts the embeddings for the relevant positions
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_packed_pos(self, input, rope) -> None:
"""
Use input_pos to indicate positions of each token relative to its sequence
when sample is packed.
"""
(
bsz,
seq_len,
_,
_,
) = input.shape
x_out = rope(
input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len)
)

# these values should be exactly the same as test_forward
# AND test_forward_with_current_pos. In this case input_pos
# covers the entire batch dim and is defined for each sample separately.
# This tests that input_pos works as expected i.e.
# extracts the embeddings for the relevant positions for each sample
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

def test_rope_init_meta_device(self, input_params):
_, _, head_dim, _, patch_size, tile_size = input_params
rope_on_device = VisionRotaryPositionalEmbeddings(
dim=head_dim, patch_size=patch_size, tile_size=tile_size
)
with torch.device("meta"):
meta_rope = VisionRotaryPositionalEmbeddings(
dim=head_dim, patch_size=patch_size, tile_size=tile_size
)

meta_rope.rope_init()
for p1, p2 in zip(rope_on_device.buffers(), meta_rope.buffers()):
torch.testing.assert_close(p1, p2)
24 changes: 24 additions & 0 deletions tests/torchtune/modules/test_vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,27 @@ def test_vision_transformer_single_tile(self, transformer_config):
), f"Expected shape {expected_shape}, but got {output.shape}"

assert_expected(output.mean(), torch.tensor(0.5458), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_vision_transformer_append_cls_token(self, transformer_config):
transformer_config = transformer_config.copy()
transformer_config["append_cls_token"] = True

model_append_cls = clip_vision_encoder(**transformer_config).eval()
fixed_init_model(model_append_cls, min_val=-1, max_val=1)
output, _ = model_append_cls(self.image, self.aspect_ratio)

# assertion
expected_shape = (
self.batch_size,
self.n_imgs,
self.num_tiles,
model_append_cls.get_image_tokens_per_tile(),
transformer_config["embed_dim"],
)

assert (
output.shape == expected_shape
), f"Expected shape {expected_shape}, but got {output.shape}"

assert_expected(output.mean(), torch.tensor(1.0172), atol=1e-3, rtol=1e-3)
6 changes: 4 additions & 2 deletions tests/torchtune/training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,10 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool):
)
# init rope since it's not covered in state dict
for m in fsdp_model_to_load.modules():
if isinstance(m, modules.RotaryPositionalEmbeddings):
m.reset_parameters()
if isinstance(m, modules.RotaryPositionalEmbeddings) or isinstance(
m, modules.VisionRotaryPositionalEmbeddings
):
m.rope_init()
for m in fsdp_model_to_load.modules():
if enable_activation_checkpointing:
if isinstance(m, CheckpointWrapper):
Expand Down
Loading
Loading