Skip to content

Commit

Permalink
[TUZ-201] Support for end to end SpaceV5 (apache#58)
Browse files Browse the repository at this point in the history
Change to octo utility so that it will autopopulate available external
libs, so far we only check for `thrust`.

I also disable epilogue fusion of matmul for now since it restricts
batch size to 1.
  • Loading branch information
Josh Fromm authored Mar 21, 2023
1 parent f4f8bbe commit 36b9196
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
17 changes: 16 additions & 1 deletion python/tvm/octo/utils/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,22 @@ def get_cuda_target() -> tvm.target.Target:
# To do so, lowercase the name and replace spaces with dases.
target_name = "nvidia/" + product_name.replace(" ", "-").lower()

return tvm.target.Target(target_name)
target = tvm.target.Target(target_name)

# Attach libs if available.
# Check if thrust symbols are defined.
libs = []
if tvm._ffi.get_global_func("tvm.contrib.thrust.sum_scan", allow_missing=True):
libs.append("thrust")

# Append libs to target.
target = str(target)
if libs:
target += " -libs="
for lib in libs:
target += f"{lib},"

return tvm.target.Target(target)


def get_default_threads() -> int:
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def residual_block_patterns():
for activation, name_postfix in [(None, ""), ("relax.nn.relu", "_relu")]:
for check, base_patterns in [
(_check_conv2d, conv2d_patterns()),
(_check_matmul, matmul_patterns()),
# TODO(jwfromm) Reenable once epilogue fusion is supported for bs > 1.
# (_check_matmul, matmul_patterns()),
]:
for name, pat, arg_pat, _ in base_patterns:
# Append residual patterns only to those base patterns with bias add,
Expand Down

0 comments on commit 36b9196

Please sign in to comment.