diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 67a1f56e87..77e88c0277 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -2,7 +2,7 @@ from eth.codecs import abi from tests.utils import decimal_to_int -from vyper.exceptions import ArgumentException, StackTooDeep, StructureException +from vyper.exceptions import ArgumentException, StructureException TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex() @@ -201,7 +201,6 @@ def abi_decode(x: Bytes[{len}]) -> DynArray[DynArray[uint256, 3], 3]: @pytest.mark.parametrize("args", nested_3d_array_args) @pytest.mark.parametrize("unwrap_tuple", (True, False)) -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_decode_nested_dynarray2(get_contract, args, unwrap_tuple): if unwrap_tuple is True: encoded = abi.encode("(uint256[][][])", (args,)) @@ -279,7 +278,6 @@ def foo(bs: Bytes[160]) -> (uint256, DynArray[uint256, 3]): assert c.foo(encoded) == [2**256 - 1, bs] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_decode_private_nested_dynarray(get_contract): code = """ bytez: DynArray[DynArray[DynArray[uint256, 3], 3], 3] diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py index 2331fb1a9e..387dab789a 100644 --- a/tests/functional/builtins/codegen/test_abi_encode.py +++ b/tests/functional/builtins/codegen/test_abi_encode.py @@ -2,7 +2,6 @@ from eth.codecs import abi from tests.utils import decimal_to_int -from vyper.exceptions import StackTooDeep # @pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"]) @@ -227,7 +226,6 @@ def abi_encode( @pytest.mark.parametrize("args", nested_3d_array_args) -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_encode_nested_dynarray_2(get_contract, args): code = """ @external @@ -332,7 +330,6 @@ def foo(bs: DynArray[uint256, 3]) -> (uint256, Bytes[160]): assert c.foo(bs) == [2**256 - 1, abi.encode("(uint256[])", (bs,))] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_abi_encode_private_nested_dynarray(get_contract): code = """ bytez: Bytes[1696] diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index e64a35811d..0fd2c6212e 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -1,5 +1,8 @@ import pytest +from vyper.exceptions import StaticAssertionException +from vyper.utils import SizeLimits + def test_basic_repeater(get_contract_with_gas_estimation): basic_repeater = """ @@ -271,17 +274,38 @@ def test(): @pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"]) -def test_for_range_oob_check(get_contract, tx_failed, typ): +def test_for_range_oob_compile_time_check(get_contract, tx_failed, typ, experimental_codegen): code = f""" @external def test(): x: {typ} = max_value({typ}) + for i: {typ} in range(x, x + 2, bound=2): + pass + """ + if not experimental_codegen: + return + with pytest.raises(StaticAssertionException): + get_contract(code) + + +@pytest.mark.parametrize( + "typ, max_value", + [ + ("uint8", SizeLimits.MAX_UINT8), + ("int128", SizeLimits.MAX_INT128), + ("uint256", SizeLimits.MAX_UINT256), + ], +) +def test_for_range_oob_runtime_check(get_contract, tx_failed, typ, max_value): + code = f""" +@external +def test(x: {typ}): for i: {typ} in range(x, x + 2, bound=2): pass """ c = get_contract(code) with tx_failed(): - c.test() + c.test(max_value) @pytest.mark.parametrize("typ", ["int128", "uint256"]) @@ -416,7 +440,25 @@ def foo(a: {typ}) -> {typ}: assert c.foo(0) == 31337 -def test_for_range_signed_int_overflow(get_contract, tx_failed): +def test_for_range_signed_int_overflow_runtime_check(get_contract, tx_failed, experimental_codegen): + code = """ +@external +def foo(_min:int256, _max: int256) -> DynArray[int256, 10]: + res: DynArray[int256, 10] = empty(DynArray[int256, 10]) + x:int256 = _max + y:int256 = _min+2 + for i:int256 in range(x,y , bound=10): + res.append(i) + return res + """ + c = get_contract(code) + with tx_failed(): + c.foo(SizeLimits.MIN_INT256, SizeLimits.MAX_INT256) + + +def test_for_range_signed_int_overflow_compile_time_check( + get_contract, tx_failed, experimental_codegen +): code = """ @external def foo() -> DynArray[int256, 10]: @@ -427,6 +469,7 @@ def foo() -> DynArray[int256, 10]: res.append(i) return res """ - c = get_contract(code) - with tx_failed(): - c.foo() + if not experimental_codegen: + return + with pytest.raises(StaticAssertionException): + get_contract(code) diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py index d96a889497..9146ace8a6 100644 --- a/tests/functional/codegen/features/test_constructor.py +++ b/tests/functional/codegen/features/test_constructor.py @@ -1,8 +1,6 @@ import pytest from web3.exceptions import ValidationError -from vyper.exceptions import StackTooDeep - def test_init_argument_test(get_contract_with_gas_estimation): init_argument_test = """ @@ -165,7 +163,6 @@ def get_foo() -> uint256: assert c.get_foo() == 39 -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estimation): code = """ foo: int128 @@ -211,7 +208,6 @@ def get_foo() -> DynArray[DynArray[uint256, 3], 3]: assert c.get_foo() == [[37, 41, 73], [37041, 41073, 73037], [146, 123, 148]] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation): code = """ foo: DynArray[DynArray[DynArray[int128, 3], 3], 3] diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py index 874600633a..49ff54b353 100644 --- a/tests/functional/codegen/features/test_immutable.py +++ b/tests/functional/codegen/features/test_immutable.py @@ -1,7 +1,6 @@ import pytest from vyper.compiler.settings import OptimizationLevel -from vyper.exceptions import StackTooDeep @pytest.mark.parametrize( @@ -199,7 +198,6 @@ def get_idx_two() -> uint256: assert c.get_idx_two() == expected_values[2][2] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_dynarray_immutable(get_contract): code = """ my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3]) diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index e063f981ec..01c2cd74c4 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -8,6 +8,7 @@ from vyper.exceptions import ( InvalidOperation, OverflowException, + StaticAssertionException, TypeMismatch, ZeroDivisionException, ) @@ -73,18 +74,35 @@ def foo(x: int256) -> int256: # TODO: make this test pass @pytest.mark.parametrize("base", (0, 1)) -def test_exponent_negative_power(get_contract, tx_failed, base): +def test_exponent_negative_power_runtime_check(get_contract, tx_failed, base, experimental_codegen): # #2985 code = f""" @external -def bar() -> int16: - x: int16 = -2 +def bar(negative:int16) -> int16: + x: int16 = negative return {base} ** x """ c = get_contract(code) # known bug: 2985 with tx_failed(): - c.bar() + c.bar(-2) + + +@pytest.mark.parametrize("base", (0, 1)) +def test_exponent_negative_power_compile_time_check( + get_contract, tx_failed, base, experimental_codegen +): + # #2985 + code = f""" +@external +def bar() -> int16: + x: int16 = -2 + return {base} ** x + """ + if not experimental_codegen: + return + with pytest.raises(StaticAssertionException): + get_contract(code) def test_exponent_min_int16(get_contract): diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 63d3d0b8f1..1383895c7a 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -62,7 +62,6 @@ def loo(x: DynArray[DynArray[int128, 2], 2]) -> int128: print("Passed list tests") -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_string_list(get_contract): code = """ @external @@ -1491,7 +1490,6 @@ def foo(x: int128) -> int128: assert c.foo(7) == 392 -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_struct_of_lists(get_contract): code = """ struct Foo: @@ -1580,7 +1578,6 @@ def bar(x: int128) -> DynArray[int128, 3]: assert c.bar(7) == [7, 14] -@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") def test_nested_struct_of_lists(get_contract, assert_compile_failed, optimize): code = """ struct nestedFoo: @@ -1710,9 +1707,7 @@ def __init__(): ("DynArray[DynArray[DynArray[uint256, 5], 5], 5]", [[[], []], []]), ], ) -def test_empty_nested_dynarray(get_contract, typ, val, venom_xfail): - if val == [[[], []], []]: - venom_xfail(raises=StackTooDeep, reason="stack scheduler regression") +def test_empty_nested_dynarray(get_contract, typ, val): code = f""" @external def foo() -> {typ}: diff --git a/tests/unit/compiler/venom/test_sccp.py b/tests/unit/compiler/venom/test_sccp.py new file mode 100644 index 0000000000..8102a0d89c --- /dev/null +++ b/tests/unit/compiler/venom/test_sccp.py @@ -0,0 +1,139 @@ +from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable +from vyper.venom.function import IRFunction +from vyper.venom.passes.make_ssa import MakeSSA +from vyper.venom.passes.sccp import SCCP +from vyper.venom.passes.sccp.sccp import LatticeEnum + + +def test_simple_case(): + ctx = IRFunction(IRLabel("_global")) + + bb = ctx.get_basic_block() + p1 = bb.append_instruction("param") + op1 = bb.append_instruction("store", 32) + op2 = bb.append_instruction("store", 64) + op3 = bb.append_instruction("add", op1, op2) + bb.append_instruction("return", p1, op3) + + make_ssa_pass = MakeSSA() + make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) + sccp = SCCP(make_ssa_pass.dom) + sccp.run_pass(ctx, ctx.basic_blocks[0]) + + assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM + assert sccp.lattice[IRVariable("%2")].value == 32 + assert sccp.lattice[IRVariable("%3")].value == 64 + assert sccp.lattice[IRVariable("%4")].value == 96 + + +def test_cont_jump_case(): + ctx = IRFunction(IRLabel("_global")) + + bb = ctx.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), ctx) + ctx.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), ctx) + ctx.append_basic_block(br2) + + p1 = bb.append_instruction("param") + op1 = bb.append_instruction("store", 32) + op2 = bb.append_instruction("store", 64) + op3 = bb.append_instruction("add", op1, op2) + bb.append_instruction("jnz", op3, br1.label, br2.label) + + br1.append_instruction("add", op3, 10) + br1.append_instruction("stop") + br2.append_instruction("add", op3, p1) + br2.append_instruction("stop") + + make_ssa_pass = MakeSSA() + make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) + sccp = SCCP(make_ssa_pass.dom) + sccp.run_pass(ctx, ctx.basic_blocks[0]) + + assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM + assert sccp.lattice[IRVariable("%2")].value == 32 + assert sccp.lattice[IRVariable("%3")].value == 64 + assert sccp.lattice[IRVariable("%4")].value == 96 + assert sccp.lattice[IRVariable("%5")].value == 106 + assert sccp.lattice.get(IRVariable("%6")) == LatticeEnum.BOTTOM + + +def test_cont_phi_case(): + ctx = IRFunction(IRLabel("_global")) + + bb = ctx.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), ctx) + ctx.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), ctx) + ctx.append_basic_block(br2) + join = IRBasicBlock(IRLabel("join"), ctx) + ctx.append_basic_block(join) + + p1 = bb.append_instruction("param") + op1 = bb.append_instruction("store", 32) + op2 = bb.append_instruction("store", 64) + op3 = bb.append_instruction("add", op1, op2) + bb.append_instruction("jnz", op3, br1.label, br2.label) + + op4 = br1.append_instruction("add", op3, 10) + br1.append_instruction("jmp", join.label) + br2.append_instruction("add", op3, p1, ret=op4) + br2.append_instruction("jmp", join.label) + + join.append_instruction("return", op4, p1) + + make_ssa_pass = MakeSSA() + make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) + + sccp = SCCP(make_ssa_pass.dom) + sccp.run_pass(ctx, ctx.basic_blocks[0]) + + assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM + assert sccp.lattice[IRVariable("%2")].value == 32 + assert sccp.lattice[IRVariable("%3")].value == 64 + assert sccp.lattice[IRVariable("%4")].value == 96 + assert sccp.lattice[IRVariable("%5", version=1)].value == 106 + assert sccp.lattice[IRVariable("%5", version=2)] == LatticeEnum.BOTTOM + assert sccp.lattice[IRVariable("%5")].value == 2 + + +def test_cont_phi_const_case(): + ctx = IRFunction(IRLabel("_global")) + + bb = ctx.get_basic_block() + + br1 = IRBasicBlock(IRLabel("then"), ctx) + ctx.append_basic_block(br1) + br2 = IRBasicBlock(IRLabel("else"), ctx) + ctx.append_basic_block(br2) + join = IRBasicBlock(IRLabel("join"), ctx) + ctx.append_basic_block(join) + + p1 = bb.append_instruction("store", 1) + op1 = bb.append_instruction("store", 32) + op2 = bb.append_instruction("store", 64) + op3 = bb.append_instruction("add", op1, op2) + bb.append_instruction("jnz", op3, br1.label, br2.label) + + op4 = br1.append_instruction("add", op3, 10) + br1.append_instruction("jmp", join.label) + br2.append_instruction("add", op3, p1, ret=op4) + br2.append_instruction("jmp", join.label) + + join.append_instruction("return", op4, p1) + + make_ssa_pass = MakeSSA() + make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) + sccp = SCCP(make_ssa_pass.dom) + sccp.run_pass(ctx, ctx.basic_blocks[0]) + + assert sccp.lattice[IRVariable("%1")].value == 1 + assert sccp.lattice[IRVariable("%2")].value == 32 + assert sccp.lattice[IRVariable("%3")].value == 64 + assert sccp.lattice[IRVariable("%4")].value == 96 + assert sccp.lattice[IRVariable("%5", version=1)].value == 106 + assert sccp.lattice[IRVariable("%5", version=2)].value == 97 + assert sccp.lattice[IRVariable("%5")].value == 2 diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 35cbbe648d..bd9d8b2e5e 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -248,16 +248,16 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: dst_typ = BytesT(length.value) # allocate a buffer for the return value - np = context.new_internal_variable(dst_typ) + buf = context.new_internal_variable(dst_typ) # `msg.data` by `calldatacopy` if sub.value == "~calldata": node = [ "seq", _make_slice_bounds_check(start, length, "calldatasize"), - ["mstore", np, length], - ["calldatacopy", np + 32, start, length], - np, + ["mstore", buf, length], + ["calldatacopy", add_ofst(buf, 32), start, length], + buf, ] # `self.code` by `codecopy` @@ -265,9 +265,9 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: node = [ "seq", _make_slice_bounds_check(start, length, "codesize"), - ["mstore", np, length], - ["codecopy", np + 32, start, length], - np, + ["mstore", buf, length], + ["codecopy", add_ofst(buf, 32), start, length], + buf, ] # `
.code` by `extcodecopy` @@ -280,9 +280,9 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context: [ "seq", _make_slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]), - ["mstore", np, length], - ["extcodecopy", "_extcode_address", np + 32, start, length], - np, + ["mstore", buf, length], + ["extcodecopy", "_extcode_address", add_ofst(buf, 32), start, length], + buf, ], ] @@ -552,10 +552,8 @@ def build_IR(self, expr, context): # respect API of copy_bytes bufsize = dst_maxlen + 32 - buf = context.new_internal_variable(BytesT(bufsize)) - - # Node representing the position of the output in memory - dst = IRnode.from_list(buf, typ=ret_typ, location=MEMORY, annotation="concat destination") + dst = context.new_internal_variable(BytesT(bufsize)) + dst.annotation = "concat destination" ret = ["seq"] # stack item representing our current offset in the dst buffer @@ -783,9 +781,9 @@ def build_IR(self, expr, args, kwargs, context): # clear output memory first, ecrecover can return 0 bytes ["mstore", output_buf, 0], ["mstore", input_buf, args[0]], - ["mstore", input_buf + 32, args[1]], - ["mstore", input_buf + 64, args[2]], - ["mstore", input_buf + 96, args[3]], + ["mstore", add_ofst(input_buf, 32), args[1]], + ["mstore", add_ofst(input_buf, 64), args[2]], + ["mstore", add_ofst(input_buf, 96), args[3]], ["staticcall", "gas", 1, input_buf, 128, output_buf, 32], ["mload", output_buf], ], @@ -799,9 +797,7 @@ def build_IR(self, expr, _args, kwargs, context): args_tuple = ir_tuple_from_args(_args) args_t = args_tuple.typ - input_buf = IRnode.from_list( - context.new_internal_variable(args_t), typ=args_t, location=MEMORY - ) + input_buf = context.new_internal_variable(args_t) ret_t = self._return_type ret = ["seq"] @@ -1103,9 +1099,7 @@ def build_IR(self, expr, args, kwargs, context): args_ofst = add_ofst(input_buf, 32) args_len = ["mload", input_buf] - output_node = IRnode.from_list( - context.new_internal_variable(BytesT(outsize)), typ=BytesT(outsize), location=MEMORY - ) + output_node = context.new_internal_variable(BytesT(outsize)) bool_ty = BoolT() @@ -1712,8 +1706,8 @@ def _build_create_IR(self, expr, args, context, value, salt, revert_on_failure): return [ "seq", ["mstore", buf, forwarder_preamble], - ["mstore", ["add", buf, preamble_length], aligned_target], - ["mstore", ["add", buf, preamble_length + 20], forwarder_post], + ["mstore", add_ofst(buf, preamble_length), aligned_target], + ["mstore", add_ofst(buf, preamble_length + 20), forwarder_post], _create_ir(value, buf, buf_len, salt, revert_on_failure), ] @@ -1822,9 +1816,7 @@ def _build_create_IR( # pretend we allocated enough memory for the encoder # (we didn't, but we are clobbering unused memory so it's safe.) bufsz = to_encode.typ.abi_type.size_bound() - argbuf = IRnode.from_list( - context.new_internal_variable(get_type_for_exact_size(bufsz)), location=MEMORY - ) + argbuf = context.new_internal_variable(get_type_for_exact_size(bufsz)) # return a complex expression which writes to memory and returns # the length of the encoded data @@ -2071,13 +2063,17 @@ def build_IR(self, expr, args, kwargs, context): # clobber val, and return it as a pointer [ "seq", - ["mstore", ["sub", buf + n_digits, i], i], - ["set", val, ["sub", buf + n_digits, i]], + ["mstore", ["sub", add_ofst(buf, n_digits), i], i], + ["set", val, ["sub", add_ofst(buf, n_digits), i]], "break", ], [ "seq", - ["mstore", ["sub", buf + n_digits, i], ["add", 48, ["mod", val, 10]]], + [ + "mstore", + ["sub", add_ofst(buf, n_digits), i], + ["add", 48, ["mod", val, 10]], + ], ["set", val, ["div", val, 10]], ], ], @@ -2093,7 +2089,7 @@ def build_IR(self, expr, args, kwargs, context): ret = [ "if", ["eq", val, 0], - ["seq", ["mstore", buf + 1, ord("0")], ["mstore", buf, 1], buf], + ["seq", ["mstore", add_ofst(buf, 1), ord("0")], ["mstore", buf, 1], buf], ["seq", ret, val], ] @@ -2271,7 +2267,7 @@ def build_IR(self, expr, args, kwargs, context): ret = ["seq"] ret.append(["mstore", buf, method_id]) - encode = abi_encode(buf + 32, args_as_tuple, context, buflen, returns_len=True) + encode = abi_encode(add_ofst(buf, 32), args_as_tuple, context, buflen, returns_len=True) else: method_id = method_id_int("log(string,bytes)") @@ -2285,7 +2281,9 @@ def build_IR(self, expr, args, kwargs, context): ret.append(["mstore", schema_buf, len(schema)]) # TODO use Expr.make_bytelike, or better have a `bytestring` IRnode type - ret.append(["mstore", schema_buf + 32, bytes_to_int(schema.ljust(32, b"\x00"))]) + ret.append( + ["mstore", add_ofst(schema_buf, 32), bytes_to_int(schema.ljust(32, b"\x00"))] + ) payload_buflen = args_abi_t.size_bound() payload_t = BytesT(payload_buflen) @@ -2293,7 +2291,7 @@ def build_IR(self, expr, args, kwargs, context): # 32 bytes extra space for the method id payload_buf = context.new_internal_variable(payload_t) encode_payload = abi_encode( - payload_buf + 32, args_as_tuple, context, payload_buflen, returns_len=True + add_ofst(payload_buf, 32), args_as_tuple, context, payload_buflen, returns_len=True ) ret.append(["mstore", payload_buf, encode_payload]) @@ -2308,11 +2306,13 @@ def build_IR(self, expr, args, kwargs, context): buflen = 32 + args_as_tuple.typ.abi_type.size_bound() buf = context.new_internal_variable(get_type_for_exact_size(buflen)) ret.append(["mstore", buf, method_id]) - encode = abi_encode(buf + 32, args_as_tuple, context, buflen, returns_len=True) + encode = abi_encode(add_ofst(buf, 32), args_as_tuple, context, buflen, returns_len=True) # debug address that tooling uses CONSOLE_ADDRESS = 0x000000000000000000636F6E736F6C652E6C6F67 - ret.append(["staticcall", "gas", CONSOLE_ADDRESS, buf + 28, ["add", 4, encode], 0, 0]) + ret.append( + ["staticcall", "gas", CONSOLE_ADDRESS, add_ofst(buf, 28), ["add", 4, encode], 0, 0] + ) return IRnode.from_list(ret, annotation="print:" + sig) @@ -2415,15 +2415,19 @@ def build_IR(self, expr, args, kwargs, context): # <32 bytes length> | <4 bytes method_id> | # write the unaligned method_id first, then we will # overwrite the 28 bytes of zeros with the bytestring length - ret += [["mstore", buf + 4, method_id]] + ret += [["mstore", add_ofst(buf, 4), method_id]] # abi encode, and grab length as stack item - length = abi_encode(buf + 36, encode_input, context, returns_len=True, bufsz=maxlen) + length = abi_encode( + add_ofst(buf, 36), encode_input, context, returns_len=True, bufsz=maxlen + ) # write the output length to where bytestring stores its length ret += [["mstore", buf, ["add", length, 4]]] else: # abi encode and grab length as stack item - length = abi_encode(buf + 32, encode_input, context, returns_len=True, bufsz=maxlen) + length = abi_encode( + add_ofst(buf, 32), encode_input, context, returns_len=True, bufsz=maxlen + ) # write the output length to where bytestring stores its length ret += [["mstore", buf, length]] @@ -2508,13 +2512,12 @@ def build_IR(self, expr, args, kwargs, context): # input validation output_buf = context.new_internal_variable(wrapped_typ) - output = IRnode.from_list(output_buf, typ=wrapped_typ, location=MEMORY) # sanity check buffer size for wrapped output type will not buffer overflow assert wrapped_typ.memory_bytes_required == output_typ.memory_bytes_required - ret.append(make_setter(output, to_decode)) + ret.append(make_setter(output_buf, to_decode)) - ret.append(output) + ret.append(output_buf) # finalize. set the type and location for the return buffer. # (note: unwraps the tuple type if necessary) ret = IRnode.from_list(ret, typ=output_typ, location=MEMORY) diff --git a/vyper/codegen/abi_encoder.py b/vyper/codegen/abi_encoder.py index b227161d94..09a22cd857 100644 --- a/vyper/codegen/abi_encoder.py +++ b/vyper/codegen/abi_encoder.py @@ -159,9 +159,9 @@ def abi_encoding_matches_vyper(typ): # the abi_encode routine will push the output len onto the stack, # otherwise it will return 0 items to the stack. def abi_encode(dst, ir_node, context, bufsz, returns_len=False): - # TODO change dst to be an IRnode so it has type info to begin with. - # setting the typ of dst to ir_node.typ is a footgun. + # cast dst to the type of the input so that make_setter works dst = IRnode.from_list(dst, typ=ir_node.typ, location=MEMORY) + abi_t = dst.typ.abi_type size_bound = abi_t.size_bound() diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 5ac8cdd758..42488f06da 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -3,7 +3,8 @@ from dataclasses import dataclass from typing import Any, Optional -from vyper.codegen.ir_node import Encoding +from vyper.codegen.ir_node import Encoding, IRnode +from vyper.compiler.settings import get_global_settings from vyper.evm.address_space import MEMORY, AddrSpace from vyper.exceptions import CompilerPanic, StateAccessViolation from vyper.semantics.types import VyperType @@ -14,6 +15,17 @@ class Constancy(enum.Enum): Constant = 1 +@dataclass(frozen=True) +class Alloca: + name: str + offset: int + typ: VyperType + size: int + + def __post_init__(self): + assert self.typ.memory_bytes_required == self.size + + # Function variable @dataclass class VariableRecord: @@ -27,6 +39,9 @@ class VariableRecord: blockscopes: Optional[list] = None defined_at: Any = None is_internal: bool = False + alloca: Optional[Alloca] = None + + # the following members are probably dead is_immutable: bool = False is_transient: bool = False data_offset: Optional[int] = None @@ -43,6 +58,20 @@ def __repr__(self): ret["allocated"] = self.typ.memory_bytes_required return f"VariableRecord({ret})" + def as_ir_node(self): + ret = IRnode.from_list( + self.pos, + typ=self.typ, + annotation=self.name, + encoding=self.encoding, + mutable=self.mutable, + location=self.location, + ) + ret._referenced_variables = {self} + if self.alloca is not None: + ret.passthrough_metadata["alloca"] = self.alloca + return ret + # compilation context for a function class Context: @@ -91,6 +120,8 @@ def __init__( # either the constructor, or called from the constructor self.is_ctor_context = is_ctor_context + self.settings = get_global_settings() + def is_constant(self): return self.constancy is Constancy.Constant or self.in_range_expr @@ -140,10 +171,7 @@ def internal_memory_scope(self): (k, v) for k, v in self.vars.items() if v.is_internal and scope_id in v.blockscopes ] for name, var in released: - n = var.typ.memory_bytes_required - assert n == var.size - self.memory_allocator.deallocate_memory(var.pos, n) - del self.vars[name] + self.deallocate_variable(name, var) # Remove block scopes self._scopes.remove(scope_id) @@ -164,36 +192,68 @@ def block_scope(self): # Remove all variables that have specific scope_id attached released = [(k, v) for k, v in self.vars.items() if scope_id in v.blockscopes] for name, var in released: - n = var.typ.memory_bytes_required - # sanity check the type's size hasn't changed since allocation. - assert n == var.size - self.memory_allocator.deallocate_memory(var.pos, n) - del self.vars[name] + self.deallocate_variable(name, var) # Remove block scopes self._scopes.remove(scope_id) - def _new_variable( - self, name: str, typ: VyperType, var_size: int, is_internal: bool, is_mutable: bool = True - ) -> int: - var_pos = self.memory_allocator.allocate_memory(var_size) + def deallocate_variable(self, varname, var): + assert varname == var.name + + # sanity check the type's size hasn't changed since allocation. + n = var.typ.memory_bytes_required + assert n == var.size - assert var_pos + var_size <= self.memory_allocator.size_of_mem, "function frame overrun" + if self.settings.experimental_codegen: + # do not deallocate at this stage because this will break + # analysis in venom; venom will do its own alloc/dealloc/analysis. + pass + else: + self.memory_allocator.deallocate_memory(var.pos, var.size) - self.vars[name] = VariableRecord( + del self.vars[var.name] + + def _new_variable( + self, + name: str, + typ: VyperType, + is_internal: bool, + is_mutable: bool = True, + internal_function=False, + ) -> IRnode: + size = typ.memory_bytes_required + + ofst = self.memory_allocator.allocate_memory(size) + assert ofst + size <= self.memory_allocator.size_of_mem, "function frame overrun" + + pos = ofst + alloca = None + if self.settings.experimental_codegen: + # convert it into an abstract pointer + if internal_function: + pos = f"$palloca_{ofst}_{size}" + else: + pos = f"$alloca_{ofst}_{size}" + alloca = Alloca(name=name, offset=ofst, typ=typ, size=size) + + var = VariableRecord( name=name, - pos=var_pos, + pos=pos, typ=typ, - size=var_size, + size=size, mutable=is_mutable, blockscopes=self._scopes.copy(), is_internal=is_internal, + alloca=alloca, ) - return var_pos + self.vars[name] = var + return var.as_ir_node() - def new_variable(self, name: str, typ: VyperType, is_mutable: bool = True) -> int: + def new_variable( + self, name: str, typ: VyperType, is_mutable: bool = True, internal_function=False + ) -> IRnode: """ - Allocate memory for a user-defined variable. + Allocate memory for a user-defined variable and return an IR node referencing it. Arguments --------- @@ -208,8 +268,9 @@ def new_variable(self, name: str, typ: VyperType, is_mutable: bool = True) -> in Memory offset for the variable """ - var_size = typ.memory_bytes_required - return self._new_variable(name, typ, var_size, False, is_mutable=is_mutable) + return self._new_variable( + name, typ, is_internal=False, is_mutable=is_mutable, internal_function=internal_function + ) def fresh_varname(self, name: str) -> str: """ @@ -219,8 +280,7 @@ def fresh_varname(self, name: str) -> str: self._internal_var_iter += 1 return f"{name}{t}" - # do we ever allocate immutable internal variables? - def new_internal_variable(self, typ: VyperType) -> int: + def new_internal_variable(self, typ: VyperType) -> IRnode: """ Allocate memory for an internal variable. @@ -237,10 +297,9 @@ def new_internal_variable(self, typ: VyperType) -> int: # internal variable names begin with a number sign so there is no chance for collision name = self.fresh_varname("#internal") - var_size = typ.memory_bytes_required - return self._new_variable(name, typ, var_size, True) + return self._new_variable(name, typ, is_internal=True) - def lookup_var(self, varname): + def lookup_var(self, varname) -> VariableRecord: return self.vars[varname] # Pretty print constancy for error messages diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index aca524cd6a..69ffb2bfd6 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -168,17 +168,7 @@ def parse_Name(self): if self.expr.id == "self": return IRnode.from_list(["address"], typ=AddressT()) elif self.expr.id in self.context.vars: - var = self.context.vars[self.expr.id] - ret = IRnode.from_list( - var.pos, - typ=var.typ, - location=var.location, # either 'memory' or 'calldata' storage is handled above. - encoding=var.encoding, - annotation=self.expr.id, - mutable=var.mutable, - ) - ret._referenced_variables = {var} - return ret + return self.context.lookup_var(self.expr.id).as_ir_node() elif (varinfo := self.expr._expr_info.var_info) is not None: if varinfo.is_constant: diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py index 1e50886baf..607872b052 100644 --- a/vyper/codegen/external_call.py +++ b/vyper/codegen/external_call.py @@ -1,9 +1,11 @@ +import copy from dataclasses import dataclass import vyper.utils as util from vyper.codegen.abi_encoder import abi_encode from vyper.codegen.core import ( _freshname, + add_ofst, calculate_type_for_external_return, check_assign, check_external_call, @@ -54,7 +56,7 @@ def _pack_arguments(fn_type, args, context): buf_t = get_type_for_exact_size(buflen) buf = context.new_internal_variable(buf_t) - args_ofst = buf + 28 + args_ofst = add_ofst(buf, 28) args_len = args_abi_t.size_bound() + 4 abi_signature = fn_type.name + dst_tuple_t.abi_type.selector_name() @@ -69,7 +71,7 @@ def _pack_arguments(fn_type, args, context): pack_args.append(["mstore", buf, util.method_id_int(abi_signature)]) if len(args) != 0: - pack_args.append(abi_encode(buf + 32, args_as_tuple, context, bufsz=buflen)) + pack_args.append(abi_encode(add_ofst(buf, 32), args_as_tuple, context, bufsz=buflen)) return buf, pack_args, args_ofst, args_len @@ -93,13 +95,11 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp encoding = Encoding.ABI - buf = IRnode.from_list( - buf, - typ=wrapped_return_t, - location=MEMORY, - encoding=encoding, - annotation=f"{expr.node_source_code} returndata buffer", - ) + assert buf.location == MEMORY + buf = copy.copy(buf) + buf.typ = wrapped_return_t + buf.encoding = encoding + buf.annotation = f"{expr.node_source_code} returndata buffer" unpacker = ["seq"] @@ -117,7 +117,6 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp # unpack strictly if needs_clamp(wrapped_return_t, encoding): return_buf = context.new_internal_variable(wrapped_return_t) - return_buf = IRnode.from_list(return_buf, typ=wrapped_return_t, location=MEMORY) # note: make_setter does ABI decoding and clamps unpacker.append(make_setter(return_buf, buf)) diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index fe706699bb..a9b4a93025 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -11,7 +11,7 @@ ) from vyper.codegen.ir_node import Encoding, IRnode from vyper.codegen.stmt import parse_body -from vyper.evm.address_space import CALLDATA, DATA, MEMORY +from vyper.evm.address_space import CALLDATA, DATA from vyper.semantics.types import TupleT from vyper.semantics.types.function import ContractFunctionT from vyper.utils import calc_mem_gas @@ -35,8 +35,7 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> list if needs_clamp(arg.typ, Encoding.ABI): # allocate a memory slot for it and copy - p = context.new_variable(arg.name, arg.typ, is_mutable=False) - dst = IRnode(p, typ=arg.typ, location=MEMORY) + dst = context.new_variable(arg.name, arg.typ, is_mutable=False) copy_arg = make_setter(dst, arg_ir) copy_arg.ast_source = arg.ast_source @@ -94,9 +93,7 @@ def handler_for(calldata_kwargs, default_kwargs): for i, arg_meta in enumerate(calldata_kwargs): k = func_t.n_positional_args + i - dst = context.lookup_var(arg_meta.name).pos - - lhs = IRnode(dst, location=MEMORY, typ=arg_meta.typ) + lhs = context.lookup_var(arg_meta.name).as_ir_node() rhs = get_element_ptr(calldata_kwargs_ofst, k, array_bounds_check=False) @@ -105,8 +102,7 @@ def handler_for(calldata_kwargs, default_kwargs): ret.append(copy_arg) for x in default_kwargs: - dst = context.lookup_var(x.name).pos - lhs = IRnode(dst, location=MEMORY, typ=x.typ) + lhs = context.lookup_var(x.name).as_ir_node() lhs.ast_source = x.ast_source kw_ast_val = func_t.default_values[x.name] # e.g. `3` in x: int = 3 rhs = Expr(kw_ast_val, context).ir_node diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index cde1ec5c87..0ad993b33c 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -49,7 +49,7 @@ def generate_ir_for_internal_function( for arg in func_t.arguments: # allocate a variable for every arg, setting mutability # to True to allow internal function arguments to be mutable - context.new_variable(arg.name, arg.typ, is_mutable=True) + context.new_variable(arg.name, arg.typ, is_mutable=True, internal_function=True) # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index 2363de3641..1114dd46cc 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -77,9 +77,7 @@ def ir_for_self_call(stmt_expr, context): if args_as_tuple.contains_self_call: copy_args = ["seq"] # TODO deallocate me - tmp_args_buf = IRnode( - context.new_internal_variable(dst_tuple_t), typ=dst_tuple_t, location=MEMORY - ) + tmp_args_buf = context.new_internal_variable(dst_tuple_t) copy_args.append( # --> args evaluate here <-- make_setter(tmp_args_buf, args_as_tuple) diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 9c5720d61c..1018a6ec43 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -7,6 +7,7 @@ LOAD, STORE, IRnode, + add_ofst, clamp_le, get_dyn_array_count, get_element_ptr, @@ -136,14 +137,14 @@ def _assert_reason(self, test_expr, msg): # abi encode method_id + bytestring to `buf+32`, then # write method_id to `buf` and get out of here - payload_buf = buf + 32 + payload_buf = add_ofst(buf, 32) bufsz -= 32 # reduce buffer by size of `method_id` slot encoded_length = abi_encode(payload_buf, msg_ir, self.context, bufsz, returns_len=True) with encoded_length.cache_when_complex("encoded_len") as (b1, encoded_length): revert_seq = [ "seq", ["mstore", buf, method_id], - ["revert", buf + 28, ["add", 4, encoded_length]], + ["revert", add_ofst(buf, 28), ["add", 4, encoded_length]], ] revert_seq = b1.resolve(revert_seq) diff --git a/vyper/utils.py b/vyper/utils.py index cf8a709997..01ae37e213 100644 --- a/vyper/utils.py +++ b/vyper/utils.py @@ -405,6 +405,7 @@ class SizeLimits: MAX_AST_DECIMAL = decimal.Decimal(2**167 - 1) / DECIMAL_DIVISOR MAX_UINT8 = 2**8 - 1 MAX_UINT256 = 2**256 - 1 + CEILING_UINT256 = 2**256 def quantize(d: decimal.Decimal, places=MAX_DECIMAL_PLACES, rounding_mode=decimal.ROUND_DOWN): diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 0a75e567ba..a60f679a76 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -1,7 +1,7 @@ # maybe rename this `main.py` or `venom.py` # (can have an `__init__.py` which exposes the API). -from typing import Any, Optional +from typing import Optional from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import OptimizationLevel @@ -11,13 +11,12 @@ ir_pass_optimize_unused_variables, ir_pass_remove_unreachable_blocks, ) -from vyper.venom.dominators import DominatorTree from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom -from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation from vyper.venom.passes.dft import DFTPass from vyper.venom.passes.make_ssa import MakeSSA -from vyper.venom.passes.normalization import NormalizationPass +from vyper.venom.passes.mem2var import Mem2Var +from vyper.venom.passes.sccp import SCCP from vyper.venom.passes.simplify_cfg import SimplifyCFGPass from vyper.venom.venom_to_assembly import VenomCompiler @@ -56,10 +55,30 @@ def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None: for entry in internals: SimplifyCFGPass().run_pass(ctx, entry) + dfg = DFG.build_dfg(ctx) + Mem2Var().run_pass(ctx, ctx.basic_blocks[0], dfg) + for entry in internals: + Mem2Var().run_pass(ctx, entry, dfg) + make_ssa_pass = MakeSSA() make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0]) + + cfg_dirty = False + sccp_pass = SCCP(make_ssa_pass.dom) + sccp_pass.run_pass(ctx, ctx.basic_blocks[0]) + cfg_dirty |= sccp_pass.cfg_dirty + for entry in internals: make_ssa_pass.run_pass(ctx, entry) + sccp_pass = SCCP(make_ssa_pass.dom) + sccp_pass.run_pass(ctx, entry) + cfg_dirty |= sccp_pass.cfg_dirty + + calculate_cfg(ctx) + SimplifyCFGPass().run_pass(ctx, ctx.basic_blocks[0]) + + calculate_cfg(ctx) + calculate_liveness(ctx) while True: changes = 0 diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py index 066a60f45e..8e4c24fea3 100644 --- a/vyper/venom/analysis.py +++ b/vyper/venom/analysis.py @@ -152,6 +152,10 @@ def get_uses(self, op: IRVariable) -> list[IRInstruction]: def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]: return self._dfg_outputs.get(op) + @property + def outputs(self) -> dict[IRVariable, IRInstruction]: + return self._dfg_outputs + @classmethod def build_dfg(cls, ctx: IRFunction) -> "DFG": dfg = cls() diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 0993fb0515..7c54145018 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -1,14 +1,14 @@ -from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional, Union +from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +from vyper.codegen.ir_node import IRnode from vyper.utils import OrderedSet # instructions which can terminate a basic block -BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop", "exit"]) +BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "stop", "exit"]) VOLATILE_INSTRUCTIONS = frozenset( [ "param", - "alloca", "call", "staticcall", "delegatecall", @@ -190,16 +190,15 @@ class IRInstruction: """ opcode: str - volatile: bool operands: list[IROperand] output: Optional[IROperand] # set of live variables at this instruction liveness: OrderedSet[IRVariable] dup_requirements: OrderedSet[IRVariable] - parent: Optional["IRBasicBlock"] + parent: "IRBasicBlock" fence_id: int annotation: Optional[str] - ast_source: Optional[int] + ast_source: Optional[IRnode] error_msg: Optional[str] def __init__( @@ -211,34 +210,36 @@ def __init__( assert isinstance(opcode, str), "opcode must be an str" assert isinstance(operands, list | Iterator), "operands must be a list" self.opcode = opcode - self.volatile = opcode in VOLATILE_INSTRUCTIONS self.operands = list(operands) # in case we get an iterator self.output = output self.liveness = OrderedSet() self.dup_requirements = OrderedSet() - self.parent = None self.fence_id = -1 self.annotation = None self.ast_source = None self.error_msg = None - def get_label_operands(self) -> list[IRLabel]: + @property + def volatile(self) -> bool: + return self.opcode in VOLATILE_INSTRUCTIONS + + def get_label_operands(self) -> Iterator[IRLabel]: """ Get all labels in instruction. """ - return [op for op in self.operands if isinstance(op, IRLabel)] + return (op for op in self.operands if isinstance(op, IRLabel)) - def get_non_label_operands(self) -> list[IROperand]: + def get_non_label_operands(self) -> Iterator[IROperand]: """ Get input operands for instruction which are not labels """ - return [op for op in self.operands if not isinstance(op, IRLabel)] + return (op for op in self.operands if not isinstance(op, IRLabel)) - def get_inputs(self) -> list[IRVariable]: + def get_inputs(self) -> Iterator[IRVariable]: """ Get all input operands for instruction. """ - return [op for op in self.operands if isinstance(op, IRVariable)] + return (op for op in self.operands if isinstance(op, IRVariable)) def get_outputs(self) -> list[IROperand]: """ @@ -268,7 +269,7 @@ def replace_label_operands(self, replacements: dict) -> None: self.operands[i] = replacements[operand.value] @property - def phi_operands(self) -> Generator[tuple[IRLabel, IRVariable], None, None]: + def phi_operands(self) -> Iterator[tuple[IRLabel, IROperand]]: """ Get phi operands for instruction. """ @@ -277,9 +278,30 @@ def phi_operands(self) -> Generator[tuple[IRLabel, IRVariable], None, None]: label = self.operands[i] var = self.operands[i + 1] assert isinstance(label, IRLabel), "phi operand must be a label" - assert isinstance(var, IRVariable), "phi operand must be a variable" + assert isinstance( + var, (IRVariable, IRLiteral) + ), "phi operand must be a variable or literal" yield label, var + def remove_phi_operand(self, label: IRLabel) -> None: + """ + Remove a phi operand from the instruction. + """ + assert self.opcode == "phi", "instruction must be a phi" + for i in range(0, len(self.operands), 2): + if self.operands[i] == label: + del self.operands[i : i + 2] + return + + def get_ast_source(self) -> Optional[IRnode]: + if self.ast_source: + return self.ast_source + idx = self.parent.instructions.index(self) + for inst in reversed(self.parent.instructions[:idx]): + if inst.ast_source: + return inst.ast_source + return self.parent.parent.ast_source + def __repr__(self) -> str: s = "" if self.output: @@ -451,6 +473,15 @@ def get_assignments(self): """ return [inst.output for inst in self.instructions if inst.output] + def get_uses(self) -> dict[IRVariable, OrderedSet[IRInstruction]]: + uses: dict[IRVariable, OrderedSet[IRInstruction]] = {} + for inst in self.instructions: + for op in inst.get_inputs(): + if op not in uses: + uses[op] = OrderedSet() + uses[op].add(inst) + return uses + @property def is_empty(self) -> bool: """ diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py index 60dd8bbee1..284a1f1b9c 100644 --- a/vyper/venom/bb_optimizer.py +++ b/vyper/venom/bb_optimizer.py @@ -13,7 +13,8 @@ def _optimize_unused_variables(ctx: IRFunction) -> set[IRInstruction]: for i, inst in enumerate(bb.instructions[:-1]): if inst.volatile: continue - if inst.output and inst.output not in bb.instructions[i + 1].liveness: + next_liveness = bb.instructions[i + 1].liveness + if (inst.output and inst.output not in next_liveness) or inst.opcode == "nop": removeList.add(inst) bb.instructions = [inst for inst in bb.instructions if inst not in removeList] diff --git a/vyper/venom/function.py b/vyper/venom/function.py index d1680385f5..8756642f80 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -30,7 +30,7 @@ class IRFunction: last_variable: int # Used during code generation - _ast_source_stack: list[int] + _ast_source_stack: list[IRnode] _error_msg_stack: list[str] _bb_index: dict[str, int] @@ -158,6 +158,12 @@ def remove_unreachable_blocks(self) -> int: continue in_labels = inst.get_label_operands() if bb.label in in_labels: + inst.remove_phi_operand(bb.label) + op_len = len(inst.operands) + if op_len == 2: + inst.opcode = "store" + inst.operands = [inst.operands[1]] + elif op_len == 0: out_bb.remove_instruction(inst) return len(removed) @@ -230,7 +236,7 @@ def pop_source(self): self._error_msg_stack.pop() @property - def ast_source(self) -> Optional[int]: + def ast_source(self) -> Optional[IRnode]: return self._ast_source_stack[-1] if len(self._ast_source_stack) > 0 else None @property diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 7dc6bd2d47..775b55f9a8 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -62,9 +62,13 @@ "gasprice", "gaslimit", "returndatasize", + "mload", "iload", + "istore", "sload", + "sstore", "tload", + "tstore", "coinbase", "number", "prevrandao", @@ -88,9 +92,6 @@ "codecopy", "returndatacopy", "revert", - "istore", - "sstore", - "tstore", "create", "create2", "addmod", @@ -104,10 +105,14 @@ NOOP_INSTRUCTIONS = frozenset(["pass", "cleanup_repeat", "var_list", "unique_symbol"]) SymbolTable = dict[str, Optional[IROperand]] +_global_symbols: SymbolTable = {} # convert IRnode directly to venom def ir_node_to_venom(ir: IRnode) -> IRFunction: + global _global_symbols + _global_symbols = {} + ctx = IRFunction() _convert_ir_bb(ctx, ir, {}) @@ -234,7 +239,7 @@ def pop_source(*args, **kwargs): @pop_source_on_return def _convert_ir_bb(ctx, ir, symbols): assert isinstance(ir, IRnode), ir - global _break_target, _continue_target, current_func, var_list + global _break_target, _continue_target, current_func, var_list, _global_symbols ctx.push_source(ir) @@ -267,6 +272,7 @@ def _convert_ir_bb(ctx, ir, symbols): # Internal definition var_list = ir.args[0].args[1] does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args + _global_symbols = {} symbols = {} _handle_internal_func(ctx, ir, does_return_data, symbols) for ir_node in ir.args[1:]: @@ -274,6 +280,7 @@ def _convert_ir_bb(ctx, ir, symbols): return ret elif is_external: + _global_symbols = {} ret = _convert_ir_bb(ctx, ir.args[0], symbols) _append_return_args(ctx) else: @@ -294,12 +301,24 @@ def _convert_ir_bb(ctx, ir, symbols): cont_ret = _convert_ir_bb(ctx, cond, symbols) cond_block = ctx.get_basic_block() - cond_symbols = symbols.copy() + saved_global_symbols = _global_symbols.copy() + then_block = IRBasicBlock(ctx.get_next_label("then"), ctx) else_block = IRBasicBlock(ctx.get_next_label("else"), ctx) - ctx.append_basic_block(else_block) + + # convert "then" + cond_symbols = symbols.copy() + ctx.append_basic_block(then_block) + then_ret_val = _convert_ir_bb(ctx, ir.args[1], cond_symbols) + if isinstance(then_ret_val, IRLiteral): + then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val) + + then_block_finish = ctx.get_basic_block() # convert "else" + cond_symbols = symbols.copy() + _global_symbols = saved_global_symbols.copy() + ctx.append_basic_block(else_block) else_ret_val = None if len(ir.args) == 3: else_ret_val = _convert_ir_bb(ctx, ir.args[2], cond_symbols) @@ -309,20 +328,9 @@ def _convert_ir_bb(ctx, ir, symbols): else_block_finish = ctx.get_basic_block() - # convert "then" - cond_symbols = symbols.copy() - - then_block = IRBasicBlock(ctx.get_next_label("then"), ctx) - ctx.append_basic_block(then_block) - - then_ret_val = _convert_ir_bb(ctx, ir.args[1], cond_symbols) - if isinstance(then_ret_val, IRLiteral): - then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val) - + # finish the condition block cond_block.append_instruction("jnz", cont_ret, then_block.label, else_block.label) - then_block_finish = ctx.get_basic_block() - # exit bb exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), ctx) exit_bb = ctx.append_basic_block(exit_bb) @@ -338,6 +346,8 @@ def _convert_ir_bb(ctx, ir, symbols): if not then_block_finish.is_terminated: then_block_finish.append_instruction("jmp", exit_bb.label) + _global_symbols = saved_global_symbols + return if_ret elif ir.value == "with": @@ -345,13 +355,12 @@ def _convert_ir_bb(ctx, ir, symbols): ret = ctx.get_basic_block().append_instruction("store", ret) - # Handle with nesting with same symbol - with_symbols = symbols.copy() - sym = ir.args[0] + with_symbols = symbols.copy() with_symbols[sym.value] = ret return _convert_ir_bb(ctx, ir.args[2], with_symbols) # body + elif ir.value == "goto": _append_jmp(ctx, IRLabel(ir.args[0].value)) elif ir.value == "djump": @@ -423,27 +432,13 @@ def _convert_ir_bb(ctx, ir, symbols): bb.append_instruction("dloadbytes", len_, src, dst) return None - elif ir.value == "mload": - arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols) - bb = ctx.get_basic_block() - if isinstance(arg_0, IRVariable): - return bb.append_instruction("mload", arg_0) - - if isinstance(arg_0, IRLiteral): - avar = symbols.get(f"%{arg_0.value}") - if avar is not None: - return bb.append_instruction("mload", avar) - - return bb.append_instruction("mload", arg_0) elif ir.value == "mstore": # some upstream code depends on reversed order of evaluation -- # to fix upstream. - arg_1, arg_0 = _convert_ir_bb_list(ctx, reversed(ir.args), symbols) + val, ptr = _convert_ir_bb_list(ctx, reversed(ir.args), symbols) - if isinstance(arg_1, IRVariable): - symbols[f"&{arg_0.value}"] = arg_1 + return ctx.get_basic_block().append_instruction("mstore", val, ptr) - ctx.get_basic_block().append_instruction("mstore", arg_1, arg_0) elif ir.value == "ceil32": x = ir.args[0] expanded = IRnode.from_list(["and", ["add", x, 31], ["not", 31]]) @@ -467,11 +462,13 @@ def _convert_ir_bb(ctx, ir, symbols): elif ir.value == "repeat": def emit_body_blocks(): - global _break_target, _continue_target + global _break_target, _continue_target, _global_symbols old_targets = _break_target, _continue_target _break_target, _continue_target = exit_block, incr_block + saved_global_symbols = _global_symbols.copy() _convert_ir_bb(ctx, body, symbols.copy()) _break_target, _continue_target = old_targets + _global_symbols = saved_global_symbols sym = ir.args[0] start, end, _ = _convert_ir_bb_list(ctx, ir.args[1:4], symbols) @@ -545,8 +542,17 @@ def emit_body_blocks(): ctx.get_basic_block().append_instruction("log", topic_count, *args) elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes(): _convert_ir_opcode(ctx, ir, symbols) - elif isinstance(ir.value, str) and ir.value in symbols: - return symbols[ir.value] + elif isinstance(ir.value, str): + if ir.value.startswith("$alloca") and ir.value not in _global_symbols: + alloca = ir.passthrough_metadata["alloca"] + ptr = ctx.get_basic_block().append_instruction("alloca", alloca.offset, alloca.size) + _global_symbols[ir.value] = ptr + elif ir.value.startswith("$palloca") and ir.value not in _global_symbols: + alloca = ir.passthrough_metadata["alloca"] + ptr = ctx.get_basic_block().append_instruction("store", alloca.offset) + _global_symbols[ir.value] = ptr + + return _global_symbols.get(ir.value) or symbols.get(ir.value) elif ir.is_literal: return IRLiteral(ir.value) else: diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index 5d149cf003..994ab9d70d 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -22,13 +22,15 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset: # if the instruction is a terminator, we need to place # it at the end of the basic block # along with all the instructions that "lead" to it - if uses_this.opcode in BB_TERMINATORS: - offset = len(bb.instructions) self._process_instruction_r(bb, uses_this, offset) if inst in self.visited_instructions: return self.visited_instructions.add(inst) + self.inst_order_num += 1 + + if inst.opcode in BB_TERMINATORS: + offset = len(bb.instructions) if inst.opcode == "phi": # phi instructions stay at the beginning of the basic block @@ -45,7 +47,6 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset: continue self._process_instruction_r(bb, target, offset) - self.inst_order_num += 1 self.inst_order[inst] = self.inst_order_num + offset def _process_basic_block(self, bb: IRBasicBlock) -> None: diff --git a/vyper/venom/passes/mem2var.py b/vyper/venom/passes/mem2var.py new file mode 100644 index 0000000000..9d74dfec0b --- /dev/null +++ b/vyper/venom/passes/mem2var.py @@ -0,0 +1,66 @@ +from vyper.utils import OrderedSet +from vyper.venom.analysis import DFG, calculate_cfg, calculate_liveness +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass + + +class Mem2Var(IRPass): + """ + This pass promoted memory operations to variable operations, when possible. + It does yet do any memory aliasing analysis, so it is conservative. + """ + + ctx: IRFunction + defs: dict[IRVariable, OrderedSet[IRBasicBlock]] + dfg: DFG + + def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock, dfg: DFG) -> int: + self.ctx = ctx + self.dfg = dfg + + calculate_cfg(ctx) + + dfg = DFG.build_dfg(ctx) + self.dfg = dfg + + calculate_liveness(ctx) + + self.var_name_count = 0 + for var, inst in dfg.outputs.items(): + if inst.opcode != "alloca": + continue + self._process_alloca_var(dfg, var) + + return 0 + + def _process_alloca_var(self, dfg: DFG, var: IRVariable): + """ + Process alloca allocated variable. If it is only used by mstore/mload/return + instructions, it is promoted to a stack variable. Otherwise, it is left as is. + """ + uses = dfg.get_uses(var) + if all([inst.opcode == "mload" for inst in uses]): + return + elif all([inst.opcode == "mstore" for inst in uses]): + return + elif all([inst.opcode in ["mstore", "mload", "return"] for inst in uses]): + var_name = f"addr{var.name}_{self.var_name_count}" + self.var_name_count += 1 + for inst in uses: + if inst.opcode == "mstore": + inst.opcode = "store" + inst.output = IRVariable(var_name) + inst.operands = [inst.operands[0]] + elif inst.opcode == "mload": + inst.opcode = "store" + inst.operands = [IRVariable(var_name)] + elif inst.opcode == "return": + bb = inst.parent + new_var = self.ctx.get_next_variable() + idx = bb.instructions.index(inst) + bb.insert_instruction( + IRInstruction("mstore", [IRVariable(var_name), inst.operands[1]], new_var), + idx, + ) + inst.operands[1] = new_var diff --git a/vyper/venom/passes/sccp/__init__.py b/vyper/venom/passes/sccp/__init__.py new file mode 100644 index 0000000000..866d55e801 --- /dev/null +++ b/vyper/venom/passes/sccp/__init__.py @@ -0,0 +1 @@ +from vyper.venom.passes.sccp.sccp import SCCP diff --git a/vyper/venom/passes/sccp/eval.py b/vyper/venom/passes/sccp/eval.py new file mode 100644 index 0000000000..8acca039c0 --- /dev/null +++ b/vyper/venom/passes/sccp/eval.py @@ -0,0 +1,132 @@ +import operator +from typing import Callable + +from vyper.utils import SizeLimits, evm_div, evm_mod, signed_to_unsigned, unsigned_to_signed +from vyper.venom.basicblock import IROperand + + +def _unsigned_to_signed(value: int) -> int: + if value <= SizeLimits.MAX_INT256: + return value # fast exit + else: + return unsigned_to_signed(value, 256) + + +def _signed_to_unsigned(value: int) -> int: + if value >= 0: + return value # fast exit + else: + return signed_to_unsigned(value, 256) + + +def _wrap_signed_binop(operation): + def wrapper(ops: list[IROperand]) -> int: + first = _unsigned_to_signed(ops[1].value) + second = _unsigned_to_signed(ops[0].value) + return _signed_to_unsigned(int(operation(first, second))) + + return wrapper + + +def _wrap_binop(operation): + def wrapper(ops: list[IROperand]) -> int: + first = ops[1].value + second = ops[0].value + return (int(operation(first, second))) & SizeLimits.MAX_UINT256 + + return wrapper + + +def _evm_signextend(ops: list[IROperand]) -> int: + value = ops[0].value + nbytes = ops[1].value + + assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds" + + if nbytes > 31: + return value + + sign_bit = 1 << (nbytes * 8 + 7) + if value & sign_bit: + value |= SizeLimits.CEILING_UINT256 - sign_bit + else: + value &= sign_bit - 1 + + return value + + +def _evm_iszero(ops: list[IROperand]) -> int: + value = ops[0].value + assert SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds" + return int(value == 0) # 1 if True else 0 + + +def _evm_shr(ops: list[IROperand]) -> int: + value = ops[0].value + shift_len = ops[1].value + assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds" + return value >> shift_len + + +def _evm_shl(ops: list[IROperand]) -> int: + value = ops[0].value + shift_len = ops[1].value + assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds" + if shift_len >= 256: + return 0 + return (value << shift_len) & SizeLimits.MAX_UINT256 + + +def _evm_sar(ops: list[IROperand]) -> int: + value = _unsigned_to_signed(ops[0].value) + assert SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256, "Value out of bounds" + shift_len = ops[1].value + return value >> shift_len + + +def _evm_not(ops: list[IROperand]) -> int: + value = ops[0].value + assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds" + return SizeLimits.MAX_UINT256 ^ value + + +def _evm_exp(ops: list[IROperand]) -> int: + base = ops[1].value + exponent = ops[0].value + + if base == 0: + return 0 + + return pow(base, exponent, SizeLimits.CEILING_UINT256) + + +ARITHMETIC_OPS: dict[str, Callable[[list[IROperand]], int]] = { + "add": _wrap_binop(operator.add), + "sub": _wrap_binop(operator.sub), + "mul": _wrap_binop(operator.mul), + "div": _wrap_binop(evm_div), + "sdiv": _wrap_signed_binop(evm_div), + "mod": _wrap_binop(evm_mod), + "smod": _wrap_signed_binop(evm_mod), + "exp": _evm_exp, + "eq": _wrap_binop(operator.eq), + "ne": _wrap_binop(operator.ne), + "lt": _wrap_binop(operator.lt), + "le": _wrap_binop(operator.le), + "gt": _wrap_binop(operator.gt), + "ge": _wrap_binop(operator.ge), + "slt": _wrap_signed_binop(operator.lt), + "sle": _wrap_signed_binop(operator.le), + "sgt": _wrap_signed_binop(operator.gt), + "sge": _wrap_signed_binop(operator.ge), + "or": _wrap_binop(operator.or_), + "and": _wrap_binop(operator.and_), + "xor": _wrap_binop(operator.xor), + "not": _evm_not, + "signextend": _evm_signextend, + "iszero": _evm_iszero, + "shr": _evm_shr, + "shl": _evm_shl, + "sar": _evm_sar, + "store": lambda ops: ops[0].value, +} diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py new file mode 100644 index 0000000000..7dfca8edd4 --- /dev/null +++ b/vyper/venom/passes/sccp/sccp.py @@ -0,0 +1,332 @@ +from dataclasses import dataclass +from enum import Enum +from functools import reduce +from typing import Union + +from vyper.exceptions import CompilerPanic, StaticAssertionException +from vyper.utils import OrderedSet +from vyper.venom.basicblock import ( + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, +) +from vyper.venom.dominators import DominatorTree +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRPass +from vyper.venom.passes.sccp.eval import ARITHMETIC_OPS + + +class LatticeEnum(Enum): + TOP = 1 + BOTTOM = 2 + + +@dataclass +class SSAWorkListItem: + inst: IRInstruction + + +@dataclass +class FlowWorkItem: + start: IRBasicBlock + end: IRBasicBlock + + +WorkListItem = Union[FlowWorkItem, SSAWorkListItem] +LatticeItem = Union[LatticeEnum, IRLiteral] +Lattice = dict[IROperand, LatticeItem] + + +class SCCP(IRPass): + """ + This class implements the Sparse Conditional Constant Propagation + algorithm by Wegman and Zadeck. It is a forward dataflow analysis + that propagates constant values through the IR graph. It is used + to optimize the IR by removing dead code and replacing variables + with their constant values. + """ + + ctx: IRFunction + dom: DominatorTree + uses: dict[IRVariable, OrderedSet[IRInstruction]] + lattice: Lattice + work_list: list[WorkListItem] + cfg_dirty: bool + cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] + + def __init__(self, dom: DominatorTree): + self.dom = dom + self.lattice = {} + self.work_list: list[WorkListItem] = [] + self.cfg_dirty = False + + def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> int: + self.ctx = ctx + self._compute_uses(self.dom) + self._calculate_sccp(entry) + self._propagate_constants() + + # self._propagate_variables() + return 0 + + def _calculate_sccp(self, entry: IRBasicBlock): + """ + This method is the main entry point for the SCCP algorithm. It + initializes the work list and the lattice and then iterates over + the work list until it is empty. It then visits each basic block + in the CFG and processes the instructions in the block. + + This method does not update the IR, it only updates the lattice + and the work list. The `_propagate_constants()` method is responsible + for updating the IR with the constant values. + """ + self.cfg_in_exec = {bb: OrderedSet() for bb in self.ctx.basic_blocks} + + dummy = IRBasicBlock(IRLabel("__dummy_start"), self.ctx) + self.work_list.append(FlowWorkItem(dummy, entry)) + + # Initialize the lattice with TOP values for all variables + for v in self.uses.keys(): + self.lattice[v] = LatticeEnum.TOP + + # Iterate over the work list until it is empty + # Items in the work list can be either FlowWorkItem or SSAWorkListItem + while len(self.work_list) > 0: + work_item = self.work_list.pop() + if isinstance(work_item, FlowWorkItem): + self._handle_flow_work_item(work_item) + elif isinstance(work_item, SSAWorkListItem): + self._handle_SSA_work_item(work_item) + else: + raise CompilerPanic("Invalid work item type") + + def _handle_flow_work_item(self, work_item: FlowWorkItem): + """ + This method handles a FlowWorkItem. + """ + start = work_item.start + end = work_item.end + if start in self.cfg_in_exec[end]: + return + self.cfg_in_exec[end].add(start) + + for inst in end.instructions: + if inst.opcode == "phi": + self._visit_phi(inst) + else: + # Stop at the first non-phi instruction + # as phis are only valid at the beginning of a block + break + + if len(self.cfg_in_exec[end]) == 1: + for inst in end.instructions: + if inst.opcode == "phi": + continue + self._visit_expr(inst) + + if len(end.cfg_out) == 1: + self.work_list.append(FlowWorkItem(end, end.cfg_out.first())) + + def _handle_SSA_work_item(self, work_item: SSAWorkListItem): + """ + This method handles a SSAWorkListItem. + """ + if work_item.inst.opcode == "phi": + self._visit_phi(work_item.inst) + elif len(self.cfg_in_exec[work_item.inst.parent]) > 0: + self._visit_expr(work_item.inst) + + def _visit_phi(self, inst: IRInstruction): + assert inst.opcode == "phi", "Can't visit non phi instruction" + in_vars: list[LatticeItem] = [] + for bb_label, var in inst.phi_operands: + bb = self.ctx.get_basic_block(bb_label.name) + if bb not in self.cfg_in_exec[inst.parent]: + continue + in_vars.append(self.lattice[var]) + value = reduce(_meet, in_vars, LatticeEnum.TOP) # type: ignore + assert inst.output in self.lattice, "Got undefined var for phi" + if value != self.lattice[inst.output]: + self.lattice[inst.output] = value + self._add_ssa_work_items(inst) + + def _visit_expr(self, inst: IRInstruction): + opcode = inst.opcode + if opcode in ["store", "alloca"]: + if isinstance(inst.operands[0], IRLiteral): + self.lattice[inst.output] = inst.operands[0] # type: ignore + else: + self.lattice[inst.output] = self.lattice[inst.operands[0]] # type: ignore + self._add_ssa_work_items(inst) + elif opcode == "jmp": + target = self.ctx.get_basic_block(inst.operands[0].value) + self.work_list.append(FlowWorkItem(inst.parent, target)) + elif opcode == "jnz": + lat = self.lattice[inst.operands[0]] + assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}" + if lat == LatticeEnum.BOTTOM: + for out_bb in inst.parent.cfg_out: + self.work_list.append(FlowWorkItem(inst.parent, out_bb)) + else: + if _meet(lat, IRLiteral(0)) == LatticeEnum.BOTTOM: + target = self.ctx.get_basic_block(inst.operands[1].name) + self.work_list.append(FlowWorkItem(inst.parent, target)) + if _meet(lat, IRLiteral(1)) == LatticeEnum.BOTTOM: + target = self.ctx.get_basic_block(inst.operands[2].name) + self.work_list.append(FlowWorkItem(inst.parent, target)) + elif opcode == "djmp": + lat = self.lattice[inst.operands[0]] + assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}" + if lat == LatticeEnum.BOTTOM: + for op in inst.operands[1:]: + target = self.ctx.get_basic_block(op.name) + self.work_list.append(FlowWorkItem(inst.parent, target)) + elif isinstance(lat, IRLiteral): + raise CompilerPanic("Unimplemented djmp with literal") + + elif opcode in ["param", "calldataload"]: + self.lattice[inst.output] = LatticeEnum.BOTTOM # type: ignore + self._add_ssa_work_items(inst) + elif opcode == "mload": + self.lattice[inst.output] = LatticeEnum.BOTTOM # type: ignore + elif opcode in ARITHMETIC_OPS: + self._eval(inst) + else: + if inst.output is not None: + self.lattice[inst.output] = LatticeEnum.BOTTOM + + def _eval(self, inst) -> LatticeItem: + """ + This method evaluates an arithmetic operation and returns the result. + At the same time it updates the lattice with the result and adds the + instruction to the SSA work list if the knowledge about the variable + changed. + """ + opcode = inst.opcode + + ops = [] + for op in inst.operands: + if isinstance(op, IRVariable): + ops.append(self.lattice[op]) + elif isinstance(op, IRLabel): + return LatticeEnum.BOTTOM + else: + ops.append(op) + + ret = None + if LatticeEnum.BOTTOM in ops: + ret = LatticeEnum.BOTTOM + else: + if opcode in ARITHMETIC_OPS: + fn = ARITHMETIC_OPS[opcode] + ret = IRLiteral(fn(ops)) # type: ignore + elif len(ops) > 0: + ret = ops[0] # type: ignore + else: + raise CompilerPanic("Bad constant evaluation") + + old_val = self.lattice.get(inst.output, LatticeEnum.TOP) + if old_val != ret: + self.lattice[inst.output] = ret # type: ignore + self._add_ssa_work_items(inst) + + return ret # type: ignore + + def _add_ssa_work_items(self, inst: IRInstruction): + for target_inst in self._get_uses(inst.output): # type: ignore + self.work_list.append(SSAWorkListItem(target_inst)) + + def _compute_uses(self, dom: DominatorTree): + """ + This method computes the uses for each variable in the IR. + It iterates over the dominator tree and collects all the + instructions that use each variable. + """ + self.uses = {} + for bb in dom.dfs_walk: + for var, insts in bb.get_uses().items(): + self._get_uses(var).update(insts) + + def _get_uses(self, var: IRVariable): + if var not in self.uses: + self.uses[var] = OrderedSet() + return self.uses[var] + + def _propagate_constants(self): + """ + This method iterates over the IR and replaces constant values + with their actual values. It also replaces conditional jumps + with unconditional jumps if the condition is a constant value. + """ + for bb in self.dom.dfs_walk: + for inst in bb.instructions: + self._replace_constants(inst, self.lattice) + + def _replace_constants(self, inst: IRInstruction, lattice: Lattice): + """ + This method replaces constant values in the instruction with + their actual values. It also updates the instruction opcode in + case of jumps and asserts as needed. + """ + if inst.opcode == "jnz": + lat = lattice[inst.operands[0]] + if isinstance(lat, IRLiteral): + if lat.value == 0: + target = inst.operands[2] + else: + target = inst.operands[1] + inst.opcode = "jmp" + inst.operands = [target] + self.cfg_dirty = True + elif inst.opcode == "assert": + lat = lattice[inst.operands[0]] + if isinstance(lat, IRLiteral): + if lat.value > 0: + inst.opcode = "nop" + else: + raise StaticAssertionException( + f"assertion found to fail at compile time ({inst.error_msg}).", + inst.get_ast_source(), + ) + + inst.operands = [] + + elif inst.opcode == "phi": + return + + for i, op in enumerate(inst.operands): + if isinstance(op, IRVariable): + lat = lattice[op] + if isinstance(lat, IRLiteral): + inst.operands[i] = lat + + def _propagate_variables(self): + """ + Copy elimination. #NOTE: Not working yet, but it's also not needed atm. + """ + for bb in self.dom.dfs_walk: + for inst in bb.instructions: + if inst.opcode == "store": + uses = self._get_uses(inst.output) + remove_inst = True + for usage_inst in uses: + if usage_inst.opcode == "phi": + remove_inst = False + continue + for i, op in enumerate(usage_inst.operands): + if op == inst.output: + usage_inst.operands[i] = inst.operands[0] + if remove_inst: + inst.opcode = "nop" + inst.operands = [] + + +def _meet(x: LatticeItem, y: LatticeItem) -> LatticeItem: + if x == LatticeEnum.TOP: + return y + if y == LatticeEnum.TOP or x == y: + return x + return LatticeEnum.BOTTOM diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index 7f02ccf819..bebf2acd32 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -1,5 +1,7 @@ +from vyper.exceptions import CompilerPanic from vyper.utils import OrderedSet from vyper.venom.basicblock import IRBasicBlock +from vyper.venom.bb_optimizer import ir_pass_remove_unreachable_blocks from vyper.venom.function import IRFunction from vyper.venom.passes.base_pass import IRPass @@ -24,6 +26,11 @@ def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock): next_bb.remove_cfg_in(b) next_bb.add_cfg_in(a) + for inst in next_bb.instructions: + if inst.opcode != "phi": + break + inst.operands[inst.operands.index(b.label)] = a.label + self.ctx.basic_blocks.remove(b) def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): @@ -79,4 +86,9 @@ def _collapse_chained_blocks(self, entry: IRBasicBlock): def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> None: self.ctx = ctx - self._collapse_chained_blocks(entry) + for _ in range(len(ctx.basic_blocks)): # essentially `while True` + self._collapse_chained_blocks(entry) + if ir_pass_remove_unreachable_blocks(ctx) == 0: + break + else: + raise CompilerPanic("Too many iterations collapsing chained blocks") diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 6924628958..7b58f0d54e 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -353,9 +353,10 @@ def _generate_evm_for_instruction( # Step 1: Apply instruction special stack manipulations if opcode in ["jmp", "djmp", "jnz", "invoke"]: - operands = inst.get_non_label_operands() + operands = list(inst.get_non_label_operands()) elif opcode == "alloca": - operands = inst.operands[1:2] + offset, _size = inst.operands + operands = [offset] # iload and istore are special cases because they can take a literal # that is handled specialy with the _OFST macro. Look below, after the @@ -381,7 +382,7 @@ def _generate_evm_for_instruction( if opcode == "phi": ret = inst.get_outputs()[0] - phis = inst.get_inputs() + phis = list(inst.get_inputs()) depth = stack.get_phi_depth(phis) # collapse the arguments to the phi node in the stack. # example, for `%56 = %label1 %13 %label2 %14`, we will @@ -523,6 +524,8 @@ def _generate_evm_for_instruction( assembly.append("MSTORE") elif opcode == "log": assembly.extend([f"LOG{log_topic_count}"]) + elif opcode == "nop": + pass else: raise Exception(f"Unknown opcode: {opcode}")