Skip to content

Commit

Permalink
address 1x16 and 16x1 issue
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Oct 26, 2022
1 parent ffd6ca9 commit f1887e9
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 13 deletions.
5 changes: 3 additions & 2 deletions src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ Array<Array<arith::IterMark>> CheckSubspaceDivisible(const IRModule& mod,
arith::SubspaceDivide(block_realize->iter_values, collector.loop_var_domain,
collector.inner_loop_vars, block_realize->predicate,
/*require_bijective=*/false, analyzer);

if (division.empty()) {
// If we can't do perfect subspace division, check if it is a trivial case of subspace division.
// In this case, we can still blockize.
Expand Down Expand Up @@ -315,7 +315,8 @@ class BlockizedBindingExtractor {
} else {
// create iter var for the outer block
if (inner_init && iter_var->iter_type == kCommReduce) {
CHECK(is_one(division[i][0]->extent)) << "When inner_init is set to true, outer reduction var length must be equal to one";
CHECK(is_one(division[i][0]->extent))
<< "When inner_init is set to true, outer reduction var length must be equal to one";
}
const IterVar outer_var(/*dom=*/Range::FromMinExtent(0, division[i][0]->extent),
/*var=*/iter_var->var.copy_with_suffix("_o"),
Expand Down
2 changes: 1 addition & 1 deletion tests/python/sparsetir/bench_rgcn_composable.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,5 +241,5 @@ def test_rgcn_composable_format(
# homograph
ground_truth_y = get_ground_truth(g, type_pointers, feat, weight)
test_rgcn_composable_format(
g, type_pointers, feat_size, feat, weight, ground_truth_y, 4, 16, [1, 2, 4, 8, 16]
g, type_pointers, feat_size, feat, weight, ground_truth_y, 16, 32, [1, 2, 4, 8, 16]
)
175 changes: 165 additions & 10 deletions tests/python/sparsetir/bench_rgcn_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,50 @@ def wmma_sync_desc(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None
C_frag[vio, vii, vj] + A_frag[vio, vii, vk] * B_frag[vk, vj]
)

@T.prim_func
def wmma_sync_16_1_desc(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None:
A_frag = T.match_buffer(
a_frag, (16, 1, 16), "float16", align=64, offset_factor=1, scope="wmma.matrix_a"
)
B_frag = T.match_buffer(
b_frag, (16, 16), "float16", align=64, offset_factor=1, scope="wmma.matrix_b"
)
C_frag = T.match_buffer(
c_frag, (16, 1, 16), "float16", align=64, offset_factor=1, scope="wmma.accumulator"
)

with T.block("root"):
for io, ii, j, k in T.grid(16, 1, 16, 16):
with T.block("update"):
vio, vj, vk = T.axis.remap("SSR", [io, j, k])
T.block_attr({"sparse": True})
C_frag[vio, 0, vj] = (
C_frag[vio, 0, vj] + A_frag[vio, 0, vk] * B_frag[vk, vj]
)

@T.prim_func
def wmma_sync_1_16_desc(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None:
A_frag = T.match_buffer(
a_frag, (1, 16, 16), "float16", align=64, offset_factor=1, scope="wmma.matrix_a"
)
B_frag = T.match_buffer(
b_frag, (16, 16), "float16", align=64, offset_factor=1, scope="wmma.matrix_b"
)
C_frag = T.match_buffer(
c_frag, (1, 16, 16), "float16", align=64, offset_factor=1, scope="wmma.accumulator"
)

with T.block("root"):
for io, ii, j, k in T.grid(1, 16, 16, 16):
with T.block("update"):
vii, vj, vk = T.axis.remap("SSR", [ii, j, k])
T.block_attr({"sparse": True})
C_frag[0, vii, vj] = (
C_frag[0, vii, vj] + A_frag[0, vii, vk] * B_frag[vk, vj]
)



@T.prim_func
def wmma_sync_impl(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None:
A_frag = T.match_buffer(
Expand Down Expand Up @@ -82,7 +126,12 @@ def wmma_sync_impl(a_frag: T.handle, b_frag: T.handle, c_frag: T.handle) -> None
)
)

return wmma_sync_desc, wmma_sync_impl
if d0 == 1:
return wmma_sync_1_16_desc, wmma_sync_impl
elif d1 == 1:
return wmma_sync_16_1_desc, wmma_sync_impl
else:
return wmma_sync_desc, wmma_sync_impl


def wmma_load_a(d0: int, d1: int, scope: str):
Expand All @@ -94,12 +143,37 @@ def wmma_load_a_desc(a: T.handle, a_frag: T.handle) -> None:
)

with T.block("root"):
T.reads(A[0:d0, 0:d1, 0:16])
T.writes(A_frag[0:d0, 0:d1, 0:16])
for io, ii, j in T.grid(d0, d1, 16):
with T.block("load"):
vio, vii, vj = T.axis.remap("SSS", [io, ii, j])
A_frag[vio, vii, vj] = A[vio, vii, vj]

@T.prim_func
def wmma_load_a_16_1_desc(a: T.handle, a_frag: T.handle) -> None:
A = T.match_buffer(a, (16, 1, 16), "float16", align=64, offset_factor=16, scope=scope)
A_frag = T.match_buffer(
a_frag, (16, 1, 16), "float16", align=64, offset_factor=16, scope="wmma.matrix_a"
)

with T.block("root"):
for io, ii, j in T.grid(16, 1, 16):
with T.block("load"):
vio, vj = T.axis.remap("SS", [io, j])
A_frag[vio, 0, vj] = A[vio, 0, vj]

@T.prim_func
def wmma_load_a_1_16_desc(a: T.handle, a_frag: T.handle) -> None:
A = T.match_buffer(a, (1, 16, 16), "float16", align=64, offset_factor=16, scope=scope)
A_frag = T.match_buffer(
a_frag, (1, 16, 16), "float16", align=64, offset_factor=16, scope="wmma.matrix_a"
)

with T.block("root"):
for io, ii, j in T.grid(1, 16, 16):
with T.block("load"):
vii, vj = T.axis.remap("SS", [ii, j])
A_frag[0, vii, vj] = A[0, vii, vj]


@T.prim_func
def wmma_load_a_impl(a: T.handle, a_frag: T.handle) -> None:
Expand Down Expand Up @@ -138,7 +212,12 @@ def wmma_load_a_impl(a: T.handle, a_frag: T.handle) -> None:
)
)

return wmma_load_a_desc, wmma_load_a_impl
if d0 == 1:
return wmma_load_a_1_16_desc, wmma_load_a_impl
elif d1 == 1:
return wmma_load_a_16_1_desc, wmma_load_a_impl
else:
return wmma_load_a_desc, wmma_load_a_impl


def wmma_load_b(scope: str):
Expand Down Expand Up @@ -198,6 +277,29 @@ def wmma_fill_desc(c_frag: T.handle) -> None:
vio, vii, vj = T.axis.remap("SSS", [io, ii, j])
C_frag[vio, vii, vj] = T.float16(0)

@T.prim_func
def wmma_fill_16_1_desc(c_frag: T.handle) -> None:
C_frag = T.match_buffer(
c_frag, (16, 1, 16), "float16", align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
for io, ii, j in T.grid(16, 1, 16):
with T.block("init"):
vio, vj = T.axis.remap("SS", [io, j])
C_frag[vio, 0, vj] = T.float16(0)

@T.prim_func
def wmma_fill_1_16_desc(c_frag: T.handle) -> None:
C_frag = T.match_buffer(
c_frag, (1, 16, 16), "float16", align=64, offset_factor=16, scope="wmma.accumulator"
)
with T.block("root"):
for io, ii, j in T.grid(1, 16, 16):
with T.block("init"):
vii, vj = T.axis.remap("SS", [ii, j])
C_frag[0, vii, vj] = T.float16(0)


@T.prim_func
def wmma_fill_impl(c_frag: T.handle) -> None:
C_frag = T.match_buffer(
Expand All @@ -220,7 +322,12 @@ def wmma_fill_impl(c_frag: T.handle) -> None:
)
)

return wmma_fill_desc, wmma_fill_impl
if d0 == 1:
return wmma_fill_1_16_desc, wmma_fill_impl
elif d1 == 1:
return wmma_fill_16_1_desc, wmma_fill_impl
else:
return wmma_fill_desc, wmma_fill_impl


def wmma_store(d0: int, d1: int, scope: str):
Expand All @@ -236,6 +343,31 @@ def wmma_store_desc(c_frag: T.handle, c: T.handle) -> None:
vio, vii, vj = T.axis.remap("SSS", [io, ii, j])
C[vio, vii, vj] = C_frag[vio, vii, vj]

@T.prim_func
def wmma_store_desc_16_1(c_frag: T.handle, c: T.handle) -> None:
C_frag = T.match_buffer(
c_frag, (16, 1, 16), "float16", align=64, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (16, 1, 16), "float16", align=64, offset_factor=16, scope=scope)
with T.block("root"):
for io, ii, j in T.grid(16, 1, 16):
with T.block("store"):
vio, vj = T.axis.remap("SS", [io, j])
C[vio, 0, vj] = C_frag[vio, 0, vj]

@T.prim_func
def wmma_store_desc_1_16(c_frag: T.handle, c: T.handle) -> None:
C_frag = T.match_buffer(
c_frag, (1, 16, 16), "float16", align=64, offset_factor=16, scope="wmma.accumulator"
)
C = T.match_buffer(c, (1, 16, 16), "float16", align=64, offset_factor=16, scope=scope)
with T.block("root"):
for io, ii, j in T.grid(1, 16, 16):
with T.block("store"):
vii, vj = T.axis.remap("SS", [ii, j])
C[0, vii, vj] = C_frag[0, vii, vj]


@T.prim_func
def wmma_store_impl(c_frag: T.handle, c: T.handle) -> None:
s0 = T.var("int32")
Expand Down Expand Up @@ -272,7 +404,12 @@ def wmma_store_impl(c_frag: T.handle, c: T.handle) -> None:
)
)

return wmma_store_desc, wmma_store_impl
if d0 == 1:
return wmma_store_desc_1_16, wmma_store_impl
elif d1 == 1:
return wmma_store_desc_16_1, wmma_store_impl
else:
return wmma_store_desc, wmma_store_impl


@T.prim_func
Expand Down Expand Up @@ -484,6 +621,7 @@ def test_rgcn_composable_format(
sch = tir.Schedule(mod["main"])
# register load_b
tir.TensorIntrin.register("wmma_{}_load_b".format("shared"), *wmma_load_b("shared"))
tir.TensorIntrin.register("wmma_{}_load_b".format("global"), *wmma_load_b("global"))

for bucket_id, bucket_size in enumerate(buckets):
d0 = group_size // bucket_size
Expand All @@ -508,14 +646,26 @@ def test_rgcn_composable_format(
ax3_o, ax3_i = sch.split(ax3, [None, 16])
sch.reorder(ax2_o, ax3_o, ax0, ax1, ax2_i, ax3_i)
X_shared = sch.reverse_cache_read(blk_wx, 0, "shared")
sch.compute_at(X_shared, ax3_o)
sch.compute_at(X_shared, ax3_o, True)
WX_accum = sch.reverse_cache_write(blk_wx, 0, "wmma.accumulator")
W_wmma = sch.reverse_cache_read(blk_wx, 2, "wmma.matrix_b", [0, 1, 5, 4])
# W_wmma = sch.cache_read(blk_wx, 2, "wmma.matrix_b")
sch.compute_at(W_wmma, ax3_o, True)
X_wmma = sch.reverse_cache_read(blk_wx, 0, "wmma.matrix_a")
sch.bind(sch.get_loops(blk)[0], "blockIdx.x")
sch.decompose_reduction(blk_wx, ax3_o)

# unroll
ax2 = sch.get_loops(WX_accum)[-4]
sch.unroll(ax2)
ax5 = sch.get_loops(blk_wx)[-5]
ax4 = sch.get_loops(blk_wx)[-6]
sch.unroll(ax5)
sch.unroll(ax4)

# tensorize
sch.tensorize(sch.get_loops(WX_accum)[-3], "wmma_{}_{}_{}_store".format(d0, d1, "shared"))
print(sch.mod["main"].script())
sch.tensorize(sch.get_loops(X_wmma)[-3], "wmma_{}_{}_{}_load_a".format(d0, d1, "shared"))
sch.tensorize(sch.get_loops(W_wmma)[-2], "wmma_{}_load_b".format("shared"))
sch.tensorize(
Expand All @@ -527,8 +677,11 @@ def test_rgcn_composable_format(

# schedule W_shared
ax2, ax3 = sch.get_loops(W_shared)[-2:]
sch.unroll(ax2)
sch.bind(ax3, "threadIdx.x")
fused_ax = sch.fuse(ax2, ax3)
ax0, ax1, ax2 = sch.split(fused_ax, [None, 32, 8])
sch.vectorize(ax2)
sch.bind(ax1, "threadIdx.x")
sch.unroll(ax0)

# schedule X_shared
ax0, ax1, ax2 = sch.get_loops(X_shared)[-3:]
Expand All @@ -544,12 +697,14 @@ def test_rgcn_composable_format(
sch.reverse_compute_at(Y_local, ii, True)
sch.bind(fo, "threadIdx.x")
sch.unroll(j)
sch.unroll(ii)
sch.bind(sch.get_loops(Y_local)[-1], "threadIdx.x")

mod = lower_sparse_buffer(sch.mod)
mod = tvm.tir.transform.RemoveUnusedArgs()(mod)

f = tvm.build(mod["main"], target="cuda")
print(f.imported_modules[0].get_source())

# prepare inputs
dev = tvm.cuda(0)
Expand All @@ -576,7 +731,7 @@ def test_rgcn_composable_format(

if __name__ == "__main__":
for feat_size in [32]: # [4, 8, 16, 32]:
for name in ["am"]: # ["aifb", "mutag", "bgs", "am", "biokg"]:
for name in ["biokg"]: # ["aifb", "mutag", "bgs", "am", "biokg"]:
print("dataset {}, feat_size={}:".format(name, feat_size))
dataset = get_hetero_dataset(name)
g = dataset[0]
Expand Down

0 comments on commit f1887e9

Please sign in to comment.