Skip to content

Commit

Permalink
[tir][testing] Enable intra-IRModule primfunc call
Browse files Browse the repository at this point in the history
Changes / fixes to allow `PrimFunc`-`PrimFunc` calls
within the same `IRModule`:

- Fix bug in `BufferAllocationLocator` constructor in which
    `BufferDecl` nodes could trigger duplicate buffer allocations.

- Add `tvm::CallingConv::kIntraModule` calling convetion, to
  allow a PrimFunc's signature to remain unchanged by the
  `MakePackedAPI` pass during lowering.

- Add `tests/python/contrib/test_hexagon/test_call_tir.py` unit tests
  to demonstrate one viable approach for intra-module `PrimFunc`
  calls.
  • Loading branch information
Christian Convey committed Jan 11, 2023
1 parent a9c6f13 commit 5edefbe
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ enum class CallingConv : int {
* - Implementation: defined by device runtime(e.g. runtime/cuda)
*/
kDeviceKernelLaunch = 2,
/*!
* \brief For functions only called by other functions within the same IR module.
*
* Indicates that various signature transformations (e.g. those provided by the
* MakePackedAPI pass) are not desired.
*/
kIntraModule = 3,
};

/*!
Expand Down
7 changes: 7 additions & 0 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/transform.h>

#include "ir_utils.h"
#include "tvm/tir/var.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -111,6 +112,12 @@ class BufferAllocationLocator : public StmtExprMutator {
collector(func->body);
unmanaged_allocations_ = collector.unmanaged_allocations;

for (Var param : func->params) {
if (param->type_annotation.defined() && param->type_annotation.as<PointerTypeNode>()) {
unmanaged_allocations_.insert(param.get());
}
}

for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
arg_buffer_vars.emplace(buffer->data.get());
Expand Down
248 changes: 248 additions & 0 deletions tests/python/contrib/test_hexagon/test_call_tir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Various tests related to the (WIP) support for having
one PrimFunc call another PrimFunc within the same IRModule.
"""

from typing import List
import pytest
import numpy as np

import tvm
import tvm.testing
import tvm.script
from tvm.script import tir as T

from tvm.contrib.hexagon.session import Session
from tvm.contrib.hexagon import allocate_hexagon_array
from .infrastructure import get_hexagon_target

# NOTE(cconvey): These pylint warnings should be re-enabled as TVM's pylint configuration matures.
# pylint: disable=missing-function-docstring,no-self-argument,invalid-name
# pylint: disable=redefined-outer-name,missing-class-docstring

# --------------------------------------------------------------------------------------------------
# Test parameters
# --------------------------------------------------------------------------------------------------

# The shape of the original (unsplit) tensors.
# We assume that each shape describes a non-empty 2D tensor.
original_shape = tvm.testing.parameter(
# degenerate cases...
[1, 1],
[1, 2],
[2, 1],
[2, 2],
# arbitrary, provided for variety
[5, 3],
[3, 5],
)

# This dtype is arbitrary, but it must match the dtype that's hardcoded into the
# callee's function signature. E.g., 'a_data: T.Ptr[T.int8]'.
#
# Hopefully over time we'll find a way to soften this limitation, at least for
# some approaches to PrimFunc-to-PrimFunc calls.
dtype = tvm.testing.parameter("int8")

# --------------------------------------------------------------------------------------------------
# Helper functions / definitions...
# --------------------------------------------------------------------------------------------------

HEXAGON_TARGET_ = get_hexagon_target("v69")

ENTRY_PRIMFUNC_NAME_ = "main"


def get_reference_input_tensor_(shape: list, dtype: str) -> np.array:
assert len(shape) == 2

a = np.ndarray(shape, dtype=dtype)
np_dtype = a.dtype

if np_dtype.kind in ["i", "u"]:
# We allow overflow for integer types because it tends to be well-behaved
# and well-understood...
min_value = np.iinfo(np_dtype).min
max_value = np.iinfo(np_dtype).max

next_value = min_value

for i in range(shape[0]):
for j in range(shape[1]):
a[i, j] = next_value
next_value += 1

elif np_dtype.kind == "f":
# NOTE: For simplicity, we avoid test data that that require
# well-defined behavior on floating-point overflow.
# But it may be reasonable to test that in the future.
min_value = np.finfo(np_dtype).min
max_value = np.finfo(np_dtype).max

min_input_value = min_value / 2.0 + 1
max_input_value = max_value / 2.0 - 2
delta = (max_input_value - min_input_value) / (shape[0] * shape[1])

next_value = min_input_value

for i in range(shape[0]):
for j in range(shape[1]):
a[i, j] = next_value
next_value += delta

else:
assert False, f"Unexpected data type: {np_dtype}"

return a


def get_reference_output_tensor_(shape: list, dtype: str) -> np.array:
return get_reference_input_tensor_(shape, dtype) + 1


def evaluate_ir_module_(
hexagon_session: Session, shape: List, dtype: str, ir_mod: tvm.ir.module.IRModule
) -> np.array:
reference_input_np = get_reference_input_tensor_(shape, dtype)
reference_output_np = get_reference_output_tensor_(shape, dtype)

hexagon_mod_local = tvm.build(
ir_mod,
target=get_hexagon_target("v69"),
name=ENTRY_PRIMFUNC_NAME_,
)

hexagon_mod_remote = hexagon_session.load_module(hexagon_mod_local)

input_data = allocate_hexagon_array(
hexagon_session.device,
data=reference_input_np,
)

output_data = allocate_hexagon_array(
hexagon_session.device,
tensor_shape=reference_output_np.shape,
dtype=reference_output_np.dtype,
data=np.full(shape, 0, dtype="int8"),
)

hexagon_mod_remote(input_data, output_data)

output_data_np = output_data.numpy()
tvm.testing.assert_allclose(reference_output_np, output_data_np)


# --------------------------------------------------------------------------------------------------
# Test cases...
# --------------------------------------------------------------------------------------------------


@tvm.testing.requires_hexagon
def test_baseline(
hexagon_session: Session, original_shape: List, dtype: str
) -> tvm.ir.module.IRModule:
dim0_size, dim1_size = original_shape

@tvm.script.ir_module
class AddOneBaseline:
"""
Provides "add-one" functionality in a single, traditional PrimFunc.
Used as a baseline for comparison / validation with other approaches.
I.e., approaches that use various aspects of PrimFunc slicing and/or
one PrimFunc calling into another.
"""

@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})

A = T.match_buffer(a, original_shape, dtype=dtype)
B = T.match_buffer(b, original_shape, dtype=dtype)

for i in range(dim0_size):
for j in range(dim1_size):
B[i, j] = A[i, j] + T.cast(1, dtype)

evaluate_ir_module_(hexagon_session, original_shape, dtype, AddOneBaseline)


@tvm.testing.requires_hexagon
def test_pass_pointers(
hexagon_session: Session, original_shape: List, dtype: str
) -> tvm.ir.module.IRModule:
# Some notable requirements for this approach to intra-IRModule primfunc calls:
#
# - The specific dtype must he hardcoded into the callee's signature, e.g.
# 'a_data: T.Ptr[T.int8]'.
#
# - The module's entry function must have a PrimFunc with the attribute
# "tir.is_entry_func": True.
# (This is related to having an IRModule with multiple PrimFuncs.)
#
# - The callee PrimFunc must have the "calling_conv": 3 attribute.
# (Where '3' is the number corresponding to 'tvm::CallingConv::kIntraModule'.)
# This ensures that the caller's 'A.data' argument, and the callee's 'a_data: T.ptr[T.int8]'
# parameter, both lower to 'uint8_t *', or something equivalent.
#
# - The callee must use 'T.buffer_decl' to describe the tile on which the callee
# shall operate. As of this writing, there's no clear way to make this
# work with 'T.decl_buffer'.
if dtype != "int8":
pytest.skip(f"Unsupported dtype for this test: {dtype}")

dim0_size, dim1_size = original_shape

tile_shape = (dim1_size,)

@tvm.script.ir_module
class AddOnePassPointers:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True, "tir.is_entry_func": True})

A = T.match_buffer(a, original_shape, dtype=dtype)
B = T.match_buffer(b, original_shape, dtype=dtype)

for i in range(dim0_size):
T.call_extern("", "callee", A.data, B.data, i)

@T.prim_func
def callee(a_data: T.Ptr[T.int8], b_data: T.Ptr[T.int8], i: T.int32):
T.func_attr(
{
"global_symbol": "callee",
"tir.noalias": True,
"calling_conv": 3, # tvm::CallingConv::kIntraModule
}
)

A_tile = T.buffer_decl(tile_shape, dtype, a_data, elem_offset=dim1_size * i)
B_tile = T.buffer_decl(tile_shape, dtype, b_data, elem_offset=dim1_size * i)

for j in range(dim1_size):
B_tile[j] = A_tile[j] + T.cast(1, dtype)

evaluate_ir_module_(hexagon_session, original_shape, dtype, AddOnePassPointers)


# --------------------------------------------------------------------------------------------------

if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 5edefbe

Please sign in to comment.