diff --git a/.gitignore b/.gitignore index 4b6d83ad904a..8d0750110b3d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ build __pycache__ *.so -test/simple.py \ No newline at end of file +test/simple.py +tmp \ No newline at end of file diff --git a/frontend/config.py b/frontend/config.py index fb259a03ed9d..5bc059a793f5 100644 --- a/frontend/config.py +++ b/frontend/config.py @@ -5,6 +5,7 @@ "debug": True, "miss_threshold": 3, "dynshape": False, + "model_name": "", "enable_fallback": False, } diff --git a/frontend/control_flow.py b/frontend/control_flow.py index 0087b5986767..ad11fef7d967 100644 --- a/frontend/control_flow.py +++ b/frontend/control_flow.py @@ -40,8 +40,7 @@ def forward(self, *values: Any) -> Any: loop_carry = values[self.num_read_only_param:] while iter_num < self.num_iter: # and cond.item(): - loop_carry = self.body(torch.tensor(iter_num), *read_only, - *loop_carry) + loop_carry = self.body(iter_num, *read_only, *loop_carry) # cond, *loop_carry = self.body(iter_num, cond, *read_only, # *loop_carry) iter_num += 1 @@ -296,6 +295,3 @@ def if_stmt(cond: bool, if_true: Callable[..., Any], break_at_callsite() recover() return if_run_branch() - - -torch.Tensor.__iter__ diff --git a/frontend/fx_graph.py b/frontend/fx_graph.py index 607e5d5b61b7..246eb3b730d4 100644 --- a/frontend/fx_graph.py +++ b/frontend/fx_graph.py @@ -34,6 +34,41 @@ NodeArgs = Union[BaseArgumentTypes, torch.fx.Node] +def fetch_attr(gm: torch.fx.GraphModule, target: str) -> Any: + target_atoms = target.split('.') + attr_itr = gm + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def generate_real_tensors( + fake_tensors: list[torch.Tensor]) -> list[torch.Tensor]: + real_tensors = [] + for x in fake_tensors: + if x.dtype == torch.float32: + real_tensors.append( + torch.rand(*x.shape, + dtype=x.dtype, + layout=x.layout, + device=x.device)) + elif x.dtype == torch.int64: + real_tensors.append( + torch.randint(0, + 2, + size=x.shape, + dtype=x.dtype, + layout=x.layout, + device=x.device)) + else: + raise NotImplementedError + return real_tensors + + def backend_compile(gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]) -> Any: backend = config.get_config('backend') @@ -43,17 +78,6 @@ def backend_compile(gm: torch.fx.GraphModule, return gm elif backend == 'inductor': - def fetch_attr(gm: torch.fx.GraphModule, target: str) -> Any: - target_atoms = target.split('.') - attr_itr = gm - for i, atom in enumerate(target_atoms): - if not hasattr(attr_itr, atom): - raise RuntimeError( - f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" - ) - attr_itr = getattr(attr_itr, atom) - return attr_itr - def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool: if node.op == 'call_module': @@ -78,26 +102,9 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool: module = importlib.import_module('tmp.fx_module_' + random_number) model = module.FxModule().cuda().eval() - real_inputs = [] - for x in example_inputs: - if x.dtype == torch.float32: - real_inputs.append( - torch.rand(*x.shape, - dtype=x.dtype, - layout=x.layout, - device=x.device)) - elif x.dtype == torch.int64: - real_inputs.append( - torch.randint(0, - 2, - size=x.shape, - dtype=x.dtype, - layout=x.layout, - device=x.device)) - else: - raise NotImplementedError + real_inputs = generate_real_tensors(example_inputs) with torch.no_grad(): - script_model = torch.jit.trace(model, real_inputs) + script_model = torch.jit.script(model, real_inputs) return script_model else: raise RuntimeError(f"Unknown backend: {backend}") diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 64e4d4ea5ec6..1edd478f98d0 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -145,7 +145,7 @@ def get_name(prefix: str, name: str) -> str: self.subparam_paths[param] = get_name(prefix, name) def add_submodule(self, module: torch.nn.Module) -> None: - new_module_name = "__external_module__" + str(len(self.submodule_paths)) + new_module_name = "external_module__" + str(len(self.submodule_paths)) self.root.add_module(new_module_name, module) self.update_subpath(module, new_module_name) # self.written = True # not mark as written as graph break may happen @@ -1261,6 +1261,15 @@ def commit(self) -> None: (name, x) for x, name in self.state.fx_graph.example_inputs ]) print("graph", self.state.fx_graph.result_graph) + from .control_flow import CondModule + for node in self.state.fx_graph.result_graph.nodes: + if node.op == 'call_module' and '.' not in node.target: + mod = getattr(self.state.root, node.target) + if isinstance(mod, CondModule): + print("CondModule:", node.target) + print("true_body:", mod.true_body.graph) + print("false_body:", mod.false_body.graph) + graph_code = graph_codegen.get_code() compiled_graph = self.state.fx_graph.compile() @@ -1824,7 +1833,9 @@ def set_if_inplace_return() -> None: inplace_ref=inplace_ref, force_new_value=(func in (float, int, min, max) or (hasattr(func, '__name__') and - func.__name__ == 'contiguous'))) + func.__name__ == 'contiguous') or + (isinstance(func, torch.nn.Module) and + hasattr(func, 'inplace') and func.inplace))) return elif self.all_scalar_arg(args, kwargs) and self.all_static_arg( args, kwargs): diff --git a/test/test_model_lstm.py b/test/test_model_lstm.py index 033e799a33b6..781d2f3d01b7 100644 --- a/test/test_model_lstm.py +++ b/test/test_model_lstm.py @@ -349,7 +349,7 @@ def test_lstm_loop(caplog): hidden_size, device='cuda') expect_result = model(inputs) - for_iter_pc = 193 + for_iter_pc = 32 mark_dynamic_pc(get_next_frame_id(), for_iter_pc, DynamicControlFlow(for_iter_pc, "FOR_ITER")) compiled = compile(model) diff --git a/test/test_nnmodule.py b/test/test_nnmodule.py index ce6716330137..92f8694a3025 100644 --- a/test/test_nnmodule.py +++ b/test/test_nnmodule.py @@ -127,6 +127,31 @@ def test_map_module(caplog): run_and_check(compiled, [HIT], 1, caplog, expect_result, x) +class InplaceRelu(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + self.bn = torch.nn.BatchNorm2d(3) + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + 1.0 + + +def test_inplace_relu(caplog): + reset() + model = InplaceRelu().eval() + compiled = compile(model) + x = torch.randn(1, 3, 3, 3) + expect_result = model(x) + run_and_check(compiled, [MISS], 1, caplog, expect_result, x) + run_and_check(compiled, [HIT], 1, caplog, expect_result, x) + + if __name__ == "__main__": caplog = logging.getLogger(__name__) test_call_method(caplog)