Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TVM backend #232

Closed
interesaaat opened this issue Aug 14, 2020 · 107 comments
Closed

Add TVM backend #232

interesaaat opened this issue Aug 14, 2020 · 107 comments
Assignees

Comments

@interesaaat
Copy link
Collaborator

No description provided.

@interesaaat interesaaat self-assigned this Aug 14, 2020
@interesaaat
Copy link
Collaborator Author

To add some documentation on what is happening. We had a first version of TVM integrated. For Tree models only GEMM worked because index_select was missing in TVM. This PR added index_select.

Afterwards we saw that there was some typing error. This PR fixed that. The related issue explain the problem in details. Now we have GEMM and the two tree traversal strategies working on TVM.

@masahi
Copy link
Collaborator

masahi commented Aug 20, 2020

One thing I noticed is that models coming from hummingbird is extremely slow to compile on TVM. The following test with the data size (20, 10) doesn't finishing compiling in 30 min on my laptop. The Relay model doesn't look complicated, and for standard imagenet models compilation doesn't take more than 1 min. I have a feeling that there is an infinite loop going on inside TVM codegen.

Do you have a similar observation?

num_classes = 2
size = 20
X = np.random.rand(size, 10).astype(np.float32)
y = np.random.randint(num_classes, size=size)

model = lgb.LGBMClassifier(n_estimators=3, min_child_samples=1)
model.fit(X, y)

# # tree_implementation = "gemm"
tree_implementation = "tree_trav"
# tree_implementation = "perf_tree_trav"

tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})

@interesaaat
Copy link
Collaborator Author

interesaaat commented Aug 20, 2020

I was actually testing it myself this morning and I had troubles with the lightgbm notebook (it crashed the kernel, i didn't had the chance to debug further but I guess it's the same problem). The blog notebook worked well.

I will try with your test and let you know. Any idea on what is going on?

@masahi
Copy link
Collaborator

masahi commented Aug 20, 2020

I have no idea, but since conversion to Relay is instant, this is not a frontend problem, but somewhere deep inside codegen.

It's likely that workloads from hummingbird are novel ones that TVM hasn't seen. It would be really interesting if debugging this problem uncovers a bug in TVM codegen.

@interesaaat
Copy link
Collaborator Author

On my mac actually worked. It took like few minutes (probably 3 or 4).

@masahi
Copy link
Collaborator

masahi commented Aug 20, 2020

Interesting, but still taking few minutes for such a simple model doesn't seem right. I've seen a report saying compiling Mask RCNN taking 20 min https://discuss.tvm.ai/t/vm-slow-compilation-of-tf-object-detection-models/7479/8

I'm assuming making dataset bigger also makes trees and hence converted models bigger? I wonder how much it would take to compile for a real world dataset.

(I've read your OSDI paper yesterday and it says TVM compliation is slow).

@interesaaat
Copy link
Collaborator Author

Yes for some models on GPU I think we saw something like in the ballpark of 40 minutes. This was with some GBDT with 500 trees on a dataset with few hundreds features. In the paper however we used custom Relay implementations (not the PyTorch frontend, which should make no difference), few custom operators, and TVM 0.6.

Our goal for the moment is to replicate the same speedups we saw using the custom models with your fronted. And fix the compilation time while we do this :)

@interesaaat
Copy link
Collaborator Author

Now I am having the same problem. I increased the numbers of trees to 10 and the input dataset to 2000x28 and it is still compiling after 10 minutes. I tried with the dataset we were using in the lgbm notebook (200k records) and compilation throws a segfault.

@interesaaat
Copy link
Collaborator Author

Ok, still running after 1 hour. My setting:

num_classes = 2
X = np.random.rand(2000, 28)
X = np.array(X, dtype=np.float32)
y = np.random.randint(num_classes, size=2000)

model = lgb.LGBMClassifier(n_estimators=10, min_child_samples=1)
model.fit(X, y)

#tree_implementation = "gemm"
tree_implementation = "tree_trav"
#tree_implementation = "perf_tree_trav"

tvm_model = hummingbird.ml.convert(model, "tvm", X, extra_config={"tree_implementation": tree_implementation})

@masahi do you think we can have someone looking into this? It get unmanageable quite fast, this was not happening with our internal Relay models on 0.6.

@masahi
Copy link
Collaborator

masahi commented Aug 21, 2020

We can open an issue there once we understand what is going on inside codegen. I have no idea yet, but I know where to look. If this is something I can fix, I'll do it.

One thing worth checking now is comparing the output of PyTorch frontend and the hand written one you have.

@interesaaat
Copy link
Collaborator Author

Getting that version will take a while, but I will try. Will keep you posted once I get there, please share if you find anything in between.

@masahi
Copy link
Collaborator

masahi commented Aug 21, 2020

I did some digging. It seems the problem is due to the very aggresive operator fusion TVM does and its consquence of very large inlined expression.

First, this is the Relay graph for dataset of size (5, 10).

fn (%v_operator_map.SklearnLGBMClassifier.values: Tensor[(3, 1), float32], %input: Tensor[(5, 10), float32], %v_operator_map.SklearnLGBMClassifier.features: Tensor[(3), int64], %v_operator_map.SklearnLGBMClassifier.nodes_offset: Tensor[(1, 1), int64], %v_operator_map.SklearnLGBMClassifier.thresholds: Tensor[(3), float32], %v_operator_map.SklearnLGBMClassifier.rights: Tensor[(3), float32], %v_operator_map.SklearnLGBMClassifier.lefts: Tensor[(3), float32]) -> (Tensor[(5), int32], Tensor[(5, 2), float32]) {
  %0 = (%v_operator_map.SklearnLGBMClassifier.nodes_offset, %v_operator_map.SklearnLGBMClassifier.nodes_offset, %v_operator_map.SklearnLGBMClassifier.nodes_offset, %v_operator_map.SklearnLGBMClassifier.nodes_offset, %v_operator_map.SklearnLGBMClassifier.nodes_offset);
  %1 = concatenate(%0) /* ty=Tensor[(5, 1), int64] */;
  %2 = reshape(%1, newshape=[-1]) /* ty=Tensor[(5), int64] */;
  %3 = take(%v_operator_map.SklearnLGBMClassifier.features, %2, axis=0) /* ty=Tensor[(5), int64] */;
  %4 = reshape(%3, newshape=[-1, 1]) /* ty=Tensor[(5, 1), int64] */;
  %5 = gather(%input, %4, axis=1) /* ty=Tensor[(5, 1), float32] */;
  %6 = take(%v_operator_map.SklearnLGBMClassifier.thresholds, %2, axis=0) /* ty=Tensor[(5), float32] */;
  %7 = reshape(%6, newshape=[-1, 1]) /* ty=Tensor[(5, 1), float32] */;
  %8 = greater_equal(%5, %7) /* ty=Tensor[(5, 1), bool] */;
  %9 = take(%v_operator_map.SklearnLGBMClassifier.rights, %2, axis=0) /* ty=Tensor[(5), float32] */;
  %10 = reshape(%9, newshape=[-1, 1]) /* ty=Tensor[(5, 1), float32] */;
  %11 = take(%v_operator_map.SklearnLGBMClassifier.lefts, %2, axis=0) /* ty=Tensor[(5), float32] */;
  %12 = reshape(%11, newshape=[-1, 1]) /* ty=Tensor[(5, 1), float32] */;
  %13 = where(%8, %10, %12) /* ty=Tensor[(5, 1), float32] */;
  %14 = cast(%13, dtype="int64") /* ty=Tensor[(5, 1), int64] */;
  %15 = add(%14, %v_operator_map.SklearnLGBMClassifier.nodes_offset) /* ty=Tensor[(5, 1), int64] */;
  %16 = reshape(%15, newshape=[-1]) /* ty=Tensor[(5), int64] */;
  %17 = take(%v_operator_map.SklearnLGBMClassifier.values, %16, axis=0) /* ty=Tensor[(5, 1), float32] */;
  %18 = reshape(%17, newshape=[-1, 1, 1]) /* ty=Tensor[(5, 1, 1), float32] */;
  %19 = reshape(%18, newshape=[-1, 1, 1]) /* ty=Tensor[(5, 1, 1), float32] */;
  %20 = sum(%19, axis=[2]) /* ty=Tensor[(5, 1), float32] */;
  %21 = sigmoid(%20) /* ty=Tensor[(5, 1), float32] */;
  %22 = multiply(1f /* ty=float32 */, %21) /* ty=Tensor[(5, 1), float32] */;
  %23 = subtract(1f /* ty=float32 */, %22) /* ty=Tensor[(5, 1), float32] */;
  %24 = (%23, %21);
  %25 = concatenate(%24, axis=1) /* ty=Tensor[(5, 2), float32] */;
  %26 = argmax(%25, axis=[1]) /* ty=Tensor[(5), int32] */;
  %27 = (%26, %25);
  %28 = %27.0;
  %29 = %27.1;
  (%28, %29)
}

This is the same graph after operator fusion. Pay particular attention to the first fused op named %7, which has gather, where, take etc. This fused op has 7 ops inside it. If we increase the size of dataset, of course we end up with more operators in the graph, but all of these added ops will be fused into this first fused function. The concrete example will follow below.

def @main(%input: Tensor[(5, 10), float32]) -> (Tensor[(5), int32], Tensor[(5, 2), float32]) {
  %7 = fn (%p0: Tensor[(3, 1), float32], %p1: Tensor[(5, 10), float32], %p2: Tensor[(5, 1), int64], %p3: Tensor[(5, 1), float32], %p4: Tensor[(5, 1), float32], %p5: Tensor[(5, 1), float32], %p6: Tensor[(1, 1), int64], Primitive=1) -> Tensor[(5, 1, 1), float32] {
    %0 = gather(%p1, %p2, axis=1) /* ty=Tensor[(5, 1), float32] */;
    %1 = greater_equal(%0, %p3) /* ty=Tensor[(5, 1), bool] */;
    %2 = where(%1, %p4, %p5) /* ty=Tensor[(5, 1), float32] */;
    %3 = cast(%2, dtype="int64") /* ty=Tensor[(5, 1), int64] */;
    %4 = add(%3, %p6) /* ty=Tensor[(5, 1), int64] */;
    %5 = reshape(%4, newshape=[-1]) /* ty=Tensor[(5), int64] */;
    %6 = take(%p0, %5, axis=0) /* ty=Tensor[(5, 1), float32] */;
    reshape(%6, newshape=[5, 1, 1]) /* ty=Tensor[(5, 1, 1), float32] */
  };
  %8 = %7(meta[relay.Constant][0] /* ty=Tensor[(3, 1), float32] */ /* ty=Tensor[(3, 1), float32] */, %input, meta[relay.Constant][1] /* ty=Tensor[(5, 1), int64] */ /* ty=Tensor[(5, 1), int64] */, meta[relay.Constant][2] /* ty=Tensor[(5, 1), float32] */ /* ty=Tensor[(5, 1), float32] */, meta[relay.Constant][3] /* ty=Tensor[(5, 1), float32] */ /* ty=Tensor[(5, 1), float32] */, meta[relay.Constant][4] /* ty=Tensor[(5, 1), float32] */ /* ty=Tensor[(5, 1), float32] */, meta[relay.Constant][5] /* ty=Tensor[(1, 1), int64] */ /* ty=Tensor[(1, 1), int64] */) /* ty=Tensor[(5, 1, 1), float32] */;
  %9 = fn (%p01: Tensor[(5, 1, 1), float32], Primitive=1) -> Tensor[(5, 1), float32] {
    sum(%p01, axis=[2]) /* ty=Tensor[(5, 1), float32] */
  };
  %10 = %9(%8) /* ty=Tensor[(5, 1), float32] */;
  %15 = fn (%p02: Tensor[(5, 1), float32], Primitive=1) -> Tensor[(5, 2), float32] {
    %11 = sigmoid(%p02) /* ty=Tensor[(5, 1), float32] */;
    %12 = multiply(1f /* ty=float32 */, %11) /* ty=Tensor[(5, 1), float32] */;
    %13 = subtract(1f /* ty=float32 */, %12) /* ty=Tensor[(5, 1), float32] */;
    %14 = (%13, %11);
    concatenate(%14, axis=1) /* ty=Tensor[(5, 2), float32] */
  };
  %16 = %15(%10) /* ty=Tensor[(5, 2), float32] */;
  %17 = fn (%p03: Tensor[(5, 2), float32], Primitive=1) -> Tensor[(5), int32] {
    argmax(%p03, axis=[1]) /* ty=Tensor[(5), int32] */
  };
  %18 = %17(%16) /* ty=Tensor[(5), int32] */;
  (%18, %16)
}

This is the low level TVM IR for the first fused function above with 7 operators in it. Gather, where, add, reshape, take etc are fused into a one liner expression. It is not complicated, and compilation is instant.

buffer_realize T_reshape([0, 5], [0, 1], [0, 1]) {
  parallel (ax0.ax1.fused, 0, 5) {
    T_reshape[ax0.ax1.fused, 0, 0] = placeholder[min(max((int64)0, 
(int64(select((int32((placeholder[floormod(floormod(((ax0.ax1.fused + 0) + 0), 5), 5), 
placeholder[floormod(floormod(((ax0.ax1.fused + 0) + 0), 5), 5), 0]] >= 
placeholder[floormod(floormod(((ax0.ax1.fused + 0) + 0), 5), 5), 0])) != 0), 
placeholder[floormod(floormod(((ax0.ax1.fused + 0) + 0), 5), 5), 0], 
placeholder[floormod(floormod(((ax0.ax1.fused + 0) + 0), 5), 5), 0])) + placeholder[0, 0])), (int64)2), 0]

  }
}

Now, if we increase the size of dataset to (6, 10), we get this fused graph. Note that the first fused function now contains more than 20 ops.

def @main(%input: Tensor[(6, 10), float32]) -> (Tensor[(6), int32], Tensor[(6, 2), float32]) {
  %21 = fn (%p0: Tensor[(21, 1), float32], %p1: Tensor[(6, 10), float32], %p2: Tensor[(21), int64], %p3: Tensor[(6, 3), int64], %p4: Tensor[(6, 3), float32], %p5: Tensor[(6, 3), float32], %p6: Tensor[(6, 3), float32], %p7: Tensor[(1, 3), int64], %p8: Tensor[(21), float32], %p9: Tensor[(21), float32], %p10: Tensor[(21), float32], Primitive=1) -> Tensor[(6, 1, 3), float32] {
    %0 = gather(%p1, %p3, axis=1) /* ty=Tensor[(6, 3), float32] */;
    %1 = greater_equal(%0, %p4) /* ty=Tensor[(6, 3), bool] */;
    %2 = where(%1, %p5, %p6) /* ty=Tensor[(6, 3), float32] */;
    %3 = cast(%2, dtype="int64") /* ty=Tensor[(6, 3), int64] */;
    %4 = add(%3, %p7) /* ty=Tensor[(6, 3), int64] */;
    %5 = reshape(%4, newshape=[-1]) /* ty=Tensor[(18), int64] */;
    %6 = take(%p2, %5, axis=0) /* ty=Tensor[(18), int64] */;
    %7 = reshape(%6, newshape=[-1, 3]) /* ty=Tensor[(6, 3), int64] */;
    %8 = gather(%p1, %7, axis=1) /* ty=Tensor[(6, 3), float32] */;
    %9 = take(%p8, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %10 = reshape(%9, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %11 = greater_equal(%8, %10) /* ty=Tensor[(6, 3), bool] */;
    %12 = take(%p9, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %13 = reshape(%12, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %14 = take(%p10, %5, axis=0) /* ty=Tensor[(18), float32] */;
    %15 = reshape(%14, newshape=[-1, 3]) /* ty=Tensor[(6, 3), float32] */;
    %16 = where(%11, %13, %15) /* ty=Tensor[(6, 3), float32] */;
    %17 = cast(%16, dtype="int64") /* ty=Tensor[(6, 3), int64] */;
    %18 = add(%17, %p7) /* ty=Tensor[(6, 3), int64] */;
    %19 = reshape(%18, newshape=[-1]) /* ty=Tensor[(18), int64] */;
    %20 = take(%p0, %19, axis=0) /* ty=Tensor[(18, 1), float32] */;
    reshape(%20, newshape=[6, 1, 3]) /* ty=Tensor[(6, 1, 3), float32] */
  };
  %22 = %21(meta[relay.Constant][0] /* ty=Tensor[(21, 1), float32] */ /* ty=Tensor[(21, 1), float32] */, %input, meta[relay.Constant][1] /* ty=Tensor[(21), int64] */ /* ty=Tensor[(21), int64] */, meta[relay.Constant][2] /* ty=Tensor[(6, 3), int64] */ /* ty=Tensor[(6, 3), int64] */, meta[relay.Constant][3] /* ty=Tensor[(6, 3), float32] */ /* ty=Tensor[(6, 3), float32] */, meta[relay.Constant][4] /* ty=Tensor[(6, 3), float32] */ /* ty=Tensor[(6, 3), float32] */, meta[relay.Constant][5] /* ty=Tensor[(6, 3), float32] */ /* ty=Tensor[(6, 3), float32] */, meta[relay.Constant][6] /* ty=Tensor[(1, 3), int64] */ /* ty=Tensor[(1, 3), int64] */, meta[relay.Constant][7] /* ty=Tensor[(21), float32] */ /* ty=Tensor[(21), float32] */, meta[relay.Constant][8] /* ty=Tensor[(21), float32] */ /* ty=Tensor[(21), float32] */, meta[relay.Constant][9] /* ty=Tensor[(21), float32] */ /* ty=Tensor[(21), float32] */) /* ty=Tensor[(6, 1, 3), float32] */;
  %23 = fn (%p01: Tensor[(6, 1, 3), float32], Primitive=1) -> Tensor[(6, 1), float32] {
    sum(%p01, axis=[2]) /* ty=Tensor[(6, 1), float32] */
  };
  %24 = %23(%22) /* ty=Tensor[(6, 1), float32] */;
  %29 = fn (%p02: Tensor[(6, 1), float32], Primitive=1) -> Tensor[(6, 2), float32] {
    %25 = sigmoid(%p02) /* ty=Tensor[(6, 1), float32] */;
    %26 = multiply(1f /* ty=float32 */, %25) /* ty=Tensor[(6, 1), float32] */;
    %27 = subtract(1f /* ty=float32 */, %26) /* ty=Tensor[(6, 1), float32] */;
    %28 = (%27, %25);
    concatenate(%28, axis=1) /* ty=Tensor[(6, 2), float32] */
  };
  %30 = %29(%24) /* ty=Tensor[(6, 2), float32] */;
  %31 = fn (%p03: Tensor[(6, 2), float32], Primitive=1) -> Tensor[(6), int32] {
    argmax(%p03, axis=[1]) /* ty=Tensor[(6), int32] */
  };
  %32 = %31(%30) /* ty=Tensor[(6), int32] */;
  (%32, %30)
}

This is the corresponding low level IR for the first fused function with 20 ops. 20 ops are inlined into a one liner expression, and it is huge.

buffer_realize T_reshape([0, 6], [0, 1], [0, 3]) {
  parallel (ax0.ax1.fused, 0, 6) {
    vectorized (ax2.inner, 0, 3) {
      T_reshape[ax0.ax1.fused, 0, ax2.inner] = placeholder[min(max((int64)0,
      (int64(select((int32((placeholder[floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6),
      placeholder[min(max((int64)0, (int64(select((int32((placeholder[floormod(floordiv(floormod(((
      floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) +
      ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)),
      18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3)
      + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)],
      placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) +  ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
      placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)]] >= placeholder[min(max((int64)0,
      (int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((
      floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(
      floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3),
      6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(
      floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
      floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)], placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3)
      + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
      placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)])) != 0), placeholder[min(max((int64)0,
      (int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), placeholder[floormod(floordiv(floormod(((floormod(floordiv(
      floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
      floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3)
      + ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18),
      3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) != 0),
      placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) +
      floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused +
      0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)],
      placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18)
      , 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) + placeholder[0, floormod(floormod(((floormod(floordiv(
      floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])),
      (int64)20)], placeholder[min(max((int64)0, (int64(select((int32((placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
      placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod(((
      (ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)]] >= placeholder[floormod(floordiv(floormod(((floormod(
      floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)
      , 6), floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)
      *3) + ax2.inner), 18), 3)), 18), 3)])) != 0), placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) +
      ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6),
      floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3)
      + ax2.inner), 18), 3)), 18), 3)], placeholder[floormod(floordiv(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)
      , 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3), 6), floormod(floormod(((floormod(floordiv(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])) +
      placeholder[0, floormod(floormod(((floormod(floordiv(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3), 6)*3) + floormod(floormod((((
      ax0.ax1.fused + 0)*3) + ax2.inner), 18), 3)), 18), 3)])), (int64)20)])) + placeholder[0, floormod(floormod((((ax0.ax1.fused + 0)*3) + ax2.inner),
      18), 3)])), (int64)20), 0]
    }
  }
}

If the dataset size is (20, 10), I get a fused function with 70-90 ops inside. The result one liner expression is so huge it takes forever to compile.

What I don't understand yet is why TVM generates such a huge, messy expression from your subgraph. My guess is it is due to the complex indexing logic that is unique to your use case. Usually operations involved in operator fusion are simple elemwise operations whose indexing logic is trivial. So aggressive op fusion done by TVM is always beneficial.

@interesaaat
Copy link
Collaborator Author

This is really interesting, thanks for digging into it! I have few questions:

  • What changed between TVM v0.6 and 0.7 such that this happens?
  • In our custom implementation directly using Relay, I remember that where was implemented has custom op. (@scnakandala can weight in on why we did this, I guess that it was because where was not implemented back then). Do you think that this "aggressive inline" was not triggered at the time because of this custom op?
  • Is there any way to by-pass this problem, beside disabling operator fusion? Something like a configuration knob telling TVM to not fuse more than X operators together?

@masahi
Copy link
Collaborator

masahi commented Aug 21, 2020

I opened a thread in their forum to discuss this issue https://discuss.tvm.ai/t/aggressive-operator-fusion-and-its-consequence-of-huge-inlined-tir-expression/7687

* What changed between TVM v0.6 and 0.7 such that this happens?

I'm pretty sure the op fusion logic hasn't changed between 0.6 and 0.7. But I do know that there have been a lot of effort around refactoring their lower level IR. I'm not sure if that change is related to this problem.

* In our custom implementation directly using Relay, I remember that `where` was implemented has custom op. (@scnakandala can weight in on why we did this, I guess that it was because `where` was not implemented back then). Do you think that this "aggressive inline" was not triggered at the time because of this custom op?

Yes. If there is a op which is foreign to TVM, TVM will treat it as "opaque" op and it breaks the fusion there.

* Is there any way to by-pass this problem, beside disabling operator fusion? Something like a configuration knob telling TVM to not fuse more than X operators together?

I think disabling op fusion is not possible because later passes depend on op fusion being applied. It is certainly possible to make the maximum size of fused functions configurable, we don't have that at the moment. Actually there is a hardcoded constant https://github.com/apache/incubator-tvm/blob/master/src/relay/transforms/fuse_ops.cc#L82 that limits the size of fused function. I tried changing it to something like 10 but somehow it didn't change anything. I'll look more into this later today.

@interesaaat
Copy link
Collaborator Author

Thanks for opening the thread on the forum! Looking forward for a possible generic solution (e.g, without having to resort to custom ops). For the models where compilation works, the numbers look on par with what we saw using the custom model implementation, so this is exciting!

@masahi
Copy link
Collaborator

masahi commented Aug 21, 2020

According to https://discuss.tvm.ai/t/aggressive-operator-fusion-and-its-consequence-of-huge-inlined-tir-expression/7687/2, the reason of code blow up is due to indexing operations being duplicated at the lower IR level. I think this makes a lot of sense, although it's not obvious to me exactly where duplication is happening.

Do you have an idea where could this happen? If the graph has a diamond-like pattern (i.e. a node with multiple incoming edges going back to a common ancestor), this could certainly happen.

@interesaaat
Copy link
Collaborator Author

interesaaat commented Aug 21, 2020

Ok I think that diamond patter appears because the model actually iterates over trees (all batched together) at a level by level fashion. So indexes in the i-th level are used at the i-th + 1 level. (code here)

Is there a way we can tell TVM to fuse all ops within a level, but across levels? Like do fusion within each for loop iteration but not across (I think that's why we are seeing this diamond pattern).

@masahi
Copy link
Collaborator

masahi commented Aug 21, 2020

No, because at the Relay IR level, the loop is unrolled so TVM cannot find a loop boundary. Tracing by Torch already does unrolling (Torchscript cannot generate jit modules with for loop).

What we can do immediately to unblock your development is to make where op opaque to the fusion pass.
This should have the same effect of where op being implemented as if it was a custom op.

You can do that by changing kBroadcast at

https://github.com/apache/incubator-tvm/blob/9b8eb816116aeaf235a42e334e3ad42de099557c/src/relay/op/tensor/transform.cc#L1719

to kOpaque. This seems to break the original big fused function into small pieces. It works but performance would be suboptimal.

With this change, I can compile for the dataset of (500, 10) on my laptop. With (1000, 10) I got segfault.

@interesaaat
Copy link
Collaborator Author

Is the segfault due to the same problem?

Ok I will do that and see what this unblocks.

@scnakandala
Copy link
Contributor

scnakandala commented Aug 23, 2020

Ok I think that diamond patter appears because the model actually iterates over trees (all batched together) at a level by level fashion. So indexes in the i-th level are used at the i-th + 1 level.

@interesaaat I think the reason why we didn't get this issue in the prototype system is because of the use of custom WhereOp which braked the fusion at every tree depth level. Back then there was no support Where in Relay.

From some previous experiments, I remember operator fusion was the main reason for low runtimes in TVM compared to TorchScript. When I disabled optimizations in TVM it was on par with TorchScript, which is expected. But operator fusion only at the tree depth-level gave some good performance. Across levels, we anyway need to bring in new weights and indices tensors from memory.

In the ideal case, if TVM support a configurable fusion depth we can have some heuristic values for tree translation which takes into account the shape of the trees.

@masahi
Copy link
Collaborator

masahi commented Aug 24, 2020

The configurable fuse depth is supported in apache/tvm#6327. With this change there is no need to change where op to kOpaque (mentioned above).

But note that due to the way fusion works in TVM, the resulting size of fused subgraphs are not "tight" wrt the max fused depth in the sense that, if the max depth is 20, and in the middle of fusion pass there are two consecutive subgraphs with size 11 each, they will remain as 11 size subgraphs. We won't move nodes from one to another to make one of the subgraph have maximum possible nodes. So you should assume that the max fuse depth is a conservative bound.

To enable this config,

max_fuse_depth = 20
with tvm.transform.PassContext(opt_level=3, config={"relay.FuseOps.max_depth": max_fuse_depth}):
    graph, lib, params = relay.build(model, target=target, params=params)

To dump the graph after fusion, you can do

with tvm.transform.PassContext(opt_level=3, config={"relay.FuseOps.max_depth": max_fuse_depth}):
    opt_mod, opt_params = relay.optimize(model, target=target, params=params)
    print(opt_mod["main"])

@masahi
Copy link
Collaborator

masahi commented Aug 24, 2020

Is the segfault due to the same problem?

@interesaaat It seems the reason I got the segfault with big dataset is, there is a very large concatenation of parameter tensors at the beginning of the graph. For example, for the dataset of size (10, 10), there is this function that can be evaluated at compile time (since all param tensors are constant)

fn () -> Tensor[(10, 50), int64] {
  %0 = (meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */, meta[relay.Constant][0] /* ty=Tensor[(1, 50), int64] */ /* ty=Tensor[(1, 50), int64] */);
  %1 = fn (%p0: (Tensor[(1, 50), int64], Tensor[(1, 50), int64], Tensor[(1, 50), int64], Tensor[(1, 50), int64], Tensor[(1, 50), int64], Tensor[(1, 50), int64], Tensor[(1, 50), int64], Tensor[(1, 50), int64], Tensor[(1, 50), int64], Tensor[(1, 50), int64]), Primitive=1) -> Tensor[(10, 50), int64] {
    concatenate(%p0)
  };
  %1(%0)
}

The %0 above is a tuple of 10 (1, 50) tensors, all that function above does is create a (10, 50) tensor by concatenating the tuple of tensors. For a dataset of size (1000, 10), the above tuple contains 1000 (1, 50) tensors, and TVM gets segfault during evaluation of concatenate.

Do you know where this concatenation of param tensors happen? Since this is completely a compile time operation, it's better to do concat in Torch and pass the concat-ed tensor to TVM.

@scnakandala
Copy link
Contributor

@masahi I think it is coming from here: https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/_tree_implementations.py#L224

We rely on the PyTorch input to get the batch size as TorchScript runtime can support any batch size.
@interesaaat I think for TVM backend we can have an additional batch_size argument given at the compile time to create this tensor at compile time.

@interesaaat
Copy link
Collaborator Author

Thanks @masahi! I am going to try with your PR and see what we get. From your comments I see that pretty much now it solves the problem, and we only need to fix the batch size as @scnakandala suggested (will send a fix shortly). Question: do you think we can fix the diamond pattern thing. Setting the limits is good to unblock us, but is un-optimal. I am more than happy to help in case.

@interesaaat
Copy link
Collaborator Author

Ok I have pushed a commit fixing max depth and batch size and now TVM works, even with large datasets :)

@masahi
Copy link
Collaborator

masahi commented Aug 24, 2020

Question: do you think we can fix the diamond pattern thing. Setting the limits is good to unblock us, but is un-optimal. I am more than happy to help in case.

Do you mean you want to fuse only the inner loop (one tree level)? That would creates many smaller functions, so compile time is certainly faster but the runtime perf might be slower than having big functions whose size is bounded by max depth param.

But big fused functions generated by TVM would contain many duplicated expression, so the performance relies on CSE being done by LLVM or NVCC.

@interesaaat
Copy link
Collaborator Author

Ok so bottom line, this is the best we can do.

@masahi
Copy link
Collaborator

masahi commented Aug 24, 2020

Yes I think so. We should certainly benchmark the runtime performance with different fuse depth, to see how much fusion helps and if CSE by LLVM and NVCC are done as expected.

@interesaaat
Copy link
Collaborator Author

Once I finish with covering the other operators I can work on the benchmark.

One question: to remove the warnings such as Cannot find config for target=llvm -keys=cpu, workload=('batch_matmul.x86', ('TENSOR', (1, 100, 99), 'float32'), ('TENSOR', (1, 100, 99), 'float32')). A fallback configuration is used, which may bring great performance regression. I need to add something like this, right?

@interesaaat
Copy link
Collaborator Author

interesaaat commented Sep 15, 2020

Same 1.7s.

fn (%input: Tensor[(10000, 28), float32]) {
  %0 = hb.tree_beampp_root(%input, meta[relay.Constant][1], meta[relay.Constant][2]);
  %1 = hb.tree_beampp_internal(%input, %0, meta[relay.Constant][3], meta[relay.Constant][4]);
  %2 = hb.tree_beampp_internal(%input, %1, meta[relay.Constant][5], meta[relay.Constant][6]);
  %3 = hb.tree_beampp_internal(%input, %2, meta[relay.Constant][7], meta[relay.Constant][8]);
  %4 = hb.tree_beampp_internal(%input, %3, meta[relay.Constant][9], meta[relay.Constant][10]);
  %5 = hb.tree_beampp_internal(%input, %4, meta[relay.Constant][11], meta[relay.Constant][12]);
  %6 = hb.tree_beampp_internal(%input, %5, meta[relay.Constant][13], meta[relay.Constant][14]);
  %7 = hb.tree_beampp_internal(%input, %6, meta[relay.Constant][15], meta[relay.Constant][16]);
  %8 = reshape(%7, newshape=[-1]);
  %9 = take(meta[relay.Constant][0], %8, axis=0);
  %10 = reshape(%9, newshape=[-1, 500, 1]);
  %11 = reshape(%10, newshape=[-1, 1, 500]);
  %12 = sum(%11, axis=[2]);
  %13 = sigmoid(%12);
  %14 = subtract(1f, %13);
  %15 = (%14, %13);
  %16 = concatenate(%15, axis=1);
  %17 = argmax(%16, axis=[1]);
  (%17, %16)
}

@masahi
Copy link
Collaborator

masahi commented Sep 15, 2020

It would be interesting to see the difference between a custom op and open source version in more detail. If indeed they are equivalent, the lowered IR should also be equivalent. We need to debug, I think such a large perf delta shouldn't exist.

To me it is also a bit surprising to see prediction on 10000 rows taking more than 1 second.

@interesaaat
Copy link
Collaborator Author

That's for 500 trees of dept 8. Regulare xgb takes about 2s. What can I do to debug this, I have everything already setup.

@masahi
Copy link
Collaborator

masahi commented Sep 15, 2020

I want to take a look at this if your are open sourcing this stuff. If you like, I'd start by looking at the lowered IR by uncommenting this line. _build.lower needs to be replaced with tvm.driver.lower

https://github.com/apache/incubator-tvm/blob/e63e08febd682f40a536075998a6839bccccd3c6/python/tvm/relay/backend/_backend.py#L51

@interesaaat
Copy link
Collaborator Author

interesaaat commented Sep 16, 2020

Ok trying now. Open sourcing the code can be a bit problematic because I need to take a cut of the whole code and most likely the cut won't work end to end. Is this the expected output? I can do the same for the custom code.

primfn(placeholder_1: handle, placeholder_red_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {placeholder_red: Buffer(placeholder_red_2: Pointer(int32), int32, [10000], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [10000, 2], [])}
  buffer_map = {placeholder_1: placeholder, placeholder_red_1: placeholder_red} {
  attr [placeholder_red_temp.v0: Pointer(int32)] "storage_scope" = "global";
  allocate(placeholder_red_temp.v0, int32, [10000]);
  attr [placeholder_red_temp.v1: Pointer(float32)] "storage_scope" = "global";
  allocate(placeholder_red_temp.v1, float32, [10000]) {
    for (ax0: int32, 0, 10000) "parallel" {
      placeholder_red_temp.v0[ax0] = -1
      placeholder_red_temp.v1[ax0] = -3.40282e+38f32
      for (k1: int32, 0, 2) {
        placeholder_red_temp.v0[ax0] = @tir.if_then_else(((float32*)placeholder_2[((ax0*2) + k1)] <= (float32*)placeholder_red_temp.v1[ax0]), (int32*)placeholder_red_temp.v0[ax0], k1, dtype=int32)
        placeholder_red_temp.v1[ax0] = @tir.if_then_else(((float32*)placeholder_2[((ax0*2) + k1)] <= (float32*)placeholder_red_temp.v1[ax0]), (float32*)placeholder_red_temp.v1[ax0], (float32*)placeholder_2[((ax0*2) + k1)], dtype=float32)
      }
    }
    for (ax0_1: int32, 0, 10000) {
      placeholder_red_2[ax0_1] = (int32*)placeholder_red_temp.v0[ax0_1]
    }
  }
}


primfn(placeholder_1: handle, T_concat_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {T_concat: Buffer(T_concat_2: Pointer(float32), float32, [10000, 2], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [10000, 1], [])}
  buffer_map = {placeholder_1: placeholder, T_concat_1: T_concat} {
  for (ax0: int32, 0, 10000) "parallel" {
    for (ax1: int32, 0, 2) {
      T_concat_2[((ax0*2) + ax1)] = @tir.if_then_else((1 <= ax1), @tir.sigmoid((float32*)placeholder_2[((ax0 + ax1) - 1)], dtype=float32), (1f32 - @tir.sigmoid((float32*)placeholder_2[(ax0 + ax1)], dtype=float32)), dtype=float32)
    }
  }
}


primfn(placeholder_1: handle, placeholder_red_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {placeholder_red: Buffer(placeholder_red_2: Pointer(float32), float32, [10000, 1], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [10000, 1, 500], [])}
  buffer_map = {placeholder_1: placeholder, placeholder_red_1: placeholder_red} {
  for (ax0: int32, 0, 10000) "parallel" {
    placeholder_red_2[ax0] = 0f32
    for (k2: int32, 0, 500) {
      placeholder_red_2[ax0] = ((float32*)placeholder_red_2[ax0] + (float32*)placeholder_2[((ax0*500) + k2)])
    }
  }
}

@masahi
Copy link
Collaborator

masahi commented Sep 16, 2020

hmm isn't that for the custom op version? For the open source one we should have more ops fused into one function, so the lowered IR would be more complicated. For the custom op version IR makes sense.

@interesaaat
Copy link
Collaborator Author

Sorry I didn't past the whole output. This is the whole output.

primfn(placeholder_1: handle, placeholder_red_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {placeholder_red: Buffer(placeholder_red_2: Pointer(int32), int32, [10000], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [10000, 2], [])}
  buffer_map = {placeholder_1: placeholder, placeholder_red_1: placeholder_red} {
  attr [placeholder_red_temp.v0: Pointer(int32)] "storage_scope" = "global";
  allocate(placeholder_red_temp.v0, int32, [10000]);
  attr [placeholder_red_temp.v1: Pointer(float32)] "storage_scope" = "global";
  allocate(placeholder_red_temp.v1, float32, [10000]) {
    for (ax0: int32, 0, 10000) "parallel" {
      placeholder_red_temp.v0[ax0] = -1
      placeholder_red_temp.v1[ax0] = -3.40282e+38f32
      for (k1: int32, 0, 2) {
        placeholder_red_temp.v0[ax0] = @tir.if_then_else(((float32*)placeholder_2[((ax0*2) + k1)] <= (float32*)placeholder_red_temp.v1[ax0]), (int32*)placeholder_red_temp.v0[ax0], k1, dtype=int32)
        placeholder_red_temp.v1[ax0] = @tir.if_then_else(((float32*)placeholder_2[((ax0*2) + k1)] <= (float32*)placeholder_red_temp.v1[ax0]), (float32*)placeholder_red_temp.v1[ax0], (float32*)placeholder_2[((ax0*2) + k1)], dtype=float32)
      }
    }
    for (ax0_1: int32, 0, 10000) {
      placeholder_red_2[ax0_1] = (int32*)placeholder_red_temp.v0[ax0_1]
    }
  }
}


primfn(placeholder_1: handle, T_concat_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {T_concat: Buffer(T_concat_2: Pointer(float32), float32, [10000, 2], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [10000, 1], [])}
  buffer_map = {placeholder_1: placeholder, T_concat_1: T_concat} {
  for (ax0: int32, 0, 10000) "parallel" {
    for (ax1: int32, 0, 2) {
      T_concat_2[((ax0*2) + ax1)] = @tir.if_then_else((1 <= ax1), @tir.sigmoid((float32*)placeholder_2[((ax0 + ax1) - 1)], dtype=float32), (1f32 - @tir.sigmoid((float32*)placeholder_2[(ax0 + ax1)], dtype=float32)), dtype=float32)
    }
  }
}


primfn(placeholder_1: handle, placeholder_red_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {placeholder_red: Buffer(placeholder_red_2: Pointer(float32), float32, [10000, 1], []),
             placeholder: Buffer(placeholder_2: Pointer(float32), float32, [10000, 1, 500], [])}
  buffer_map = {placeholder_1: placeholder, placeholder_red_1: placeholder_red} {
  for (ax0: int32, 0, 10000) "parallel" {
    placeholder_red_2[ax0] = 0f32
    for (k2: int32, 0, 500) {
      placeholder_red_2[ax0] = ((float32*)placeholder_red_2[ax0] + (float32*)placeholder_2[((ax0*500) + k2)])
    }
  }
}


primfn(placeholder_19: handle, placeholder_20: handle, placeholder_21: handle, placeholder_22: handle, placeholder_23: handle, placeholder_24: handle, placeholder_25: handle, placeholder_26: handle, placeholder_27: handle, placeholder_28: handle, placeholder_29: handle, placeholder_30: handle, placeholder_31: handle, placeholder_32: handle, placeholder_33: handle, placeholder_34: handle, placeholder_35: handle, placeholder_36: handle, placeholder_37: handle, T_reshape_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
 buffers = {placeholder_18: Buffer(placeholder_38: Pointer(float32), float32, [128000, 1], []),
             placeholder_17: Buffer(placeholder_39: Pointer(int64), int64, [500], []),
             placeholder_16: Buffer(placeholder_40: Pointer(float32), float32, [4000], []),
             placeholder_10: Buffer(placeholder_41: Pointer(float32), float32, [16000], []),
             placeholder_3: Buffer(placeholder_42: Pointer(int64), int64, [500], []),
             placeholder_1: Buffer(placeholder_43: Pointer(float32), float32, [1000], []),
             placeholder_5: Buffer(placeholder_44: Pointer(int64), int64, [32000], []),
             placeholder_7: Buffer(placeholder_45: Pointer(int64), int64, [2000], []),
             placeholder: Buffer(placeholder_46: Pointer(int64), int64, [4000], []),
             T_reshape: Buffer(T_reshape_2: Pointer(float32), float32, [10000, 1, 500], []),
             placeholder_11: Buffer(placeholder_47: Pointer(int64), int64, [8000], []),
             placeholder_6: Buffer(placeholder_48: Pointer(float32), float32, [2000], []),
             placeholder_15: Buffer(placeholder_49: Pointer(float32), float32, [32000], []),
             placeholder_2: Buffer(placeholder_50: Pointer(int64), int64, [64000], []),
             placeholder_4: Buffer(placeholder_51: Pointer(float32), float32, [8000], []),
             placeholder_8: Buffer(placeholder_52: Pointer(int64), int64, [1000], []),
             placeholder_14: Buffer(placeholder_53: Pointer(float32), float32, [500], []),
             placeholder_13: Buffer(placeholder_54: Pointer(int64), int64, [16000], []),
             placeholder_12: Buffer(placeholder_55: Pointer(float32), float32, [64000], []),
             placeholder_9: Buffer(placeholder_56: Pointer(float32), float32, [10000, 28], [])}
  buffer_map = {placeholder_28: placeholder, T_reshape_1: T_reshape, placeholder_25: placeholder_1, placeholder_36: placeholder_2, placeholder_21: placeholder_3, placeholder_31: placeholder_4, placeholder_34: placeholder_5, placeholder_27: placeholder_6, placeholder_26: placeholder_7, placeholder_24: placeholder_8, placeholder_20: placeholder_9, placeholder_33: placeholder_10, placeholder_30: placeholder_11, placeholder_37: placeholder_12, placeholder_32: placeholder_13, placeholder_22: placeholder_14, placeholder_35: placeholder_15, placeholder_29: placeholder_16, placeholder_23: placeholder_17, placeholder_19: placeholder_18} {
  for (ax0.ax1.fused: int32, 0, 10000) "parallel" {
    for (ax2.outer: int32, 0, 32) {
      for (ax2.inner.s: int32, 0, 16) {
        if @tir.likely((((ax2.outer*16) + ax2.inner.s) < 500), dtype=bool) {
T_reshape_2[(((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s)] = (float32*)placeholder_38[min(max(0i64, (((((((((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*128i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*128i64)) + (cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])]))*64i64)) + (cast(int64, ((float32*)placeholder_48[min(max(0i64, (((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*2i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*2i64)) + cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])])))), 1999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_45[min(max(0i64, (((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*2i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*2i64)) + cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])])))), 1999i64)])]))*32i64)) + (cast(int64, ((float32*)placeholder_40[min(max(0i64, ((((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*4i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*4i64)) + (cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])]))*2i64)) + cast(int64, ((float32*)placeholder_48[min(max(0i64, (((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*2i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*2i64)) + cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])])))), 1999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_45[min(max(0i64, (((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*2i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*2i64)) + cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])])))), 1999i64)])])))), 3999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_46[min(max(0i64, ((((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*4i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*4i64)) + (cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])]))*2i64)) + cast(int64, ((float32*)placeholder_48[min(max(0i64, (((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*2i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*2i64)) + cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])])))), 1999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_45[min(max(0i64, (((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*2i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*2i64)) + cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])])))), 1999i64)])])))), 3999i64)])]))*16i64)) + (cast(int64, ((float32*)placeholder_51[min(max(0i64, (((((cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))]))*8i64) + ((int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)]*8i64)) + (cast(int64, ((float32*)placeholder_43[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)] <= (float32*)placeholder_56[(cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)) + (int64*)placeholder_52[min(max(0i64, (cast(int64, ((float32*)placeholder_53[((ax2.outer*16) + ax2.inner.s)] <= (float32*)placeholder_56[(min(max(0i64, (int64*)placeholder_42[((ax2.outer*16) + ax2.inner.s)]), 27i64) + cast(int64, (floordiv(floormod((((ax0.ax1.fused*500) + (ax2.outer*16)) + ax2.inner.s), 5000000), 500)*28)))])) + (int64*)placeholder_39[((ax2.outer*16) + ax2.inner.s)])), 999i64)])])))), 1999i64)])])))), 3999i64)])])))), 7999i64)])])))), 15999i64)])])))), 31999i64)])])))), 63999i64)])])))), 127999i64)]
}
}
}

@interesaaat
Copy link
Collaborator Author

interesaaat commented Sep 16, 2020

Ok I finally managed to get it working with tvm 0.6 and custom ops.

// attr [placeholder_red_temp.v0] storage_scope = "global"
allocate placeholder_red_temp.v0[int32 * 10000]
// attr [placeholder_red_temp.v1] storage_scope = "global"
allocate placeholder_red_temp.v1[float32 * 10000]
produce placeholder_red_temp {
  for (ax0, 0, 10000) {
    placeholder_red_temp.v0[ax0] = -1
    placeholder_red_temp.v1[ax0] = -3.40282e+38f
    for (k1, 0, 2) {
      placeholder_red_temp.v0[ax0] = tvm_if_then_else((placeholder[((ax0*2) + k1)] <= placeholder_red_temp.v1[ax0]), placeholder_red_temp.v0[ax0], k1)
      placeholder_red_temp.v1[ax0] = tvm_if_then_else((placeholder[((ax0*2) + k1)] <= placeholder_red_temp.v1[ax0]), placeholder_red_temp.v1[ax0], placeholder[((ax0*2) + k1)])
    }
  }
}
produce placeholder_red {
  for (ax0, 0, 10000) {
    placeholder_red[ax0] = placeholder_red_temp.v0[ax0]
  }
}

produce T_concat {
  parallel (ax0, 0, 10000) {
    for (ax1, 0, 2) {
      T_concat[((ax0*2) + ax1)] = tvm_if_then_else((1 <= ax1), sigmoid(placeholder[((ax0 + ax1) - 1)]), (1f - sigmoid(placeholder[(ax0 + ax1)])))
    }
  }
}

produce placeholder_red {
  for (ax0.ax1.fused, 0, 10000) {
    placeholder_red[ax0.ax1.fused] = 0f
    for (k2, 0, 500) {
      placeholder_red[ax0.ax1.fused] = (placeholder_red[ax0.ax1.fused] + placeholder[((ax0.ax1.fused*500) + k2)])
    }
  }
}

produce T_reshape {
  parallel (ax0.ax1.fused, 0, 10000) {
    for (ax2.outer, 0, 63) {
      for (ax2.inner.s, 0, 8) {
        if (likely((((ax2.outer*8) + ax2.inner.s) < 500))) {
          if (likely((((ax2.outer*8) + ax2.inner.s) < 500))) {
r*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2) + 1))*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))])] < placeholder[tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))]), (tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2), ((tvm_if_then_else((placeholder[((floormod((floordiv(((ax2.outer*8) + ax2.inner.s), 500) + ax0.ax1.fused), 6962)*28) + placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)])] < placeholder[floormod(((ax2.outer*8) + ax2.inner.s), 500)]), (floormod(((ax2.outer*8) + ax2.inner.s), 500)*2), ((floormod(((ax2.outer*8) + ax2.inner.s), 500)*2) + 1))*2) + 1))*2) + 1))*2) + 1))*2) + 1))*2) + 1))*2) + 1))*2) + 1))), 127999)]
}}}}}}

v0.0.4
fn (%input: Tensor[(10000, 28), float32]) {
  %0 = hb.tree_beampp_root(%input, meta[relay.Constant][1], meta[relay.Constant][2]);
  %1 = hb.tree_beampp_internal(%input, %0, meta[relay.Constant][3], meta[relay.Constant][4]);
  %2 = hb.tree_beampp_internal(%input, %1, meta[relay.Constant][5], meta[relay.Constant][6]);
  %3 = hb.tree_beampp_internal(%input, %2, meta[relay.Constant][7], meta[relay.Constant][8]);
  %4 = hb.tree_beampp_internal(%input, %3, meta[relay.Constant][9], meta[relay.Constant][10]);
  %5 = hb.tree_beampp_internal(%input, %4, meta[relay.Constant][11], meta[relay.Constant][12]);
  %6 = hb.tree_beampp_internal(%input, %5, meta[relay.Constant][13], meta[relay.Constant][14]);
  %7 = hb.tree_beampp_internal(%input, %6, meta[relay.Constant][15], meta[relay.Constant][16]);
  %8 = reshape(%7, newshape=[-1]);
  %9 = take(meta[relay.Constant][0], %8, axis=0);
  %10 = reshape(%9, newshape=[-1, 500, 1]);
  %11 = reshape(%10, newshape=[-1, 1, 500]);
  %12 = sum(%11, axis=[2]);
  %13 = sigmoid(%12);
  %14 = subtract(1f, %13);
  %15 = (%14, %13);
  %16 = concatenate(%15, axis=1);
  %17 = argmax(%16, axis=[1]);
  (%17, %16)
}

@masahi
Copy link
Collaborator

masahi commented Sep 16, 2020

Interesting. One thing that is immediately obvious is there are a lot of memory buffers (placeholder_ stuff) in the current version. I think they correspond to all tree parameters (node indices, thresholds etc) coming from PyTorch. If you look at the Relay graph of current version, there are tensors of size (5000000,) or (10000, 500), but I don't see them in the relay graph or lowered IR of custom version.

Are you sure the input PyTorch modules are the same? Or having a custom op allows representing parameters more compactly? Please try comparing the number and size of parameter tensors passed to TVM (params in mod, lib, params = relay.build() or mod, params = from_pytorch())

@masahi
Copy link
Collaborator

masahi commented Sep 16, 2020

Also, if it is possible, it would be interesting to run your custom op version with the current TVM. As you can see how the lowered IR looks like in two different versions, there have been a major change in how lower level IR is represented since v0.6. We should make sure the perf regression is coming from solely due to custom op and not from internal change in lower level IR.

@interesaaat
Copy link
Collaborator Author

interesaaat commented Sep 16, 2020

Do you think that placeholders are because of the reshape ops that copy all the time? There are no reshapes in the custom ops.

I will start with your first suggestion since I already tried months back to switch to 0.7 and miserably fail. I am quite sure that the params are the same.

@masahi
Copy link
Collaborator

masahi commented Sep 16, 2020

Do you think that placeholders are because of the reshape ops that copy all the time?

No, since reshapes ops are fused into one memory movement operation.

I'm pretty sure they come from tree parameters. In the relay graph of custom version, there are also many parameters represented as meta[relay.Constant][1] etc. I wonder why they don't appear in the lowered IR as parameters of current version do (as various placeholder_)

@interesaaat
Copy link
Collaborator Author

The two __init__s are as follows:

def __init__(self, tree_parameters, n_features, n_classes):

        tree_depths = [tree_parameter[0] for tree_parameter in tree_parameters]
        max_tree_depth = max(tree_depths)
        self.max_tree_depth = max_tree_depth
        self.num_trees = len(tree_depths)
        self.n_features = n_features
        self.n_classes = n_classes

        node_maps = [tp[1] for tp in tree_parameters]

        weight_0 = np.zeros((self.num_trees, 2 ** max_tree_depth - 1))
        bias_0 = np.zeros((self.num_trees, 2 ** max_tree_depth - 1))
        weight_1 = np.zeros((self.num_trees, 2 ** max_tree_depth, n_classes))

        for i, node_map in enumerate(node_maps):
            self._get_weights_and_biases(node_map, max_tree_depth, weight_0[i], weight_1[i], bias_0[i])

        node_by_levels = [set() for _ in range(max_tree_depth)]
        self._traverse_by_level(node_by_levels, 0, -1, max_tree_depth)

        self.root_nodes = relay.Constant(tvm.nd.array(weight_0[:, 0].flatten().astype('int32')))
        self.root_biases = relay.Constant(tvm.nd.array(-1*bias_0[:, 0].astype('float32')))

        self.nodes = []
        self.biases = []
        for i in range(1, max_tree_depth):
            nodes = relay.Constant(tvm.nd.array(weight_0[:, list(sorted(node_by_levels[i]))].flatten().astype('int32')))
            biases = relay.Constant(tvm.nd.array(-1 * bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype('float32')))
            self.nodes.append(nodes)
            self.biases.append(biases)

        self.leaf_nodes = relay.Constant(tvm.nd.array(weight_1.reshape((-1, n_classes)).astype('float32')))

and

def __init__(self, tree_parameters, max_depth, n_features, classes, n_classes=None, **kwargs):
        """
        Args:
            tree_parameters: The parameters defining the tree structure
            max_depth: The maximum tree-depth in the model
            n_features: The number of features input to the model
            classes: The classes used for classification. None if implementing a regression model
            n_classes: The total number of used classes
        """
        super(PerfectTreeTraversalTreeImpl, self).__init__(tree_parameters, n_features, classes, n_classes, **kwargs)

        # Initialize the actual model.
        self.max_tree_depth = max_depth
        self.num_trees = len(tree_parameters)
        self.n_features = n_features

        node_maps = [tp[0] for tp in tree_parameters]

        weight_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
        bias_0 = np.zeros((self.num_trees, 2 ** max_depth - 1))
        weight_1 = np.zeros((self.num_trees, 2 ** max_depth, self.n_classes))

        for i, node_map in enumerate(node_maps):
            self._get_weights_and_biases(node_map, max_depth, weight_0[i], weight_1[i], bias_0[i])

        node_by_levels = [set() for _ in range(max_depth)]
        self._traverse_by_level(node_by_levels, 0, -1, max_depth)

        self.root_nodes = torch.nn.Parameter(torch.from_numpy(weight_0[:, 0].flatten().astype("int64")), requires_grad=False)
        self.root_biases = torch.nn.Parameter(-1 * torch.from_numpy(bias_0[:, 0].astype("float32")), requires_grad=False)

        tree_indices = np.array([i for i in range(0, 2 * self.num_trees, 2)]).astype("int64")
        self.tree_indices = torch.nn.Parameter(torch.from_numpy(tree_indices), requires_grad=False)

        self.nodes = []
        self.biases = []
        for i in range(1, max_depth):
            nodes = torch.nn.Parameter(
                torch.from_numpy(weight_0[:, list(sorted(node_by_levels[i]))].flatten().astype("int64")), requires_grad=False
            )
            biases = torch.nn.Parameter(
                torch.from_numpy(-1 * bias_0[:, list(sorted(node_by_levels[i]))].flatten().astype("float32")),
                requires_grad=False,
            )
            self.nodes.append(nodes)
            self.biases.append(biases)

        self.nodes = torch.nn.ParameterList(self.nodes)
        self.biases = torch.nn.ParameterList(self.biases)

        self.leaf_nodes = torch.nn.Parameter(
            torch.from_numpy(weight_1.reshape((-1, self.n_classes)).astype("float32")), requires_grad=False
        )

@interesaaat
Copy link
Collaborator Author

I don't see any major difference beside the fact the we were previously cast by int32 instead of int64.

@interesaaat
Copy link
Collaborator Author

interesaaat commented Sep 17, 2020

This is a print for param in current HB.

{'_operator_map.SklearnXGBClassifier.root_nodes': <tvm.nd.NDArray shape=(500,), cpu(0)>
array([1...]), '_operator_map.SklearnXGBClassifier.root_biases': <tvm.nd.NDArray shape=(500,), cpu(0)>
array([...], dtype=float32), '_operator_map.SklearnXGBClassifier.tree_indices': <tvm.nd.NDArray shape=(500,), cpu(0)>
array([  ...]), '_operator_map.SklearnXGBClassifier.leaf_nodes': <tvm.nd.NDArray shape=(128000, 1), cpu(0)>
array([[-0.16190477],
       [-0.16190477],
       [-0.16190477],
       ...,
       [-0.00740473],
       [-0.00740473],
       [-0.00740473]], dtype=float32), '_operator_map.SklearnXGBClassifier.nodes.0': <tvm.nd.NDArray shape=(1000,), cpu(0)>
array([1...), '_operator_map.SklearnXGBClassifier.nodes.1': <tvm.nd.NDArray shape=(2000,), cpu(0)>
array([ 2,  3, 19, ...,  0,  0,  0]), '_operator_map.SklearnXGBClassifier.nodes.2': <tvm.nd.NDArray shape=(4000,), cpu(0)>
array([ 9,  0, 25, ...,  0,  0,  0]), '_operator_map.SklearnXGBClassifier.nodes.3': <tvm.nd.NDArray shape=(8000,), cpu(0)>
array([0, 0, 0, ..., 0, 0, 0]), '_operator_map.SklearnXGBClassifier.nodes.4': <tvm.nd.NDArray shape=(16000,), cpu(0)>
array([10,  0,  0, ...,  0,  0,  0]), '_operator_map.SklearnXGBClassifier.nodes.5': <tvm.nd.NDArray shape=(32000,), cpu(0)>
array([ 0, 13,  0, ...,  0,  0,  0]), '_operator_map.SklearnXGBClassifier.nodes.6': <tvm.nd.NDArray shape=(64000,), cpu(0)>
array([0, 0, 0, ..., 0, 0, 0]), '_operator_map.SklearnXGBClassifier.biases.0': <tvm.nd.NDArray shape=(1000,), cpu(0)>
array([...],
      dtype=float32), '_operator_map.SklearnXGBClassifier.biases.1': <tvm.nd.NDArray shape=(2000,), cpu(0)>
array([-0.47300318,  1.8892708 , -0.9837453 , ..., -0.        ,
       -0.        , -0.        ], dtype=float32), '_operator_map.SklearnXGBClassifier.biases.2': <tvm.nd.NDArray shape=(4000,), cpu(0)>
array([ 1.2566634 , -0.        , -0.12022677, ..., -0.        ,
       -0.        , -0.        ], dtype=float32), '_operator_map.SklearnXGBClassifier.biases.3': <tvm.nd.NDArray shape=(8000,), cpu(0)>
array([ 1.8571583, -0.       , -0.       , ..., -0.       , -0.       ,
       -0.       ], dtype=float32), '_operator_map.SklearnXGBClassifier.biases.4': <tvm.nd.NDArray shape=(16000,), cpu(0)>
array([-0.6967661, -0.       , -0.       , ..., -0.       , -0.       ,
       -0.       ], dtype=float32), '_operator_map.SklearnXGBClassifier.biases.5': <tvm.nd.NDArray shape=(32000,), cpu(0)>
array([-0.      , -2.643936, -0.      , ..., -0.      , -0.      ,
       -0.      ], dtype=float32), '_operator_map.SklearnXGBClassifier.biases.6': <tvm.nd.NDArray shape=(64000,), cpu(0)>
array([-0., -0., -0., ..., -0., -0., -0.], dtype=float32)}
[01:11:01] /home/hummingbird/tvm/src/te/schedule/bound.cc:119: not in feed graph consumer = compute(placeholder_red_temp, 0x561f240f8820)

@masahi
Copy link
Collaborator

masahi commented Sep 17, 2020

I realized that, since in the custom version you are not converting from PyTorch, comparing two PyTorch code doesn't tell us much. What I want to see is the params you pass to relay.build(mod, params=params). For example, for data set of (20000, 28), if I do

model, params = relay.frontend.from_pytorch(ts_model, test_input)

for k, v in params.items():
    print(k, v.shape)

I get this

_operator_map.SklearnXGBClassifier.leaf_nodes (3200, 1)
_operator_map.SklearnXGBClassifier.biases.4 (1600,)
_operator_map.SklearnXGBClassifier.nodes.4 (1600,)
_operator_map.SklearnXGBClassifier.biases.3 (800,)
_operator_map.SklearnXGBClassifier.nodes.3 (800,)
_operator_map.SklearnXGBClassifier.biases.2 (400,)
_operator_map.SklearnXGBClassifier.nodes.2 (400,)
_operator_map.SklearnXGBClassifier.biases.1 (200,)
_operator_map.SklearnXGBClassifier.nodes.1 (200,)
_operator_map.SklearnXGBClassifier.biases.0 (100,)
_operator_map.SklearnXGBClassifier.nodes.0 (100,)
_operator_map.SklearnXGBClassifier.tree_indices (50,)
_operator_map.SklearnXGBClassifier.root_biases (50,)
_operator_map.SklearnXGBClassifier.root_nodes (50,)

Also, I want see the Relay graph after fusion. You can dump that by

opt_mod, opt_params = relay.optimize(model, target=target, params=params)
print(opt_mod["main"])

In the current HB, there should be one big function, something like this:

  %59 = fn (%p02: Tensor[(3200, 1), float32], %p12: Tensor[(1000000), int64], %p22: Tensor[(1000000), int64], %p32: Tensor[(20000, 28), float32], %p4: Tensor[(200), int64], %p5: Tensor[(200), float32], %p6: Tensor[(400), int64], %p7: Tensor[(400), float32], %p8: Tensor[(800), int64], %p9: Tensor[(800), float32], %p10: Tensor[(1600), int64], %p111: Tensor[(1600), float32], Primitive=1) -> Tensor[(20000, 1, 50), float32] {
    %15 = multiply(%p12, 2 /* ty=int64 */) /* ty=Tensor[(1000000), int64] */;
    %16 = add(%15, %p22) /* ty=Tensor[(1000000), int64] */;
    %17 = multiply(%16, 2 /* ty=int64 */) /* ty=Tensor[(1000000), int64] */;
    %18 = take(%p4, %16, axis=0) /* ty=Tensor[(1000000), int64] */;
    %19 = reshape(%18, newshape=[-1, 50]) /* ty=Tensor[(20000, 50), int64] */;
    %20 = gather(%p32, %19, axis=1) /* ty=Tensor[(20000, 50), float32] */;
    %21 = reshape(%20, newshape=[-1]) /* ty=Tensor[(1000000), float32] */;
    %22 = take(%p5, %16, axis=0) /* ty=Tensor[(1000000), float32] */;
    %23 = greater_equal(%21, %22) /* ty=Tensor[(1000000), bool] */;
    %24 = cast(%23, dtype="int64") /* ty=Tensor[(1000000), int64] */;
    %25 = reshape(%24, newshape=[-1]) /* ty=Tensor[(1000000), int64] */;
    %26 = add(%17, %25) /* ty=Tensor[(1000000), int64] */;
    %27 = multiply(%26, 2 /* ty=int64 */) /* ty=Tensor[(1000000), int64] */;
    %28 = take(%p6, %26, axis=0) /* ty=Tensor[(1000000), int64] */;
    %29 = reshape(%28, newshape=[-1, 50]) /* ty=Tensor[(20000, 50), int64] */;
    %30 = gather(%p32, %29, axis=1) /* ty=Tensor[(20000, 50), float32] */;
    %31 = reshape(%30, newshape=[-1]) /* ty=Tensor[(1000000), float32] */;
    %32 = take(%p7, %26, axis=0) /* ty=Tensor[(1000000), float32] */;
    %33 = greater_equal(%31, %32) /* ty=Tensor[(1000000), bool] */;
    %34 = cast(%33, dtype="int64") /* ty=Tensor[(1000000), int64] */;
    %35 = reshape(%34, newshape=[-1]) /* ty=Tensor[(1000000), int64] */;
    %36 = add(%27, %35) /* ty=Tensor[(1000000), int64] */;
    %37 = multiply(%36, 2 /* ty=int64 */) /* ty=Tensor[(1000000), int64] */;
    %38 = take(%p8, %36, axis=0) /* ty=Tensor[(1000000), int64] */;
    %39 = reshape(%38, newshape=[-1, 50]) /* ty=Tensor[(20000, 50), int64] */;
    %40 = gather(%p32, %39, axis=1) /* ty=Tensor[(20000, 50), float32] */;
    %41 = reshape(%40, newshape=[-1]) /* ty=Tensor[(1000000), float32] */;
    %42 = take(%p9, %36, axis=0) /* ty=Tensor[(1000000), float32] */;
    %43 = greater_equal(%41, %42) /* ty=Tensor[(1000000), bool] */;
    %44 = cast(%43, dtype="int64") /* ty=Tensor[(1000000), int64] */;
    %45 = reshape(%44, newshape=[-1]) /* ty=Tensor[(1000000), int64] */;
    %46 = add(%37, %45) /* ty=Tensor[(1000000), int64] */;
    %47 = multiply(%46, 2 /* ty=int64 */) /* ty=Tensor[(1000000), int64] */;
    %48 = take(%p10, %46, axis=0) /* ty=Tensor[(1000000), int64] */;
    %49 = reshape(%48, newshape=[-1, 50]) /* ty=Tensor[(20000, 50), int64] */;
    %50 = gather(%p32, %49, axis=1) /* ty=Tensor[(20000, 50), float32] */;
    %51 = reshape(%50, newshape=[-1]) /* ty=Tensor[(1000000), float32] */;
    %52 = take(%p111, %46, axis=0) /* ty=Tensor[(1000000), float32] */;
    %53 = greater_equal(%51, %52) /* ty=Tensor[(1000000), bool] */;
    %54 = cast(%53, dtype="int64") /* ty=Tensor[(1000000), int64] */;
    %55 = reshape(%54, newshape=[-1]) /* ty=Tensor[(1000000), int64] */;
    %56 = add(%47, %55) /* ty=Tensor[(1000000), int64] */;
    %57 = reshape(%56, newshape=[-1]) /* ty=Tensor[(1000000), int64] */;
    %58 = take(%p02, %57, axis=0) /* ty=Tensor[(1000000, 1), float32] */;
    reshape(%58, newshape=[20000, 1, 50]) /* ty=Tensor[(20000, 1, 50), float32] */
  };

This is the function that gets translated to the lowered IR function with many placeholders we discussed above. Note that this function takes as parameter tensors of different sizes (like p02, p12 etc). These parameters become placeholder in the lowered IR.

So the output from current HB makes sense to me. Since I don't know why the lowered IR of custom version doesn't have many placeholder, I want to see how the corresponding Relay function before lowering looks like (for example, how many parameters it has etc)

@interesaaat
Copy link
Collaborator Author

Does the param print here make sense? I am trying to print the same for the custom version.

@masahi
Copy link
Collaborator

masahi commented Sep 17, 2020

ah right, for the current HB they make sense. Yes, I want to see the params for the custom version.

I don't need the values of params, just the shape is enough.

@interesaaat
Copy link
Collaborator Author

Yes that was quite verbose :) I am trying to print also for the other case, but it takes a while...Question: do you want to have a call and go over this live? Maybe we can find a solution with less iterations :)

@masahi
Copy link
Collaborator

masahi commented Sep 17, 2020

:) ok I'll send an email to you (this thread is already too long)

@interesaaat
Copy link
Collaborator Author

This is the print of params for the custom version.

{'p14': <tvm.NDArray shape=(32000,), cpu(0)>
array([-0.      , -2.643936, -0.      , ..., -0.      , -0.      ,  -0.      ], dtype=float32), 'p9': <tvm.NDArray shape=(8000,), cpu(0)>
array([0, 0, 0, ..., 0, 0, 0], dtype=int32), 'p16': <tvm.NDArray shape=(64000,), cpu(0)>
array([-0., -0., -0., ..., -0., -0., -0.], dtype=float32), 'p4': <tvm.NDArray shape=(1000,), cpu(0)>
array([...], dtype=float32), 'p15': <tvm.NDArray shape=(64000,), cpu(0)>
array([0, 0, 0, ..., 0, 0, 0], dtype=int32), 'p12': <tvm.NDArray shape=(16000,), cpu(0)>
array([-0.6967661, -0.       , -0.       , ..., -0.       , -0.       ,
       -0.       ], dtype=float32), 'p5': <tvm.NDArray shape=(2000,), cpu(0)>
array([ 2,  3, 19, ...,  0,  0,  0], dtype=int32), 'p1': <tvm.NDArray shape=(500,), cpu(0)>
array([...], dtype=int32), 'p0': <tvm.NDArray shape=(128000, 1), cpu(0)>
array([...], dtype=float32), 'p6': <tvm.NDArray shape=(2000,), cpu(0)>
array([...], dtype=float32), 'p8': <tvm.NDArray shape=(4000,), cpu(0)>
array([ 1.2566634 , -0.        , -0.12022677, ..., -0.     ], dtype=float32), 'p10': <tvm.NDArray shape=(8000,), cpu(0)>
array([ 1.8571583, -0.       , -0.       , ..., -0.       , -0.    ], dtype=float32), 'p3': <tvm.NDArray shape=(1000,), cpu(0)>
array([...], dtype=int32), 'p11': <tvm.NDArray shape=(16000,), cpu(0)>
array([10,  0,  0, ...,  0,  0,  0], dtype=int32), 'p7': <tvm.NDArray shape=(4000,), cpu(0)>
array([ 9,  0, 25, ...,  0,  0,  0], dtype=int32), 'p13': <tvm.NDArray shape=(32000,), cpu(0)>

@interesaaat
Copy link
Collaborator Author

I don't think there is an easy solution for the reshape problem. index_select requires a 1d index, that's why we always flatten the tensor. I don't think there is any op in pytorch (or TVM) that does exactly what we want. Maybe I can implement the function in python and with torchscript it will get compiled into what we want? (without having to resort to pattern matching or custom ops)

@masahi
Copy link
Collaborator

masahi commented Sep 19, 2020

I don't think having a custom pytorch would help.

We can at least remove the last reshape here (it doesn't do anything) https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/_tree_implementations.py#L328

If we need to keep indices around as 1D, can we do gather at https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/_tree_implementations.py#L327 also with 1D indices? By flattening the input x to 1D beforehand, for example? If we somehow could do that, we can remove remaining two reshapes.

Also, do indices need to be 64 bit? If 32 bit is enough, that would cut bandwidth by half, and I think that would be a good deal. Does custom one also uses 64 bit indices? From the parameters you posted above, they seem to be 32 bit.

@interesaaat
Copy link
Collaborator Author

I don't think having a custom pytorch would help.

We can at least remove the last reshape here (it doesn't do anything) https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/_tree_implementations.py#L328

Yea i also saw that one. Will remove.

If we need to keep indices around as 1D, can we do gather at https://github.com/microsoft/hummingbird/blob/master/hummingbird/ml/operator_converters/_tree_implementations.py#L327 also with 1D indices? By flattening the input x to 1D beforehand, for example? If we somehow could do that, we can remove remaining two reshapes.

I see. Let me try to make everything a 1d array. Hopefully it will work.

Also, do indices need to be 64 bit? If 32 bit is enough, that would cut bandwidth by half, and I think that would be a good deal. Does custom one also uses 64 bit indices? From the parameters you posted above, they seem to be 32 bit.

I think that indices are supposed to be Long in pytorch.

@interesaaat
Copy link
Collaborator Author

Unfortunately there is no solution to this because we cannot gather on 1D indices. Unless @scnakandala has some idea. (BTW I removed the unnecessary view ops but the performance in TVM didn't change).

@interesaaat
Copy link
Collaborator Author

interesaaat commented Sep 23, 2020

@masahi We tried to get rid of the reshapes ops in PR #306.

Unfortunately since we are using expand TVM crashes with a seg fault error for the 10k batch size. I was able to make it work with a batch size of 100 and the optimized graph is the following:

def @main(%input: Tensor[(62, 28), float32]) -> (Tensor[(62), int32], Tensor[(62, 2), float32]) {
  %55 = fn (%p0: Tensor[(128000, 1, 1), float32], %p1: Tensor[(62, 28), float32], %p2: Tensor[(500), int64], %p3: Tensor[(500), float32], %p4: Tensor[(500), int64], %p5: Tensor[(62, 1000), int64], %p6: Tensor[(62, 1000), float32], %p7: Tensor[(62, 2000), int64], %p8: Tensor[(62, 2000), float32], %p9: Tensor[(62, 4000), int64], %p10: Tensor[(62, 4000), float32], %p11: Tensor[(62, 8000), int64], %p12: Tensor[(62, 8000), float32], %p13: Tensor[(62, 16000), int64], %p14: Tensor[(62, 16000), float32], %p15: Tensor[(62, 32000), int64], %p16: Tensor[(62, 32000), float32], %p17: Tensor[(62, 64000), int64], %p18: Tensor[(62, 64000), float32], Primitive=1) -> Tensor[(62, 1, 500), float32] {
    %0 = take(%p1, %p2, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %1 = greater_equal(%0, %p3) /* ty=Tensor[(62, 500), bool] */;
    %2 = cast(%1, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %3 = add(%2, %p4) /* ty=Tensor[(62, 500), int64] */;
    %4 = multiply(%3, 2 /* ty=int64 */) /* ty=Tensor[(62, 500), int64] */;
    %5 = gather(%p5, %3, axis=1) /* ty=Tensor[(62, 500), int64] */;
    %6 = gather(%p1, %5, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %7 = gather(%p6, %3, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %8 = greater_equal(%6, %7) /* ty=Tensor[(62, 500), bool] */;
    %9 = cast(%8, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %10 = add(%4, %9) /* ty=Tensor[(62, 500), int64] */;
    %11 = multiply(%10, 2 /* ty=int64 */) /* ty=Tensor[(62, 500), int64] */;
    %12 = gather(%p7, %10, axis=1) /* ty=Tensor[(62, 500), int64] */;
    %13 = gather(%p1, %12, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %14 = gather(%p8, %10, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %15 = greater_equal(%13, %14) /* ty=Tensor[(62, 500), bool] */;
    %16 = cast(%15, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %17 = add(%11, %16) /* ty=Tensor[(62, 500), int64] */;
    %18 = multiply(%17, 2 /* ty=int64 */) /* ty=Tensor[(62, 500), int64] */;
    %19 = gather(%p9, %17, axis=1) /* ty=Tensor[(62, 500), int64] */;
    %20 = gather(%p1, %19, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %21 = gather(%p10, %17, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %22 = greater_equal(%20, %21) /* ty=Tensor[(62, 500), bool] */;
    %23 = cast(%22, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %24 = add(%18, %23) /* ty=Tensor[(62, 500), int64] */;
    %25 = multiply(%24, 2 /* ty=int64 */) /* ty=Tensor[(62, 500), int64] */;
    %26 = gather(%p11, %24, axis=1) /* ty=Tensor[(62, 500), int64] */;
    %27 = gather(%p1, %26, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %28 = gather(%p12, %24, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %29 = greater_equal(%27, %28) /* ty=Tensor[(62, 500), bool] */;
    %30 = cast(%29, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %31 = add(%25, %30) /* ty=Tensor[(62, 500), int64] */;
    %32 = multiply(%31, 2 /* ty=int64 */) /* ty=Tensor[(62, 500), int64] */;
    %33 = gather(%p13, %31, axis=1) /* ty=Tensor[(62, 500), int64] */;
    %34 = gather(%p1, %33, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %35 = gather(%p14, %31, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %36 = greater_equal(%34, %35) /* ty=Tensor[(62, 500), bool] */;
    %37 = cast(%36, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %38 = add(%32, %37) /* ty=Tensor[(62, 500), int64] */;
    %39 = multiply(%38, 2 /* ty=int64 */) /* ty=Tensor[(62, 500), int64] */;
    %40 = gather(%p15, %38, axis=1) /* ty=Tensor[(62, 500), int64] */;
    %41 = gather(%p1, %40, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %42 = gather(%p16, %38, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %43 = greater_equal(%41, %42) /* ty=Tensor[(62, 500), bool] */;
    %44 = cast(%43, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %45 = add(%39, %44) /* ty=Tensor[(62, 500), int64] */;
    %46 = multiply(%45, 2 /* ty=int64 */) /* ty=Tensor[(62, 500), int64] */;
    %47 = gather(%p17, %45, axis=1) /* ty=Tensor[(62, 500), int64] */;
    %48 = gather(%p1, %47, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %49 = gather(%p18, %45, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %50 = greater_equal(%48, %49) /* ty=Tensor[(62, 500), bool] */;
    %51 = cast(%50, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %52 = add(%46, %51) /* ty=Tensor[(62, 500), int64] */;
    %53 = reshape(%52, newshape=[-1]) /* ty=Tensor[(31000), int64] */;
    %54 = take(%p0, %53, axis=0) /* ty=Tensor[(31000, 1, 1), float32] */;
    reshape(%54, newshape=[62, 1, 500]) /* ty=Tensor[(62, 1, 500), float32] */
  };
  %56 = %55(meta[relay.Constant][0] /* ty=Tensor[(128000, 1, 1), float32] */, %input, meta[relay.Constant][1] /* ty=Tensor[(500), int64] */, meta[relay.Constant][2] /* ty=Tensor[(500), float32] */, meta[relay.Constant][3] /* ty=Tensor[(500), int64] */, meta[relay.Constant][4] /* ty=Tensor[(62, 1000), int64] */, meta[relay.Constant][5] /* ty=Tensor[(62, 1000), float32] */, meta[relay.Constant][6] /* ty=Tensor[(62, 2000), int64] */, meta[relay.Constant][7] /* ty=Tensor[(62, 2000), float32] */, meta[relay.Constant][8] /* ty=Tensor[(62, 4000), int64] */, meta[relay.Constant][9] /* ty=Tensor[(62, 4000), float32] */, meta[relay.Constant][10] /* ty=Tensor[(62, 8000), int64] */, meta[relay.Constant][11] /* ty=Tensor[(62, 8000), float32] */, meta[relay.Constant][12] /* ty=Tensor[(62, 16000), int64] */, meta[relay.Constant][13] /* ty=Tensor[(62, 16000), float32] */, meta[relay.Constant][14] /* ty=Tensor[(62, 32000), int64] */, meta[relay.Constant][15] /* ty=Tensor[(62, 32000), float32] */, meta[relay.Constant][16] /* ty=Tensor[(62, 64000), int64] */, meta[relay.Constant][17] /* ty=Tensor[(62, 64000), float32] */) /* ty=Tensor[(62, 1, 500), float32] */;
  %57 = fn (%p01: Tensor[(62, 1, 500), float32], Primitive=1) -> Tensor[(62, 1), float32] {
    sum(%p01, axis=[2]) /* ty=Tensor[(62, 1), float32] */
  };
  %58 = %57(%56) /* ty=Tensor[(62, 1), float32] */;
  %63 = fn (%p02: Tensor[(62, 1), float32], Primitive=1) -> Tensor[(62, 2), float32] {
    %59 = sigmoid(%p02) /* ty=Tensor[(62, 1), float32] */;
    %60 = multiply(1f /* ty=float32 */, %59) /* ty=Tensor[(62, 1), float32] */;
    %61 = subtract(1f /* ty=float32 */, %60) /* ty=Tensor[(62, 1), float32] */;
    %62 = (%61, %59);
    concatenate(%62, axis=1) /* ty=Tensor[(62, 2), float32] */
  };
  %64 = %63(%58) /* ty=Tensor[(62, 2), float32] */;
  %65 = fn (%p03: Tensor[(62, 2), float32], Primitive=1) -> Tensor[(62), int32] {
    argmax(%p03, axis=[1]) /* ty=Tensor[(62), int32] */
  };
  %66 = %65(%64) /* ty=Tensor[(62), int32] */;
  (%66, %64)
}

This version is actually about 2x slower than the previous one. Even without reshape I don't see any major operator fusion happening.

@interesaaat
Copy link
Collaborator Author

interesaaat commented Sep 23, 2020

Actually, is %55 the result of fusion? In that case it actually fuses a lot.

Although the original version without expand actually fuses more.

def @main(%input: Tensor[(62, 28), float32]) -> (Tensor[(62), int32], Tensor[(62, 2), float32]) {
  %69 = fn (%p0: Tensor[(128000, 1), float32], %p1: Tensor[(62, 28), float32], %p2: Tensor[(500), int64], %p3: Tensor[(500), float32], %p4: Tensor[(500), int64], %p5: Tensor[(1000), int64], %p6: Tensor[(1000), float32], %p7: Tensor[(2000), int64], %p8: Tensor[(2000), float32], %p9: Tensor[(4000), int64], %p10: Tensor[(4000), float32], %p11: Tensor[(8000), int64], %p12: Tensor[(8000), float32], %p13: Tensor[(16000), int64], %p14: Tensor[(16000), float32], %p15: Tensor[(32000), int64], %p16: Tensor[(32000), float32], %p17: Tensor[(64000), int64], %p18: Tensor[(64000), float32], Primitive=1) -> Tensor[(62, 1, 500), float32] {
    %0 = take(%p1, %p2, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %1 = greater_equal(%0, %p3) /* ty=Tensor[(62, 500), bool] */;
    %2 = cast(%1, dtype="int64") /* ty=Tensor[(62, 500), int64] */;
    %3 = add(%2, %p4) /* ty=Tensor[(62, 500), int64] */;
    %4 = reshape(%3, newshape=[-1]) /* ty=Tensor[(31000), int64] */;
    %5 = multiply(%4, 2 /* ty=int64 */) /* ty=Tensor[(31000), int64] */;
    %6 = take(%p5, %4, axis=0) /* ty=Tensor[(31000), int64] */;
    %7 = reshape(%6, newshape=[-1, 500]) /* ty=Tensor[(62, 500), int64] */;
    %8 = gather(%p1, %7, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %9 = reshape(%8, newshape=[-1]) /* ty=Tensor[(31000), float32] */;
    %10 = take(%p6, %4, axis=0) /* ty=Tensor[(31000), float32] */;
    %11 = greater_equal(%9, %10) /* ty=Tensor[(31000), bool] */;
    %12 = cast(%11, dtype="int64") /* ty=Tensor[(31000), int64] */;
    %13 = add(%5, %12) /* ty=Tensor[(31000), int64] */;
    %14 = multiply(%13, 2 /* ty=int64 */) /* ty=Tensor[(31000), int64] */;
    %15 = take(%p7, %13, axis=0) /* ty=Tensor[(31000), int64] */;
    %16 = reshape(%15, newshape=[-1, 500]) /* ty=Tensor[(62, 500), int64] */;
    %17 = gather(%p1, %16, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %18 = reshape(%17, newshape=[-1]) /* ty=Tensor[(31000), float32] */;
    %19 = take(%p8, %13, axis=0) /* ty=Tensor[(31000), float32] */;
    %20 = greater_equal(%18, %19) /* ty=Tensor[(31000), bool] */;
    %21 = cast(%20, dtype="int64") /* ty=Tensor[(31000), int64] */;
    %22 = add(%14, %21) /* ty=Tensor[(31000), int64] */;
    %23 = multiply(%22, 2 /* ty=int64 */) /* ty=Tensor[(31000), int64] */;
    %24 = take(%p9, %22, axis=0) /* ty=Tensor[(31000), int64] */;
    %25 = reshape(%24, newshape=[-1, 500]) /* ty=Tensor[(62, 500), int64] */;
    %26 = gather(%p1, %25, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %27 = reshape(%26, newshape=[-1]) /* ty=Tensor[(31000), float32] */;
    %28 = take(%p10, %22, axis=0) /* ty=Tensor[(31000), float32] */;
    %29 = greater_equal(%27, %28) /* ty=Tensor[(31000), bool] */;
    %30 = cast(%29, dtype="int64") /* ty=Tensor[(31000), int64] */;
    %31 = add(%23, %30) /* ty=Tensor[(31000), int64] */;
    %32 = multiply(%31, 2 /* ty=int64 */) /* ty=Tensor[(31000), int64] */;
    %33 = take(%p11, %31, axis=0) /* ty=Tensor[(31000), int64] */;
    %34 = reshape(%33, newshape=[-1, 500]) /* ty=Tensor[(62, 500), int64] */;
    %35 = gather(%p1, %34, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %36 = reshape(%35, newshape=[-1]) /* ty=Tensor[(31000), float32] */;
    %37 = take(%p12, %31, axis=0) /* ty=Tensor[(31000), float32] */;
    %38 = greater_equal(%36, %37) /* ty=Tensor[(31000), bool] */;
    %39 = cast(%38, dtype="int64") /* ty=Tensor[(31000), int64] */;
    %40 = add(%32, %39) /* ty=Tensor[(31000), int64] */;
    %41 = multiply(%40, 2 /* ty=int64 */) /* ty=Tensor[(31000), int64] */;
    %42 = take(%p13, %40, axis=0) /* ty=Tensor[(31000), int64] */;
    %43 = reshape(%42, newshape=[-1, 500]) /* ty=Tensor[(62, 500), int64] */;
    %44 = gather(%p1, %43, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %45 = reshape(%44, newshape=[-1]) /* ty=Tensor[(31000), float32] */;
    %46 = take(%p14, %40, axis=0) /* ty=Tensor[(31000), float32] */;
    %47 = greater_equal(%45, %46) /* ty=Tensor[(31000), bool] */;
    %48 = cast(%47, dtype="int64") /* ty=Tensor[(31000), int64] */;
    %49 = add(%41, %48) /* ty=Tensor[(31000), int64] */;
    %50 = multiply(%49, 2 /* ty=int64 */) /* ty=Tensor[(31000), int64] */;
    %51 = take(%p15, %49, axis=0) /* ty=Tensor[(31000), int64] */;
    %52 = reshape(%51, newshape=[-1, 500]) /* ty=Tensor[(62, 500), int64] */;
    %53 = gather(%p1, %52, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %54 = reshape(%53, newshape=[-1]) /* ty=Tensor[(31000), float32] */;
    %55 = take(%p16, %49, axis=0) /* ty=Tensor[(31000), float32] */;
    %56 = greater_equal(%54, %55) /* ty=Tensor[(31000), bool] */;
    %57 = cast(%56, dtype="int64") /* ty=Tensor[(31000), int64] */;
    %58 = add(%50, %57) /* ty=Tensor[(31000), int64] */;
    %59 = multiply(%58, 2 /* ty=int64 */) /* ty=Tensor[(31000), int64] */;
    %60 = take(%p17, %58, axis=0) /* ty=Tensor[(31000), int64] */;
    %61 = reshape(%60, newshape=[-1, 500]) /* ty=Tensor[(62, 500), int64] */;
    %62 = gather(%p1, %61, axis=1) /* ty=Tensor[(62, 500), float32] */;
    %63 = reshape(%62, newshape=[-1]) /* ty=Tensor[(31000), float32] */;
    %64 = take(%p18, %58, axis=0) /* ty=Tensor[(31000), float32] */;
    %65 = greater_equal(%63, %64) /* ty=Tensor[(31000), bool] */;
    %66 = cast(%65, dtype="int64") /* ty=Tensor[(31000), int64] */;
    %67 = add(%59, %66) /* ty=Tensor[(31000), int64] */;
    %68 = take(%p0, %67, axis=0) /* ty=Tensor[(31000, 1), float32] */;
    reshape(%68, newshape=[62, 1, 500]) /* ty=Tensor[(62, 1, 500), float32] */
  };
  %70 = %69(meta[relay.Constant][0] /* ty=Tensor[(128000, 1), float32] */, %input, meta[relay.Constant][1] /* ty=Tensor[(500), int64] */, meta[relay.Constant][2] /* ty=Tensor[(500), float32] */, meta[relay.Constant][3] /* ty=Tensor[(500), int64] */, meta[relay.Constant][4] /* ty=Tensor[(1000), int64] */, meta[relay.Constant][5] /* ty=Tensor[(1000), float32] */, meta[relay.Constant][6] /* ty=Tensor[(2000), int64] */, meta[relay.Constant][7] /* ty=Tensor[(2000), float32] */, meta[relay.Constant][8] /* ty=Tensor[(4000), int64] */, meta[relay.Constant][9] /* ty=Tensor[(4000), float32] */, meta[relay.Constant][10] /* ty=Tensor[(8000), int64] */, meta[relay.Constant][11] /* ty=Tensor[(8000), float32] */, meta[relay.Constant][12] /* ty=Tensor[(16000), int64] */, meta[relay.Constant][13] /* ty=Tensor[(16000), float32] */, meta[relay.Constant][14] /* ty=Tensor[(32000), int64] */, meta[relay.Constant][15] /* ty=Tensor[(32000), float32] */, meta[relay.Constant][16] /* ty=Tensor[(64000), int64] */, meta[relay.Constant][17] /* ty=Tensor[(64000), float32] */) /* ty=Tensor[(62, 1, 500), float32] */;
  %71 = fn (%p01: Tensor[(62, 1, 500), float32], Primitive=1) -> Tensor[(62, 1), float32] {
    sum(%p01, axis=[2]) /* ty=Tensor[(62, 1), float32] */
  };
  %72 = %71(%70) /* ty=Tensor[(62, 1), float32] */;
  %77 = fn (%p02: Tensor[(62, 1), float32], Primitive=1) -> Tensor[(62, 2), float32] {
    %73 = sigmoid(%p02) /* ty=Tensor[(62, 1), float32] */;
    %74 = multiply(1f /* ty=float32 */, %73) /* ty=Tensor[(62, 1), float32] */;
    %75 = subtract(1f /* ty=float32 */, %74) /* ty=Tensor[(62, 1), float32] */;
    %76 = (%75, %73);
    concatenate(%76, axis=1) /* ty=Tensor[(62, 2), float32] */
  };
  %78 = %77(%72) /* ty=Tensor[(62, 2), float32] */;
  %79 = fn (%p03: Tensor[(62, 2), float32], Primitive=1) -> Tensor[(62), int32] {
    argmax(%p03, axis=[1]) /* ty=Tensor[(62), int32] */
  };
  %80 = %79(%78) /* ty=Tensor[(62), int32] */;
  (%80, %78)
}

@masahi
Copy link
Collaborator

masahi commented Sep 23, 2020

Actually, is %55 the result of fusion? In that case it actually fuses a lot.

Yes exactly. The new version has more gather and that doesn't seem help improving performance. Segfault is likely due to the same issue we met previously, you need to fix the batch size at compile time and pass the concated parameters to TVM.

In general, I recommend setting fuse depth to something like 50, even if you can compile without it. It seems too much fusing hurts performance (for a similar reason too big loop unrolling hurts)

@ksaur
Copy link
Contributor

ksaur commented Nov 3, 2020

closed with #236

@ksaur ksaur closed this as completed Nov 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants