Skip to content

Commit

Permalink
Merge pull request neo-ai#1 from trevor-m/trevmorr-fix-deeplab
Browse files Browse the repository at this point in the history
Fixes for deeplabv3
  • Loading branch information
jianzhong-xu authored Jun 2, 2020
2 parents da33fa5 + 7c0a692 commit f6cd6a1
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions python/tvm/relay/backend/contrib/tidl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,9 +1092,23 @@ def __init__(self, compiler):
# Will map index in output to subgraph param name.
self.name_map = {}

def add_new_output(self, name, expr):
self.name_map[self.num_original_outputs + len(self.additional_outputs)] = name
self.additional_outputs.append(expr)
def add_new_outputs(self, subgraph_name, expr, was_input=True):
"""Adds expr as an additional output to be generated by the module. If expr is a tuple, multiple outputs will be added."""
if isinstance(expr, Tuple):
for i, out in enumerate(expr.fields):
if was_input:
name = subgraph_name + "_" + str(i)
else:
name = subgraph_name + "_o" + str(i)
self.name_map[self.num_original_outputs + len(self.additional_outputs)] = name
self.additional_outputs.append(out)
else:
if was_input:
name = subgraph_name
else:
name = subgraph_name + "_o0"
self.name_map[self.num_original_outputs + len(self.additional_outputs)] = name
self.additional_outputs.append(expr)

def visit_call(self, call):
if isinstance(call.op, Function) and "Compiler" in call.op.attrs and call.op.attrs["Compiler"] == self.compiler:
Expand All @@ -1103,14 +1117,10 @@ def visit_call(self, call):
subgraph_name = "_".join(param.name_hint.split("_")[:2])
arg = super().visit(arg)
var_map[param] = arg
self.add_new_output(param.name_hint, arg)
self.add_new_outputs(param.name_hint, arg, was_input=True)
new_body = VarReplacer(var_map).visit(call.op.body)
# Add subgraph outputs as well
if isinstance(new_body, Tuple):
for i, out in enumerate(new_body.fields):
self.add_new_output(subgraph_name + "_o" + str(i), out)
else:
self.add_new_output(subgraph_name + "_o0", new_body)
self.add_new_outputs(subgraph_name, new_body, was_input=False)
return new_body
return super().visit_call(call)

Expand Down Expand Up @@ -1250,6 +1260,21 @@ def visit_call(self, call):
return subgraph_gv(*args)
return super().visit_call(call)

def PruneSubgraphsWithMoreThanOneInput(mod, compiler="tidl"):
subgraph_names_to_remove = []
# Remove subgraphs with more than 1 input or tuple inputs.
for subgraph in mod.get_global_vars():
name = subgraph.name_hint
if not mod[name].attrs or mod[name].attrs["Compiler"] != compiler:
continue
print("SUBGRAPH PARAMS", mod[name].params)
if len(mod[name].params) != 1 or isinstance(mod[name].params[0].checked_type, relay.TupleType):
subgraph_names_to_remove.append(name)
print("Removing subgraphs due to having more than one input:", subgraph_names_to_remove)
new_mod = tvm.IRModule()
new_mod["main"] = SubgraphRemover(subgraph_names_to_remove, mod, new_mod).visit(mod["main"])
return new_mod

def PruneSubgraphs(mod, compiler="tidl", num_subgraphs_to_keep=4):
subgraph_with_macs = []
for subgraph in mod.get_global_vars():
Expand Down Expand Up @@ -1297,6 +1322,7 @@ def EnableTIDL(mod, params, num_tidl_subgraphs,
print("---------- Unpack composite ops in the graph ----------")
mod = UnpackComposites(mod, "tidl")
print("---------- Prune Graph ----------")
mod = PruneSubgraphsWithMoreThanOneInput(mod, compiler="tidl")
mod = PruneSubgraphs(mod, compiler="tidl", num_subgraphs_to_keep=num_tidl_subgraphs)
print(mod.astext(show_meta_data=False))

Expand Down

0 comments on commit f6cd6a1

Please sign in to comment.