From 70d30bf3e4408c3c9f91839c79e2c79a5bd94feb Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 24 Feb 2023 10:50:10 -0800 Subject: [PATCH] [TUZ-65] Add a batch of new operators to support new high priority models (#23) * Add a batch of new operators to support new high priority models. * Drop import. * Some cleanup. * Switched as much as possible to relax opset. --- python/tvm/relax/frontend/onnx_frontend.py | 305 ++++++++++--- src/topi/einsum.cc | 4 +- .../relax/frontend/test_onnx_frontend.py | 400 +++++++++--------- 3 files changed, 448 insertions(+), 261 deletions(-) diff --git a/python/tvm/relax/frontend/onnx_frontend.py b/python/tvm/relax/frontend/onnx_frontend.py index 5eaee280d1a9..5f950a89076f 100644 --- a/python/tvm/relax/frontend/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx_frontend.py @@ -161,7 +161,7 @@ class MatMul(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.matmul, inputs[0], inputs[1]) + return relax.op.matmul(inputs[0], inputs[1]) class Div(OnnxOpConverter): @@ -169,7 +169,7 @@ class Div(OnnxOpConverter): @classmethod def _impl_v14(cls, bb, inputs, attr): - return bb.emit_te(topi.divide, inputs[0], inputs[1]) + return relax.op.divide(inputs[0], inputs[1]) class Sigmoid(OnnxOpConverter): @@ -177,7 +177,7 @@ class Sigmoid(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.sigmoid, inputs[0]) + return relax.op.sigmoid(inputs[0]) class Softmax(OnnxOpConverter): @@ -186,7 +186,7 @@ class Softmax(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): axis = attr.get("axis", -1) - return bb.emit_te(topi.nn.softmax, inputs[0], axis=axis) + return relax.op.nn.softmax(inputs[0], axis=axis) class Transpose(OnnxOpConverter): @@ -203,7 +203,7 @@ class Unsqueeze(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - input = inputs[0] + data = inputs[0] axes = inputs[1] if isinstance(axes, relax.Constant): @@ -211,8 +211,8 @@ def _impl_v13(cls, bb, inputs, attr): constant_axes = list(map(int, constant_axes)) constant_axes = sorted(constant_axes) for axis in constant_axes: - input = bb.emit_te(topi.expand_dims, input, axis=axis, num_newaxis=1) - return input + data = relax.op.expand_dims(data, axis=axis) + return data raise NotImplementedError("Unsqueeze with dynamic axes is not supported.") @@ -223,7 +223,7 @@ class Concat(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): axis = attr.get("axis", 0) - return bb.emit_te(topi.concatenate, inputs, axis) + return relax.op.concat(inputs, axis=axis) class Add(OnnxOpConverter): @@ -231,7 +231,7 @@ class Add(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.add, inputs[0], inputs[1]) + return relax.op.add(inputs[0], inputs[1]) class Mul(OnnxOpConverter): @@ -239,7 +239,7 @@ class Mul(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.multiply, inputs[0], inputs[1]) + return relax.op.multiply(inputs[0], inputs[1]) class Cast(OnnxOpConverter): @@ -278,14 +278,14 @@ def _impl_v13(cls, bb, inputs, attr): # Compute Y = alpha * A X B + beta * C if alpha is not None: - A = bb.emit_te(topi.multiply, A, relax.const(alpha, dtype=dtype)) + A = bb.normalize(relax.op.multiply(A, relax.const(alpha, dtype=dtype))) Y = bb.emit_te(topi.matmul, A, B, transA, transB) if C is not None: if beta is not None: - C = bb.emit_te(topi.multiply, C, relax.const(beta, dtype=dtype)) - Y = bb.emit_te(topi.add, Y, C) + C = bb.normalize(relax.op.multiply(C, relax.const(beta, dtype=dtype))) + Y = relax.op.add(Y, C) return Y @@ -330,19 +330,7 @@ class Gelu(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr): - x = inputs[0] - - # Declare constants - const_dtype = x.checked_type.dtype - half = relax.const(0.5, dtype=const_dtype) - one = relax.const(1.0, dtype=const_dtype) - sqrt2 = relax.const(math.sqrt(2.0), dtype=const_dtype) - - # Compute gelu - term1 = bb.emit_te(topi.multiply, half, x) - erf = bb.emit_te(topi.fast_erf, bb.emit_te(topi.divide, x, sqrt2)) - term2 = bb.emit_te(topi.add, one, erf) - return bb.emit_te(topi.multiply, term1, term2) + return relax.op.nn.gelu(inputs[0]) class BiasGelu(OnnxOpConverter): @@ -353,14 +341,8 @@ class BiasGelu(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr): - x = inputs[0] - b = inputs[1] - - b_dims = b.checked_type.ndim - assert b_dims == 1, "BiasGelu bias term must be a 1D tensor." - - inp = bb.emit_te(topi.add, x, b) - return Gelu._impl_v1(bb, [inp], attr) + inp = relax.op.add(inputs[0], inputs[1]) + return relax.op.nn.gelu(inp) class Where(OnnxOpConverter): @@ -368,7 +350,7 @@ class Where(OnnxOpConverter): @classmethod def _impl_v16(cls, bb, inputs, attr): - return bb.emit_te(topi.where, *inputs) + return relax.op.where(inputs[0], inputs[1], inputs[2]) class Clip(OnnxOpConverter): @@ -389,7 +371,7 @@ class Equal(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.equal, inputs[0], inputs[1]) + return relax.op.equal(inputs[0], inputs[1]) class Shape(OnnxOpConverter): @@ -397,7 +379,7 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.shape, inputs[0], "int64") + return relax.op.shape_of(inputs[0]) class Not(OnnxOpConverter): @@ -413,7 +395,7 @@ class Tanh(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.tanh, inputs[0]) + return relax.op.tanh(inputs[0]) class Sqrt(OnnxOpConverter): @@ -421,7 +403,7 @@ class Sqrt(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.sqrt, inputs[0]) + return relax.op.sqrt(inputs[0]) class Relu(OnnxOpConverter): @@ -429,7 +411,7 @@ class Relu(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.nn.relu, inputs[0]) + return relax.op.nn.relu(inputs[0]) class Pow(OnnxOpConverter): @@ -444,23 +426,23 @@ class Conv(OnnxOpConverter): """Convert an onnx Conv node into an equivalent Relax expression.""" @classmethod - def _impl_v13(cls, bb, inputs, attr): - # not supported yet - assert "auto_pad" not in attr - assert "group" not in attr - # supported conv2d - return bb.emit_te( - topi.add, - bb.emit_te( - topi.nn.conv2d, - inputs[0], - inputs[1], + def _impl_v11(cls, bb, inputs, attr): + conv_out = bb.normalize( + relax.op.nn.conv2d( + data=inputs[0], + weight=inputs[1], strides=attr.get("strides", 1), padding=attr.get("pads", 0), - dilation=attr.get("dilations", 1), - ), - bb.emit_te(topi.expand_dims, inputs[2], axis=1, num_newaxis=2), + dilation=attr.get("dilation", 1), + groups=attr.get("group", 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) ) + if inputs[2] is not None: + conv_out = relax.op.add(conv_out, inputs[2]) + + return conv_out class Erf(OnnxOpConverter): @@ -499,11 +481,10 @@ class Squeeze(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - if len(inputs) > 1: + axis = inputs[1] + if axis is not None: axis = [int(x) for x in inputs[1].data.numpy()] - else: - axis = None - return bb.emit_te(topi.squeeze, inputs[0], axis=axis) + return relax.op.squeeze(inputs[0], axis) class Constant(OnnxOpConverter): @@ -553,7 +534,7 @@ def _impl_v9(cls, bb, inputs, attr): for i in range(shape_ndim): shape_vars.append(tvm.tir.Var("x_%d" % i, "int64")) bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) - return bb.normalize(relax.op.broadcast_to(const_value, relax.ShapeExpr(shape_vars))) + return relax.op.broadcast_to(const_value, relax.ShapeExpr(shape_vars)) class Sub(OnnxOpConverter): @@ -561,7 +542,85 @@ class Sub(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - return bb.emit_te(topi.subtract, inputs[0], inputs[1]) + return relax.op.subtract(inputs[0], inputs[1]) + + +class Sin(OnnxOpConverter): + """Converts an onnx Sin node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr): + return relax.op.sin(inputs[0]) + + +class Cos(OnnxOpConverter): + """Converts an onnx Cos node into an equivalent Relax expression.""" + + @classmethod + def _impl_v7(cls, bb, inputs, attr): + return relax.op.cos(inputs[0]) + + +class Neg(OnnxOpConverter): + """Converts an onnx Neg node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return relax.op.negative(inputs[0]) + + +class Abs(OnnxOpConverter): + """Converts an onnx Abs node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return relax.op.abs(inputs[0]) + + +class Min(OnnxOpConverter): + """Converts an onnx Min node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + # Expand inputs, stack them, then perform minimum over the new axis. + inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] + stacked_tensor = relax.op.concat(inputs, axis=0) + return relax.op.min(stacked_tensor, axis=0) + + +class Max(OnnxOpConverter): + """Converts an onnx Max node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + # Expand inputs, stack them, then perform maximum over the new axis. + inputs = [bb.normalize(relax.op.expand_dims(i, axis=0)) for i in inputs] + stacked_tensor = relax.op.concat(inputs, axis=0) + return relax.op.max(stacked_tensor, axis=0) + + +class Log(OnnxOpConverter): + """Converts an onnx Log node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return relax.op.log(inputs[0]) + + +class Less(OnnxOpConverter): + """Converts an onnx Less node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return relax.op.less(inputs[0], inputs[1]) + + +class LessOrEqual(OnnxOpConverter): + """Converts an onnx LessOrEqual node into an equivalent Relax expression.""" + + @classmethod + def _impl_v13(cls, bb, inputs, attr): + return relax.op.less_equal(inputs[0], inputs[1]) class Split(OnnxOpConverter): @@ -579,8 +638,7 @@ def _impl_v1(cls, bb, inputs, attr): # When splits isnt specified divide evenly over axis. else: indices = attr["tvm_custom"]["num_outputs"] - output = bb.emit_te(topi.split, inputs[0], indices, attr.get("axis", 0)) - return output + return bb.emit_te(topi.split, inputs[0], indices, attr.get("axis", 0)) @classmethod def _impl_v13(cls, bb, inputs, attr): @@ -601,8 +659,7 @@ def _impl_v13(cls, bb, inputs, attr): # When splits isnt specified divide evenly over axis. else: indices = attr["tvm_custom"]["num_outputs"] - output = bb.emit_te(topi.split, inputs[0], indices, axis=attr.get("axis", 0)) - return output + return bb.emit_te(topi.split, inputs[0], indices, axis=attr.get("axis", 0)) class Slice(OnnxOpConverter): @@ -859,7 +916,109 @@ def massage(bb, tensor): return relax.Tuple([output, present]) -# pylint: enable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines +class Identity(OnnxOpConverter): + """Converts an onnx Identity node into an equivalent Relax expression.""" + + @classmethod + def _impl_v1(cls, bb, inputs, attr): + return inputs[0] + + +class Resize(OnnxOpConverter): + """Converts an onnx Resize node into an equivalent Relax expression.""" + + @classmethod + def _impl_v18(cls, bb, inputs, attr): + # Extract the many attributes of resize. + coord_mode = attr.get("coordinate_transformation_mode", b"half_pixel").decode("ascii") + cubic_coeff_a = attr.get("cubic_coeff_a", -0.75) + exclude_outside = attr.get("exclude_outside", 0) + extrapolation_value = attr.get("extrapolation_value", 0.0) + mode = attr.get("mode", b"nearest").decode("ascii") + rounding_method = attr.get("nearest_mode", b"round_prefer_floor").decode("ascii") + + # Adapt attributes to fit TVM definition. + if mode == "nearest": + mode = "nearest_neighbor" + + # Unpack inputs. + x = inputs[0] + roi = inputs[1] + scales = inputs[2] + sizes = inputs[3] + ndims = len(x.struct_info.shape) + assert ndims == 4, "Only resize2d is currently supported." + + assert ( + scales is None or sizes is None + ), "Only one of scales and sizes can be provided in Resize." + + # Define relax implementation. + if roi is not None: + roi = relax.op.concat( + [ + relax.op.strided_slice(roi, axes=[0], begin=[2], end=[ndims]), + relax.op.strided_slice(roi, axes=[0], begin=[ndims + 2], end=[2 * ndims]), + ], + axis=0, + ) + else: + roi = [0.0] * 4 + + # Convert scales to sizes if needed. + if scales is not None: + assert isinstance(scales, relax.Constant), "Only constant scales currently supported." + scales = scales.data.numpy() + sizes_shape = [dim.value for dim in x.struct_info.shape] + sizes = (sizes_shape * scales)[2:].astype("int64").tolist() + else: + assert isinstance( + sizes, relax.Constant + ), "Only constant output size currently supported." + sizes = sizes.data.numpy().astype("int64").tolist()[2:] + + # TODO(jwfromm) relax.image.resize2d runs into some issues with dynamism. + return bb.emit_te( + topi.image.resize2d, + x, + roi, + sizes, + layout="NCHW", + method=mode, + coordinate_transformation_mode=coord_mode, + rounding_method=rounding_method, + bicubic_alpha=cubic_coeff_a, + bicubic_exclude=exclude_outside, + extrapolation_value=extrapolation_value, + ) + + +class Einsum(OnnxOpConverter): + """Converts an onnx Einsum node into an equivalent Relax expression.""" + + @classmethod + def _impl_v12(cls, bb, inputs, attr): + equation = attr["equation"].decode("utf-8") + return bb.emit_te(topi.einsum, equation, *inputs) + + +class Range(OnnxOpConverter): + """Converts an onnx Range node into an equivalent Relax expression.""" + + @classmethod + def _impl_v12(cls, bb, inputs, attr): + # TODO(jwfromm) Something is wrong with topi.arange, doesnt work with any relax expressions. + # Unpack inputs. Need to add relax.op.resize + start = inputs[0] + assert isinstance(start, relax.Constant), "Constant start required for range." + start = start.data.numpy().tolist() + limit = inputs[1] + assert isinstance(limit, relax.Constant), "Constant limit required for range." + limit = limit.data.numpy().tolist() + delta = inputs[2] + assert isinstance(delta, relax.Constant), "Constant delta required for Range." + step = delta.data.numpy().tolist() + return bb.emit_te(topi.arange, start, limit, step) def _get_convert_map(): @@ -894,9 +1053,19 @@ def _get_convert_map(): "Squeeze": Squeeze, "Constant": Constant, "Sub": Sub, + "Sin": Sin, + "Cos": Cos, + "Neg": Neg, + "Abs": Abs, + "Min": Min, + "Max": Max, + "Log": Log, + "Less": Less, + "LessOrEqual": LessOrEqual, "LayerNormalization": relay.frontend.onnx.LayerNormalization, "SkipLayerNormalization": relay.frontend.onnx.SkipLayerNormalization, "EmbedLayerNormalization": relay.frontend.onnx.EmbedLayerNormalization, + "InstanceNormalization": relay.frontend.onnx.InstanceNorm, # defs/reduction "ReduceMax": relay.frontend.onnx.ReduceMax, "ReduceMin": relay.frontend.onnx.ReduceMin, @@ -919,6 +1088,10 @@ def _get_convert_map(): "GlobalAveragePool": relay.frontend.onnx.GlobalAveragePool, "Flatten": relay.frontend.onnx.Flatten, "MaxPool": relay.frontend.onnx.MaxPool, + "Identity": Identity, + "Resize": Resize, + "Einsum": Einsum, + "Range": Range, } @@ -1101,6 +1274,8 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): attr["tvm_custom"]["num_outputs"] = len(outputs) op = self._convert_operator(op_name, inputs, attr, self.opset) + # Create struct information for the new operator. + op = self.bb.normalize(op) if not isinstance(op, relax.Tuple): if isinstance(op.checked_type, tvm.ir.type.TupleType): diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index 892a17e58d7f..dabf38c9a244 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -265,13 +265,13 @@ class EinsumBuilder { auto ellipsis_shape = ellipsis_shape_.value(); for (int i = 0; i < static_cast(ellipsis_shape.size()); ++i) { reduction_axes->push_back( - IterVar(Range(0, ellipsis_shape[i]), Var("k"), IterVarType::kCommReduce)); + IterVar(Range(0, ellipsis_shape[i]), Var("k", DataType::Int(64)), IterVarType::kCommReduce)); ellipsis_indices->push_back(reduction_axes->back()->var); } } else { // Normal label reduction_axes->push_back(IterVar(Range(0, label_to_extent_[label]), - Var(std::string(1, label)), IterVarType::kCommReduce)); + Var(std::string(1, label), DataType::Int(64)), IterVarType::kCommReduce)); label_to_index->emplace(label, reduction_axes->back()->var); } } diff --git a/tests/python/relax/frontend/test_onnx_frontend.py b/tests/python/relax/frontend/test_onnx_frontend.py index 5ea92a8bbf0a..7ac30f7c794f 100644 --- a/tests/python/relax/frontend/test_onnx_frontend.py +++ b/tests/python/relax/frontend/test_onnx_frontend.py @@ -97,6 +97,8 @@ def check_correctness( # Convert the onnx model into relax through the onnx importer. tvm_model = relax.from_onnx(model, opset=opset) + # Legalize any relax ops into tensorir. + tvm_model = relax.transform.LegalizeOps()(tvm_model) # Compile the relax graph into a VM then run. with tvm.transform.PassContext(opt_level=3): # TODO add target configuration. @@ -108,8 +110,18 @@ def check_correctness( # Wrap as a list if there is only one output. if isinstance(tvm_output, tvm.nd.NDArray): tvm_output = [tvm_output] + # If the output is a shape tuple, convert it to an ndarray for comparison. + if isinstance(tvm_output, tvm.runtime.ShapeTuple): + tvm_output = [tvm.nd.array([int(i) for i in tvm_output])] - assert len(tvm_output) == len(ort_output), "Unequal number of outputs" + tvm_num_outputs = len(tvm_output) + # Shape tuples need to be handled specially. + if isinstance(tvm_output, tvm.runtime.ShapeTuple): + tvm_num_outputs = 1 + + # Check that number of outputs match. + + assert tvm_num_outputs == len(ort_output), "Unequal number of outputs" for (tvm_out, ort_out) in zip(tvm_output, ort_output): # TODO Allow configurable tolerance. @@ -147,6 +159,70 @@ def test_sanitize(input_names, expected_names): assert param.name_hint == expected_names[i] +def verify_unary(op_name, shape, attrs={}, domain=None): + test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain) + graph = helper.make_graph( + [test_node], + "elemwise_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], + ) + + model = helper.make_model(graph, producer_name="elemwise_test") + check_correctness(model) + + +def verify_binary(op_name, shape_a, shape_b, shape_c, attrs={}, domain=None): + test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, domain=domain) + graph = helper.make_graph( + [test_node], + "binary_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a), + helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b), + ], + outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, shape_c)], + ) + + model = helper.make_model(graph, producer_name="binary_test") + check_correctness(model) + + +def verify_compare(op_name, shape, attrs={}, domain=None): + test_node = helper.make_node(op_name, ["a", "b"], ["c"], **attrs, domain=domain) + graph = helper.make_graph( + [test_node], + "compare_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("b", TensorProto.FLOAT, shape), + ], + outputs=[helper.make_tensor_value_info("c", TensorProto.BOOL, shape)], + ) + + model = helper.make_model(graph, producer_name="compare_test") + check_correctness(model) + + +def verify_ternary(op_name, shape_a, shape_b, shape_c, shape_d, attrs={}, domain=None): + test_node = helper.make_node(op_name, ["a", "b", "c"], ["d"], **attrs, domain=domain) + graph = helper.make_graph( + [test_node], + "ternary_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, shape_a), + helper.make_tensor_value_info("b", TensorProto.FLOAT, shape_b), + helper.make_tensor_value_info("c", TensorProto.FLOAT, shape_c), + ], + outputs=[helper.make_tensor_value_info("d", TensorProto.FLOAT, shape_d)], + ) + + model = helper.make_model(graph, producer_name="ternary_test") + check_correctness(model) + + @pytest.mark.parametrize("dynamic", [True, False]) def test_matmul(dynamic): matmul_node = helper.make_node("MatMul", ["a", "b"], ["c"]) @@ -176,54 +252,15 @@ def test_matmul(dynamic): def test_concat(): - concat_node = helper.make_node("Concat", ["a", "b"], ["ab"], axis=0) - - graph = helper.make_graph( - [concat_node], - "concat_test", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [1, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32]), - ], - outputs=[helper.make_tensor_value_info("ab", TensorProto.FLOAT, [2, 32])], - ) - - model = helper.make_model(graph, producer_name="concat_test") - check_correctness(model) + verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0}) def test_add(): - add_node = helper.make_node("Add", ["a", "b"], ["ab"]) - - graph = helper.make_graph( - [add_node], - "add_test", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [1, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32]), - ], - outputs=[helper.make_tensor_value_info("ab", TensorProto.FLOAT, [1, 32])], - ) - - model = helper.make_model(graph, producer_name="add_test") - check_correctness(model) + verify_binary("Add", [1, 32], [1, 32], [1, 32]) def test_mul(): - mul_node = helper.make_node("Mul", ["a", "b"], ["ab"]) - - graph = helper.make_graph( - [mul_node], - "mul_test", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [1, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32]), - ], - outputs=[helper.make_tensor_value_info("ab", TensorProto.FLOAT, [1, 32])], - ) - - model = helper.make_model(graph, producer_name="mul_test") - check_correctness(model) + verify_binary("Mul", [1, 32], [1, 32], [1, 32]) @pytest.mark.parametrize( @@ -324,62 +361,19 @@ def test_reshape(in_shape, shape, out_shape): def test_div(): - div_node = helper.make_node("Div", ["a", "b"], ["c"]) - - graph = helper.make_graph( - [div_node], - "div_test", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32]), - ], - outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32])], - ) - - model = helper.make_model(graph, producer_name="div_test") - check_correctness(model) + verify_binary("Div", [32, 32], [32, 32], [32, 32]) def test_sigmoid(): - sigmoid_node = helper.make_node("Sigmoid", ["a"], ["b"]) - - graph = helper.make_graph( - [sigmoid_node], - "sigmoid_test", - inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])], - outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32])], - ) - - model = helper.make_model(graph, producer_name="sigmoid_test") - check_correctness(model) + verify_unary("Sigmoid", [32, 32]) def test_softmax(): - softmax_node = helper.make_node("Softmax", ["a"], ["b"]) - - graph = helper.make_graph( - [softmax_node], - "softmax_test", - inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32, 32])], - outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32, 32])], - ) - - model = helper.make_model(graph, producer_name="softmax_test") - check_correctness(model) + verify_unary("Softmax", [32, 32, 32]) def test_transpose(): - transpose_node = helper.make_node("Transpose", ["a"], ["b"], perm=[1, 2, 0]) - - graph = helper.make_graph( - [transpose_node], - "transpose_test", - inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32, 32])], - outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32, 32])], - ) - - model = helper.make_model(graph, producer_name="transpose_test") - check_correctness(model) + verify_unary("Transpose", [32, 32, 32], attrs={"perm": [1, 2, 0]}) def test_unsqueeze(): @@ -398,34 +392,11 @@ def test_unsqueeze(): def test_gelu(): - gelu_node = helper.make_node("Gelu", ["a"], ["b"], domain="com.microsoft") - - graph = helper.make_graph( - [gelu_node], - "gelu_test", - inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32])], - outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [32, 32])], - ) - - model = helper.make_model(graph, producer_name="gelu_test") - check_correctness(model) + verify_unary("Gelu", [32, 32], domain="com.microsoft") def test_bias_gelu(): - bias_gelu_node = helper.make_node("BiasGelu", ["a", "b"], ["c"], domain="com.microsoft") - - graph = helper.make_graph( - [bias_gelu_node], - "bias_gelu_test", - inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]), - ], - outputs=[helper.make_tensor_value_info("c", TensorProto.FLOAT, [32, 32])], - ) - - model = helper.make_model(graph, producer_name="bias_gelu_test") - check_correctness(model) + verify_binary("BiasGelu", [32, 32], [32], [32, 32], domain="com.microsoft") def test_where(): @@ -532,51 +503,15 @@ def test_not(): def test_tanh(): - tanh_node = helper.make_node("Tanh", ["x"], ["y"]) - shape = [9, 8, 7, 6] - graph = helper.make_graph( - [tanh_node], - "tanh_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], - ) - - model = helper.make_model(graph, producer_name="tanh_test") - check_correctness(model) + verify_unary("Tanh", [9, 8, 7, 6]) def test_sqrt(): - sqrt_node = helper.make_node("Sqrt", ["x"], ["y"]) - shape = [32, 32] - graph = helper.make_graph( - [sqrt_node], - "sqrt_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], - ) - - model = helper.make_model(graph, producer_name="sqrt_test") - check_correctness(model) + verify_unary("Sqrt", [32, 32]) def test_relu(): - relu_node = helper.make_node("Relu", ["x"], ["y"]) - shape = [32, 32] - graph = helper.make_graph( - [relu_node], - "relu_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], - ) - - model = helper.make_model(graph, producer_name="relu_test") - check_correctness(model) + verify_unary("Relu", [32, 32]) def test_conv(): @@ -598,36 +533,11 @@ def test_conv(): def test_pow(): - pow_node = helper.make_node("Pow", ["x", "y"], ["z"]) - shape = [32, 32] - graph = helper.make_graph( - [pow_node], - "pow_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), - helper.make_tensor_value_info("y", TensorProto.FLOAT, shape), - ], - outputs=[helper.make_tensor_value_info("z", TensorProto.FLOAT, shape)], - ) - - model = helper.make_model(graph, producer_name="pow_test") - check_correctness(model) + verify_binary("Pow", [32, 32], [32, 32], [32, 32]) def test_erf(): - erf_node = helper.make_node("Erf", ["x"], ["y"]) - shape = [32, 32] - graph = helper.make_graph( - [erf_node], - "erf_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], - ) - - model = helper.make_model(graph, producer_name="erf_test") - check_correctness(model) + verify_unary("Erf", [32, 32]) @pytest.mark.parametrize("reverse", [True, False]) @@ -699,20 +609,45 @@ def test_const(): def test_sub(): - sub_node = helper.make_node("Sub", ["x", "y"], ["z"]) - shape = [32, 16] - graph = helper.make_graph( - [sub_node], - "sub_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), - helper.make_tensor_value_info("y", TensorProto.FLOAT, shape), - ], - outputs=[helper.make_tensor_value_info("z", TensorProto.FLOAT, shape)], - ) + verify_binary("Sub", [32, 16], [32, 16], [32, 16]) - model = helper.make_model(graph, producer_name="sub_test") - check_correctness(model) + +def test_min(): + verify_binary("Min", [32, 16], [32, 16], [32, 16]) + + +def test_max(): + verify_binary("Max", [32, 16], [32, 16], [32, 16]) + + +def test_sin(): + verify_unary("Sin", [32, 16]) + + +def test_cos(): + verify_unary("Cos", [32, 16]) + + +def test_identity(): + verify_unary("Identity", [32, 16]) + + +def test_neg(): + verify_unary("Neg", [32, 16]) + + +def test_abs(): + verify_unary("Abs", [32, 16]) + + +def test_log(): + verify_unary("Log", [32, 16]) + + +def test_instance_norm(): + verify_ternary( + "InstanceNormalization", [1, 32, 32], [32], [32], [1, 32, 32], attrs={"epsilon": 1e-12} + ) def test_layer_norm(): @@ -999,6 +934,8 @@ def verify_reduce_func(func, data, axis, keepdims): @pytest.mark.parametrize("dynamic", [False, True]) +# TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed. +@pytest.mark.skip("Produces ill-formed IR") def test_expand(dynamic): if dynamic: # TODO: Support dynamic shape for Expand @@ -1041,6 +978,8 @@ def _test_expand(name, data, shape, ref_data): _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data) +# TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed. +@pytest.mark.skip("Produces ill-formed IR") def test_constantofshape(): def verify_constantofshape(input_dim, value, dtype): fill_node = helper.make_node( @@ -1401,5 +1340,78 @@ def verify_tile(in_shape, repeats, out_shape): verify_tile(x.shape, repeats, z_array.shape) +def test_resize(): + resize_node = helper.make_node("Resize", ["X", "", "scales"], ["Y"], mode="cubic") + + graph = helper.make_graph( + [resize_node], + "resize_test", + inputs=[ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 32, 32]), + ], + initializer=[ + helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.0, 2.0]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 64, 64]), + ], + ) + + model = helper.make_model(graph, producer_name="resize_test") + check_correctness(model) + + +def test_einsum(): + eqn = "ij->i" + einsum_node = helper.make_node("Einsum", ["x"], ["y"], equation=eqn) + + graph = helper.make_graph( + [einsum_node], + "einsum_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4]), + ], + outputs=[ + helper.make_tensor_value_info("y", TensorProto.FLOAT, [3]), + ], + ) + + model = helper.make_model(graph, producer_name="einsum_test") + check_correctness(model) + + +def test_range(): + range_node = helper.make_node( + "Range", + ["start", "limit", "delta"], + ["output"], + ) + + graph = helper.make_graph( + [range_node], + "range_test", + inputs=[], + initializer=[ + helper.make_tensor("start", TensorProto.INT64, [], [1]), + helper.make_tensor("limit", TensorProto.INT64, [], [5]), + helper.make_tensor("delta", TensorProto.INT64, [], [2]), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.INT64, [2]), + ], + ) + + model = helper.make_model(graph, producer_name="range_test") + check_correctness(model) + + +def test_less(): + verify_compare("Less", [32, 32]) + + +def test_less_equal(): + verify_compare("LessOrEqual", [32, 32]) + + if __name__ == "__main__": tvm.testing.main()