Skip to content

Commit

Permalink
Replace shape_ranges to min_shape/opt_shape/max_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
lanluo-nvidia committed Jun 5, 2024
1 parent cacb23b commit ed18394
Showing 1 changed file with 42 additions and 8 deletions.
50 changes: 42 additions & 8 deletions tests/py/dynamo/conversion/test_cat_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ def forward(self, x, y):

input_specs = [
Input(
shape=(16, -1, 3),
dtype=torch.float32,
shape_ranges=[((16, 2, 3), (16, 3, 3), (16, 32, 3))],
min_shape=(16, 2, 3),
opt_shape=(16, 3, 3),
max_shape=(16, 32, 3),
),
Input(
shape=(16, -1, 3),
dtype=torch.float32,
shape_ranges=[((16, 2, 3), (16, 16, 3), (16, 32, 3))],
min_shape=(16, 2, 3),
opt_shape=(16, 16, 3),
max_shape=(16, 32, 3),
),
]
self.run_test_with_dynamic_shape(
Expand All @@ -71,14 +73,46 @@ def forward(self, x, y):

input_specs = [
Input(
shape=(-1, 16, 3),
dtype=torch.float32,
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
min_shape=(2, 16, 3),
opt_shape=(3, 16, 3),
max_shape=(32, 16, 3),
),
Input(
shape=(-1, 16, 3),
dtype=torch.float32,
shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))],
min_shape=(2, 16, 3),
opt_shape=(3, 16, 3),
max_shape=(32, 16, 3),
),
]
self.run_test_with_dynamic_shape(
Cat(),
input_specs,
)

@parameterized.expand(
[
("pos", 1),
("neg", -2),
]
)
def test_cat_dynamic_shape_dim(self, _, dim):
class Cat(nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat.default((x, y), dim)

input_specs = [
Input(
dtype=torch.float32,
min_shape=(2, 1, 1),
opt_shape=(3, 1, 2),
max_shape=(4, 1, 3),
),
Input(
dtype=torch.float32,
min_shape=(2, 2, 1),
opt_shape=(3, 3, 2),
max_shape=(4, 4, 3),
),
]
self.run_test_with_dynamic_shape(
Expand Down

0 comments on commit ed18394

Please sign in to comment.