Skip to content

Commit

Permalink
[tir][testing] Enable intra-IRModule primfunc call
Browse files Browse the repository at this point in the history
Changes / fixes to allow `PrimFunc`-`PrimFunc` calls
within the same `IRModule`:

- Fix bug in `BufferAllocationLocator` constructor in which
    `BufferDecl` nodes could trigger duplicate buffer allocations.

- Add `tvm::CallingConv::kIntraModule` calling convetion, to
  allow a PrimFunc's signature to remain unchanged by the
  `MakePackedAPI` pass during lowering.

- Add `tests/python/contrib/test_hexagon/test_call_tir.py` unit tests
  to demonstrate one viable approach for intra-module `PrimFunc`
  calls.

- Add a new tensor-content generator, `TensorContentFullRangeCOrder`,
  to `tests/python/contrib/test_hexagon/pytest_util.py`.
  • Loading branch information
Christian Convey committed Jan 31, 2023
1 parent 5e652c1 commit 2f25578
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 0 deletions.
45 changes: 45 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,51 @@ enum class CallingConv : int {
* - Implementation: defined by device runtime(e.g. runtime/cuda)
*/
kDeviceKernelLaunch = 2,
/*!
* \brief For functions only called by other functions within the same IR module.
*
* This calling convention exists to support one PrimFunc calling another PrimFunc
* within the same IRModule.
*
* Overview / Purpose:
*
* - This calling convention is intended only for PrimFuncs whose callers reside in
* the same IRModule. This is the only supported use case.
*
* - The details of the calling convention may change frequently as TVM evolves.
* Therefore users are discouraged from attempting to use this calling convention
* outside of the supported use case(s).
*
* Background: Why the other calling conventions aren't sufficent.
* - kDefault and kCPackedFunc both involve the PackedFunc signature standard, which
* is not compatible with the current approach for intra-module PrimFunc calls.
*
* - kDeviceKernelLaunch indicates a split between host-side caller and device-side
* callee. This is similar to kIntraModule, but kIntraModule is intended for
* both caller and callee residing in the same runtime module.
*
* Current mechanics / usage requirements:
*
* - A PrimFunc with this calling convention will NOT undergo any of the signature
* transformations provided by the MakePackedAPI pass.
*
* - Supported use cases, and their corresponding unit tests, are all expressed as
* TVMScript.
*
* - The callsite must use `T.call_extern`.
*
* - There's a 1:1 correspondence between the caller argument list and the callee
* parameter list. That is, the n'th call argument maps to the n'th callee parameter.
*
* - For a given argument / parameter, the following mappings are supported. Other mappings
* may work but are not verified in TVM's unit tests.
*
* - (caller: 'T.int8') --> (callee: 'T.int8')
*
* - (caller: 'A.data', where A is a Buffer with dtype=int8 and axis_separators=[]) -->
* (callee: 'T.Ptr[T.int8]')
*/
kIntraModule = 3,
};

/*!
Expand Down
7 changes: 7 additions & 0 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/transform.h>

#include "ir_utils.h"
#include "tvm/tir/var.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -111,6 +112,12 @@ class BufferAllocationLocator : public StmtExprMutator {
collector(func->body);
unmanaged_allocations_ = collector.unmanaged_allocations;

for (Var param : func->params) {
if (param->type_annotation.defined() && param->type_annotation.as<PointerTypeNode>()) {
unmanaged_allocations_.insert(param.get());
}
}

for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
arg_buffer_vars.emplace(buffer->data.get());
Expand Down
6 changes: 6 additions & 0 deletions tests/python/contrib/test_hexagon/pytest_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def get_numpy_dtype_info(dtype) -> Union[np.finfo, np.iinfo]:
TensorContentSequentialCOrder = collections.namedtuple(
"TensorContentSequentialCOrder", ["start_value", "increment"]
)
TensorContentFullRangeCOrder = collections.namedtuple("TensorContentFullRangeCOrder", [])
TensorContentRandom = collections.namedtuple("TensorContentRandom", [])
TensorContentDtypeMin = collections.namedtuple("TensorContentDtypeMin", [])
TensorContentDtypeMax = collections.namedtuple("TensorContentDtypeMax", [])
Expand Down Expand Up @@ -172,5 +173,10 @@ def create_populated_numpy_ndarray(
next_elem_val += itp.increment
return a

elif isinstance(itp, TensorContentFullRangeCOrder):
num_elements = np.prod(input_shape)
info = get_numpy_dtype_info(dtype)
return np.linspace(info.min, info.max, num=num_elements, dtype=dtype).reshape(input_shape)

else:
raise ValueError(f"Unexpected input_tensor_populator type: {type(itp)}")
215 changes: 215 additions & 0 deletions tests/python/contrib/test_hexagon/test_call_tir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

""" Various tests related to the (WIP) support for having
one PrimFunc call another PrimFunc within the same IRModule.
"""

from typing import List
import pytest
import numpy as np

import tvm
import tvm.testing
import tvm.script
from tvm.script import tir as T

from tvm.contrib.hexagon.session import Session
from tvm.contrib.hexagon import allocate_hexagon_array
import test_hexagon.pytest_util as pytest_util
from test_hexagon.infrastructure import get_hexagon_target


# NOTE(cconvey): These pylint warnings should be re-enabled as TVM's pylint configuration matures.
# pylint: disable=missing-function-docstring,no-self-argument,invalid-name
# pylint: disable=redefined-outer-name,missing-class-docstring

# --------------------------------------------------------------------------------------------------
# Test parameters
# --------------------------------------------------------------------------------------------------

# The shape of the original (unsplit) tensors.
# We assume that each shape describes a non-empty 2D tensor.
original_shape = tvm.testing.parameter(
# degenerate cases...
[1, 1],
[1, 2],
[2, 1],
[2, 2],
# arbitrary, provided for variety
[5, 3],
[3, 5],
)

# This dtype is arbitrary, but it must match the dtype that's hardcoded into the
# callee's function signature. E.g., 'a_data: T.Ptr[T.int8]'.
#
# Hopefully over time we'll find a way to soften this limitation, at least for
# some approaches to PrimFunc-to-PrimFunc calls.
dtype = tvm.testing.parameter("int8")

# --------------------------------------------------------------------------------------------------
# Helper functions / definitions...
# --------------------------------------------------------------------------------------------------

HEXAGON_TARGET_ = get_hexagon_target("v69")

ENTRY_PRIMFUNC_NAME_ = "main"


def get_reference_input_tensor_(shape: list, dtype: str) -> np.array:
a = pytest_util.create_populated_numpy_ndarray(
shape, dtype, pytest_util.TensorContentFullRangeCOrder()
)

return a


def get_reference_output_tensor_(shape: list, dtype: str) -> np.array:
return get_reference_input_tensor_(shape, dtype) + 1


def evaluate_ir_module_(
hexagon_session: Session, shape: List, dtype: str, ir_mod: tvm.ir.module.IRModule
) -> np.array:
reference_input_np = get_reference_input_tensor_(shape, dtype)
reference_output_np = get_reference_output_tensor_(shape, dtype)

hexagon_mod_local = tvm.build(
ir_mod,
target=get_hexagon_target("v69"),
name=ENTRY_PRIMFUNC_NAME_,
)

hexagon_mod_remote = hexagon_session.load_module(hexagon_mod_local)

input_data = allocate_hexagon_array(
hexagon_session.device,
data=reference_input_np,
)

output_data = allocate_hexagon_array(
hexagon_session.device,
tensor_shape=reference_output_np.shape,
dtype=reference_output_np.dtype,
data=np.full(shape, 0, dtype="int8"),
)

hexagon_mod_remote(input_data, output_data)

output_data_np = output_data.numpy()
tvm.testing.assert_allclose(reference_output_np, output_data_np)


# --------------------------------------------------------------------------------------------------
# Test cases...
# --------------------------------------------------------------------------------------------------


@tvm.testing.requires_hexagon
def test_baseline(
hexagon_session: Session, original_shape: List, dtype: str
) -> tvm.ir.module.IRModule:
dim0_size, dim1_size = original_shape

@tvm.script.ir_module
class AddOneBaseline:
"""
Provides "add-one" functionality in a single, traditional PrimFunc.
Used as a baseline for comparison / validation with other approaches.
I.e., approaches that use various aspects of PrimFunc slicing and/or
one PrimFunc calling into another.
"""

@T.prim_func
def main(a: T.handle, b: T.handle):
# We exchange data between function by handles, which are similar to pointer.
T.func_attr({"global_symbol": "main", "tir.noalias": True})

A = T.match_buffer(a, original_shape, dtype=dtype)
B = T.match_buffer(b, original_shape, dtype=dtype)

for i in range(dim0_size):
for j in range(dim1_size):
B[i, j] = A[i, j] + T.cast(1, dtype)

evaluate_ir_module_(hexagon_session, original_shape, dtype, AddOneBaseline)


@tvm.testing.requires_hexagon
def test_pass_pointers(
hexagon_session: Session, original_shape: List, dtype: str
) -> tvm.ir.module.IRModule:
# Some notable requirements for this approach to intra-IRModule primfunc calls:
#
# - The specific dtype must he hardcoded into the callee's signature, e.g.
# 'a_data: T.Ptr[T.int8]'.
#
# - The module's entry function must have a PrimFunc with the attribute
# "tir.is_entry_func": True.
# (This is related to having an IRModule with multiple PrimFuncs.)
#
# - The callee PrimFunc must have the "calling_conv": 3 attribute.
# (Where '3' is the number corresponding to 'tvm::CallingConv::kIntraModule'.)
# This ensures that the caller's 'A.data' argument, and the callee's 'a_data: T.ptr[T.int8]'
# parameter, both lower to 'uint8_t *', or something equivalent.
#
# - The callee must use 'T.buffer_decl' to describe the tile on which the callee
# shall operate. As of this writing, there's no clear way to make this
# work with 'T.decl_buffer'.
if dtype != "int8":
pytest.skip(f"Unsupported dtype for this test: {dtype}")

dim0_size, dim1_size = original_shape

tile_shape = (dim1_size,)

@tvm.script.ir_module
class AddOnePassPointers:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True, "tir.is_entry_func": True})

A = T.match_buffer(a, original_shape, dtype=dtype)
B = T.match_buffer(b, original_shape, dtype=dtype)

for i in range(dim0_size):
T.call_extern("", "callee", A.data, B.data, i)

@T.prim_func
def callee(a_data: T.Ptr[T.int8], b_data: T.Ptr[T.int8], i: T.int32):
T.func_attr(
{
"global_symbol": "callee",
"tir.noalias": True,
"calling_conv": 3, # tvm::CallingConv::kIntraModule
}
)

A_tile = T.buffer_decl(tile_shape, dtype, a_data, elem_offset=dim1_size * i)
B_tile = T.buffer_decl(tile_shape, dtype, b_data, elem_offset=dim1_size * i)

for j in range(dim1_size):
B_tile[j] = A_tile[j] + T.cast(1, dtype)

evaluate_ir_module_(hexagon_session, original_shape, dtype, AddOnePassPointers)


# --------------------------------------------------------------------------------------------------

if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 2f25578

Please sign in to comment.