diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py
index 67a1f56e87..77e88c0277 100644
--- a/tests/functional/builtins/codegen/test_abi_decode.py
+++ b/tests/functional/builtins/codegen/test_abi_decode.py
@@ -2,7 +2,7 @@
from eth.codecs import abi
from tests.utils import decimal_to_int
-from vyper.exceptions import ArgumentException, StackTooDeep, StructureException
+from vyper.exceptions import ArgumentException, StructureException
TEST_ADDR = "0x" + b"".join(chr(i).encode("utf-8") for i in range(20)).hex()
@@ -201,7 +201,6 @@ def abi_decode(x: Bytes[{len}]) -> DynArray[DynArray[uint256, 3], 3]:
@pytest.mark.parametrize("args", nested_3d_array_args)
@pytest.mark.parametrize("unwrap_tuple", (True, False))
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_abi_decode_nested_dynarray2(get_contract, args, unwrap_tuple):
if unwrap_tuple is True:
encoded = abi.encode("(uint256[][][])", (args,))
@@ -279,7 +278,6 @@ def foo(bs: Bytes[160]) -> (uint256, DynArray[uint256, 3]):
assert c.foo(encoded) == [2**256 - 1, bs]
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_abi_decode_private_nested_dynarray(get_contract):
code = """
bytez: DynArray[DynArray[DynArray[uint256, 3], 3], 3]
diff --git a/tests/functional/builtins/codegen/test_abi_encode.py b/tests/functional/builtins/codegen/test_abi_encode.py
index 2331fb1a9e..387dab789a 100644
--- a/tests/functional/builtins/codegen/test_abi_encode.py
+++ b/tests/functional/builtins/codegen/test_abi_encode.py
@@ -2,7 +2,6 @@
from eth.codecs import abi
from tests.utils import decimal_to_int
-from vyper.exceptions import StackTooDeep
# @pytest.mark.parametrize("string", ["a", "abc", "abcde", "potato"])
@@ -227,7 +226,6 @@ def abi_encode(
@pytest.mark.parametrize("args", nested_3d_array_args)
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_abi_encode_nested_dynarray_2(get_contract, args):
code = """
@external
@@ -332,7 +330,6 @@ def foo(bs: DynArray[uint256, 3]) -> (uint256, Bytes[160]):
assert c.foo(bs) == [2**256 - 1, abi.encode("(uint256[])", (bs,))]
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_abi_encode_private_nested_dynarray(get_contract):
code = """
bytez: Bytes[1696]
diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py
index e64a35811d..0fd2c6212e 100644
--- a/tests/functional/codegen/features/iteration/test_for_range.py
+++ b/tests/functional/codegen/features/iteration/test_for_range.py
@@ -1,5 +1,8 @@
import pytest
+from vyper.exceptions import StaticAssertionException
+from vyper.utils import SizeLimits
+
def test_basic_repeater(get_contract_with_gas_estimation):
basic_repeater = """
@@ -271,17 +274,38 @@ def test():
@pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"])
-def test_for_range_oob_check(get_contract, tx_failed, typ):
+def test_for_range_oob_compile_time_check(get_contract, tx_failed, typ, experimental_codegen):
code = f"""
@external
def test():
x: {typ} = max_value({typ})
+ for i: {typ} in range(x, x + 2, bound=2):
+ pass
+ """
+ if not experimental_codegen:
+ return
+ with pytest.raises(StaticAssertionException):
+ get_contract(code)
+
+
+@pytest.mark.parametrize(
+ "typ, max_value",
+ [
+ ("uint8", SizeLimits.MAX_UINT8),
+ ("int128", SizeLimits.MAX_INT128),
+ ("uint256", SizeLimits.MAX_UINT256),
+ ],
+)
+def test_for_range_oob_runtime_check(get_contract, tx_failed, typ, max_value):
+ code = f"""
+@external
+def test(x: {typ}):
for i: {typ} in range(x, x + 2, bound=2):
pass
"""
c = get_contract(code)
with tx_failed():
- c.test()
+ c.test(max_value)
@pytest.mark.parametrize("typ", ["int128", "uint256"])
@@ -416,7 +440,25 @@ def foo(a: {typ}) -> {typ}:
assert c.foo(0) == 31337
-def test_for_range_signed_int_overflow(get_contract, tx_failed):
+def test_for_range_signed_int_overflow_runtime_check(get_contract, tx_failed, experimental_codegen):
+ code = """
+@external
+def foo(_min:int256, _max: int256) -> DynArray[int256, 10]:
+ res: DynArray[int256, 10] = empty(DynArray[int256, 10])
+ x:int256 = _max
+ y:int256 = _min+2
+ for i:int256 in range(x,y , bound=10):
+ res.append(i)
+ return res
+ """
+ c = get_contract(code)
+ with tx_failed():
+ c.foo(SizeLimits.MIN_INT256, SizeLimits.MAX_INT256)
+
+
+def test_for_range_signed_int_overflow_compile_time_check(
+ get_contract, tx_failed, experimental_codegen
+):
code = """
@external
def foo() -> DynArray[int256, 10]:
@@ -427,6 +469,7 @@ def foo() -> DynArray[int256, 10]:
res.append(i)
return res
"""
- c = get_contract(code)
- with tx_failed():
- c.foo()
+ if not experimental_codegen:
+ return
+ with pytest.raises(StaticAssertionException):
+ get_contract(code)
diff --git a/tests/functional/codegen/features/test_constructor.py b/tests/functional/codegen/features/test_constructor.py
index d96a889497..9146ace8a6 100644
--- a/tests/functional/codegen/features/test_constructor.py
+++ b/tests/functional/codegen/features/test_constructor.py
@@ -1,8 +1,6 @@
import pytest
from web3.exceptions import ValidationError
-from vyper.exceptions import StackTooDeep
-
def test_init_argument_test(get_contract_with_gas_estimation):
init_argument_test = """
@@ -165,7 +163,6 @@ def get_foo() -> uint256:
assert c.get_foo() == 39
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_nested_dynamic_array_constructor_arg_2(w3, get_contract_with_gas_estimation):
code = """
foo: int128
@@ -211,7 +208,6 @@ def get_foo() -> DynArray[DynArray[uint256, 3], 3]:
assert c.get_foo() == [[37, 41, 73], [37041, 41073, 73037], [146, 123, 148]]
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_initialise_nested_dynamic_array_2(w3, get_contract_with_gas_estimation):
code = """
foo: DynArray[DynArray[DynArray[int128, 3], 3], 3]
diff --git a/tests/functional/codegen/features/test_immutable.py b/tests/functional/codegen/features/test_immutable.py
index 874600633a..49ff54b353 100644
--- a/tests/functional/codegen/features/test_immutable.py
+++ b/tests/functional/codegen/features/test_immutable.py
@@ -1,7 +1,6 @@
import pytest
from vyper.compiler.settings import OptimizationLevel
-from vyper.exceptions import StackTooDeep
@pytest.mark.parametrize(
@@ -199,7 +198,6 @@ def get_idx_two() -> uint256:
assert c.get_idx_two() == expected_values[2][2]
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_nested_dynarray_immutable(get_contract):
code = """
my_list: immutable(DynArray[DynArray[DynArray[int128, 3], 3], 3])
diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py
index e063f981ec..01c2cd74c4 100644
--- a/tests/functional/codegen/types/numbers/test_signed_ints.py
+++ b/tests/functional/codegen/types/numbers/test_signed_ints.py
@@ -8,6 +8,7 @@
from vyper.exceptions import (
InvalidOperation,
OverflowException,
+ StaticAssertionException,
TypeMismatch,
ZeroDivisionException,
)
@@ -73,18 +74,35 @@ def foo(x: int256) -> int256:
# TODO: make this test pass
@pytest.mark.parametrize("base", (0, 1))
-def test_exponent_negative_power(get_contract, tx_failed, base):
+def test_exponent_negative_power_runtime_check(get_contract, tx_failed, base, experimental_codegen):
# #2985
code = f"""
@external
-def bar() -> int16:
- x: int16 = -2
+def bar(negative:int16) -> int16:
+ x: int16 = negative
return {base} ** x
"""
c = get_contract(code)
# known bug: 2985
with tx_failed():
- c.bar()
+ c.bar(-2)
+
+
+@pytest.mark.parametrize("base", (0, 1))
+def test_exponent_negative_power_compile_time_check(
+ get_contract, tx_failed, base, experimental_codegen
+):
+ # #2985
+ code = f"""
+@external
+def bar() -> int16:
+ x: int16 = -2
+ return {base} ** x
+ """
+ if not experimental_codegen:
+ return
+ with pytest.raises(StaticAssertionException):
+ get_contract(code)
def test_exponent_min_int16(get_contract):
diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py
index 63d3d0b8f1..1383895c7a 100644
--- a/tests/functional/codegen/types/test_dynamic_array.py
+++ b/tests/functional/codegen/types/test_dynamic_array.py
@@ -62,7 +62,6 @@ def loo(x: DynArray[DynArray[int128, 2], 2]) -> int128:
print("Passed list tests")
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_string_list(get_contract):
code = """
@external
@@ -1491,7 +1490,6 @@ def foo(x: int128) -> int128:
assert c.foo(7) == 392
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_struct_of_lists(get_contract):
code = """
struct Foo:
@@ -1580,7 +1578,6 @@ def bar(x: int128) -> DynArray[int128, 3]:
assert c.bar(7) == [7, 14]
-@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_nested_struct_of_lists(get_contract, assert_compile_failed, optimize):
code = """
struct nestedFoo:
@@ -1710,9 +1707,7 @@ def __init__():
("DynArray[DynArray[DynArray[uint256, 5], 5], 5]", [[[], []], []]),
],
)
-def test_empty_nested_dynarray(get_contract, typ, val, venom_xfail):
- if val == [[[], []], []]:
- venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
+def test_empty_nested_dynarray(get_contract, typ, val):
code = f"""
@external
def foo() -> {typ}:
diff --git a/tests/unit/compiler/venom/test_sccp.py b/tests/unit/compiler/venom/test_sccp.py
new file mode 100644
index 0000000000..8102a0d89c
--- /dev/null
+++ b/tests/unit/compiler/venom/test_sccp.py
@@ -0,0 +1,139 @@
+from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable
+from vyper.venom.function import IRFunction
+from vyper.venom.passes.make_ssa import MakeSSA
+from vyper.venom.passes.sccp import SCCP
+from vyper.venom.passes.sccp.sccp import LatticeEnum
+
+
+def test_simple_case():
+ ctx = IRFunction(IRLabel("_global"))
+
+ bb = ctx.get_basic_block()
+ p1 = bb.append_instruction("param")
+ op1 = bb.append_instruction("store", 32)
+ op2 = bb.append_instruction("store", 64)
+ op3 = bb.append_instruction("add", op1, op2)
+ bb.append_instruction("return", p1, op3)
+
+ make_ssa_pass = MakeSSA()
+ make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0])
+ sccp = SCCP(make_ssa_pass.dom)
+ sccp.run_pass(ctx, ctx.basic_blocks[0])
+
+ assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM
+ assert sccp.lattice[IRVariable("%2")].value == 32
+ assert sccp.lattice[IRVariable("%3")].value == 64
+ assert sccp.lattice[IRVariable("%4")].value == 96
+
+
+def test_cont_jump_case():
+ ctx = IRFunction(IRLabel("_global"))
+
+ bb = ctx.get_basic_block()
+
+ br1 = IRBasicBlock(IRLabel("then"), ctx)
+ ctx.append_basic_block(br1)
+ br2 = IRBasicBlock(IRLabel("else"), ctx)
+ ctx.append_basic_block(br2)
+
+ p1 = bb.append_instruction("param")
+ op1 = bb.append_instruction("store", 32)
+ op2 = bb.append_instruction("store", 64)
+ op3 = bb.append_instruction("add", op1, op2)
+ bb.append_instruction("jnz", op3, br1.label, br2.label)
+
+ br1.append_instruction("add", op3, 10)
+ br1.append_instruction("stop")
+ br2.append_instruction("add", op3, p1)
+ br2.append_instruction("stop")
+
+ make_ssa_pass = MakeSSA()
+ make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0])
+ sccp = SCCP(make_ssa_pass.dom)
+ sccp.run_pass(ctx, ctx.basic_blocks[0])
+
+ assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM
+ assert sccp.lattice[IRVariable("%2")].value == 32
+ assert sccp.lattice[IRVariable("%3")].value == 64
+ assert sccp.lattice[IRVariable("%4")].value == 96
+ assert sccp.lattice[IRVariable("%5")].value == 106
+ assert sccp.lattice.get(IRVariable("%6")) == LatticeEnum.BOTTOM
+
+
+def test_cont_phi_case():
+ ctx = IRFunction(IRLabel("_global"))
+
+ bb = ctx.get_basic_block()
+
+ br1 = IRBasicBlock(IRLabel("then"), ctx)
+ ctx.append_basic_block(br1)
+ br2 = IRBasicBlock(IRLabel("else"), ctx)
+ ctx.append_basic_block(br2)
+ join = IRBasicBlock(IRLabel("join"), ctx)
+ ctx.append_basic_block(join)
+
+ p1 = bb.append_instruction("param")
+ op1 = bb.append_instruction("store", 32)
+ op2 = bb.append_instruction("store", 64)
+ op3 = bb.append_instruction("add", op1, op2)
+ bb.append_instruction("jnz", op3, br1.label, br2.label)
+
+ op4 = br1.append_instruction("add", op3, 10)
+ br1.append_instruction("jmp", join.label)
+ br2.append_instruction("add", op3, p1, ret=op4)
+ br2.append_instruction("jmp", join.label)
+
+ join.append_instruction("return", op4, p1)
+
+ make_ssa_pass = MakeSSA()
+ make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0])
+
+ sccp = SCCP(make_ssa_pass.dom)
+ sccp.run_pass(ctx, ctx.basic_blocks[0])
+
+ assert sccp.lattice[IRVariable("%1")] == LatticeEnum.BOTTOM
+ assert sccp.lattice[IRVariable("%2")].value == 32
+ assert sccp.lattice[IRVariable("%3")].value == 64
+ assert sccp.lattice[IRVariable("%4")].value == 96
+ assert sccp.lattice[IRVariable("%5", version=1)].value == 106
+ assert sccp.lattice[IRVariable("%5", version=2)] == LatticeEnum.BOTTOM
+ assert sccp.lattice[IRVariable("%5")].value == 2
+
+
+def test_cont_phi_const_case():
+ ctx = IRFunction(IRLabel("_global"))
+
+ bb = ctx.get_basic_block()
+
+ br1 = IRBasicBlock(IRLabel("then"), ctx)
+ ctx.append_basic_block(br1)
+ br2 = IRBasicBlock(IRLabel("else"), ctx)
+ ctx.append_basic_block(br2)
+ join = IRBasicBlock(IRLabel("join"), ctx)
+ ctx.append_basic_block(join)
+
+ p1 = bb.append_instruction("store", 1)
+ op1 = bb.append_instruction("store", 32)
+ op2 = bb.append_instruction("store", 64)
+ op3 = bb.append_instruction("add", op1, op2)
+ bb.append_instruction("jnz", op3, br1.label, br2.label)
+
+ op4 = br1.append_instruction("add", op3, 10)
+ br1.append_instruction("jmp", join.label)
+ br2.append_instruction("add", op3, p1, ret=op4)
+ br2.append_instruction("jmp", join.label)
+
+ join.append_instruction("return", op4, p1)
+
+ make_ssa_pass = MakeSSA()
+ make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0])
+ sccp = SCCP(make_ssa_pass.dom)
+ sccp.run_pass(ctx, ctx.basic_blocks[0])
+
+ assert sccp.lattice[IRVariable("%1")].value == 1
+ assert sccp.lattice[IRVariable("%2")].value == 32
+ assert sccp.lattice[IRVariable("%3")].value == 64
+ assert sccp.lattice[IRVariable("%4")].value == 96
+ assert sccp.lattice[IRVariable("%5", version=1)].value == 106
+ assert sccp.lattice[IRVariable("%5", version=2)].value == 97
+ assert sccp.lattice[IRVariable("%5")].value == 2
diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py
index 35cbbe648d..bd9d8b2e5e 100644
--- a/vyper/builtins/functions.py
+++ b/vyper/builtins/functions.py
@@ -248,16 +248,16 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
dst_typ = BytesT(length.value)
# allocate a buffer for the return value
- np = context.new_internal_variable(dst_typ)
+ buf = context.new_internal_variable(dst_typ)
# `msg.data` by `calldatacopy`
if sub.value == "~calldata":
node = [
"seq",
_make_slice_bounds_check(start, length, "calldatasize"),
- ["mstore", np, length],
- ["calldatacopy", np + 32, start, length],
- np,
+ ["mstore", buf, length],
+ ["calldatacopy", add_ofst(buf, 32), start, length],
+ buf,
]
# `self.code` by `codecopy`
@@ -265,9 +265,9 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
node = [
"seq",
_make_slice_bounds_check(start, length, "codesize"),
- ["mstore", np, length],
- ["codecopy", np + 32, start, length],
- np,
+ ["mstore", buf, length],
+ ["codecopy", add_ofst(buf, 32), start, length],
+ buf,
]
# `
.code` by `extcodecopy`
@@ -280,9 +280,9 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
[
"seq",
_make_slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]),
- ["mstore", np, length],
- ["extcodecopy", "_extcode_address", np + 32, start, length],
- np,
+ ["mstore", buf, length],
+ ["extcodecopy", "_extcode_address", add_ofst(buf, 32), start, length],
+ buf,
],
]
@@ -552,10 +552,8 @@ def build_IR(self, expr, context):
# respect API of copy_bytes
bufsize = dst_maxlen + 32
- buf = context.new_internal_variable(BytesT(bufsize))
-
- # Node representing the position of the output in memory
- dst = IRnode.from_list(buf, typ=ret_typ, location=MEMORY, annotation="concat destination")
+ dst = context.new_internal_variable(BytesT(bufsize))
+ dst.annotation = "concat destination"
ret = ["seq"]
# stack item representing our current offset in the dst buffer
@@ -783,9 +781,9 @@ def build_IR(self, expr, args, kwargs, context):
# clear output memory first, ecrecover can return 0 bytes
["mstore", output_buf, 0],
["mstore", input_buf, args[0]],
- ["mstore", input_buf + 32, args[1]],
- ["mstore", input_buf + 64, args[2]],
- ["mstore", input_buf + 96, args[3]],
+ ["mstore", add_ofst(input_buf, 32), args[1]],
+ ["mstore", add_ofst(input_buf, 64), args[2]],
+ ["mstore", add_ofst(input_buf, 96), args[3]],
["staticcall", "gas", 1, input_buf, 128, output_buf, 32],
["mload", output_buf],
],
@@ -799,9 +797,7 @@ def build_IR(self, expr, _args, kwargs, context):
args_tuple = ir_tuple_from_args(_args)
args_t = args_tuple.typ
- input_buf = IRnode.from_list(
- context.new_internal_variable(args_t), typ=args_t, location=MEMORY
- )
+ input_buf = context.new_internal_variable(args_t)
ret_t = self._return_type
ret = ["seq"]
@@ -1103,9 +1099,7 @@ def build_IR(self, expr, args, kwargs, context):
args_ofst = add_ofst(input_buf, 32)
args_len = ["mload", input_buf]
- output_node = IRnode.from_list(
- context.new_internal_variable(BytesT(outsize)), typ=BytesT(outsize), location=MEMORY
- )
+ output_node = context.new_internal_variable(BytesT(outsize))
bool_ty = BoolT()
@@ -1712,8 +1706,8 @@ def _build_create_IR(self, expr, args, context, value, salt, revert_on_failure):
return [
"seq",
["mstore", buf, forwarder_preamble],
- ["mstore", ["add", buf, preamble_length], aligned_target],
- ["mstore", ["add", buf, preamble_length + 20], forwarder_post],
+ ["mstore", add_ofst(buf, preamble_length), aligned_target],
+ ["mstore", add_ofst(buf, preamble_length + 20), forwarder_post],
_create_ir(value, buf, buf_len, salt, revert_on_failure),
]
@@ -1822,9 +1816,7 @@ def _build_create_IR(
# pretend we allocated enough memory for the encoder
# (we didn't, but we are clobbering unused memory so it's safe.)
bufsz = to_encode.typ.abi_type.size_bound()
- argbuf = IRnode.from_list(
- context.new_internal_variable(get_type_for_exact_size(bufsz)), location=MEMORY
- )
+ argbuf = context.new_internal_variable(get_type_for_exact_size(bufsz))
# return a complex expression which writes to memory and returns
# the length of the encoded data
@@ -2071,13 +2063,17 @@ def build_IR(self, expr, args, kwargs, context):
# clobber val, and return it as a pointer
[
"seq",
- ["mstore", ["sub", buf + n_digits, i], i],
- ["set", val, ["sub", buf + n_digits, i]],
+ ["mstore", ["sub", add_ofst(buf, n_digits), i], i],
+ ["set", val, ["sub", add_ofst(buf, n_digits), i]],
"break",
],
[
"seq",
- ["mstore", ["sub", buf + n_digits, i], ["add", 48, ["mod", val, 10]]],
+ [
+ "mstore",
+ ["sub", add_ofst(buf, n_digits), i],
+ ["add", 48, ["mod", val, 10]],
+ ],
["set", val, ["div", val, 10]],
],
],
@@ -2093,7 +2089,7 @@ def build_IR(self, expr, args, kwargs, context):
ret = [
"if",
["eq", val, 0],
- ["seq", ["mstore", buf + 1, ord("0")], ["mstore", buf, 1], buf],
+ ["seq", ["mstore", add_ofst(buf, 1), ord("0")], ["mstore", buf, 1], buf],
["seq", ret, val],
]
@@ -2271,7 +2267,7 @@ def build_IR(self, expr, args, kwargs, context):
ret = ["seq"]
ret.append(["mstore", buf, method_id])
- encode = abi_encode(buf + 32, args_as_tuple, context, buflen, returns_len=True)
+ encode = abi_encode(add_ofst(buf, 32), args_as_tuple, context, buflen, returns_len=True)
else:
method_id = method_id_int("log(string,bytes)")
@@ -2285,7 +2281,9 @@ def build_IR(self, expr, args, kwargs, context):
ret.append(["mstore", schema_buf, len(schema)])
# TODO use Expr.make_bytelike, or better have a `bytestring` IRnode type
- ret.append(["mstore", schema_buf + 32, bytes_to_int(schema.ljust(32, b"\x00"))])
+ ret.append(
+ ["mstore", add_ofst(schema_buf, 32), bytes_to_int(schema.ljust(32, b"\x00"))]
+ )
payload_buflen = args_abi_t.size_bound()
payload_t = BytesT(payload_buflen)
@@ -2293,7 +2291,7 @@ def build_IR(self, expr, args, kwargs, context):
# 32 bytes extra space for the method id
payload_buf = context.new_internal_variable(payload_t)
encode_payload = abi_encode(
- payload_buf + 32, args_as_tuple, context, payload_buflen, returns_len=True
+ add_ofst(payload_buf, 32), args_as_tuple, context, payload_buflen, returns_len=True
)
ret.append(["mstore", payload_buf, encode_payload])
@@ -2308,11 +2306,13 @@ def build_IR(self, expr, args, kwargs, context):
buflen = 32 + args_as_tuple.typ.abi_type.size_bound()
buf = context.new_internal_variable(get_type_for_exact_size(buflen))
ret.append(["mstore", buf, method_id])
- encode = abi_encode(buf + 32, args_as_tuple, context, buflen, returns_len=True)
+ encode = abi_encode(add_ofst(buf, 32), args_as_tuple, context, buflen, returns_len=True)
# debug address that tooling uses
CONSOLE_ADDRESS = 0x000000000000000000636F6E736F6C652E6C6F67
- ret.append(["staticcall", "gas", CONSOLE_ADDRESS, buf + 28, ["add", 4, encode], 0, 0])
+ ret.append(
+ ["staticcall", "gas", CONSOLE_ADDRESS, add_ofst(buf, 28), ["add", 4, encode], 0, 0]
+ )
return IRnode.from_list(ret, annotation="print:" + sig)
@@ -2415,15 +2415,19 @@ def build_IR(self, expr, args, kwargs, context):
# <32 bytes length> | <4 bytes method_id> |
# write the unaligned method_id first, then we will
# overwrite the 28 bytes of zeros with the bytestring length
- ret += [["mstore", buf + 4, method_id]]
+ ret += [["mstore", add_ofst(buf, 4), method_id]]
# abi encode, and grab length as stack item
- length = abi_encode(buf + 36, encode_input, context, returns_len=True, bufsz=maxlen)
+ length = abi_encode(
+ add_ofst(buf, 36), encode_input, context, returns_len=True, bufsz=maxlen
+ )
# write the output length to where bytestring stores its length
ret += [["mstore", buf, ["add", length, 4]]]
else:
# abi encode and grab length as stack item
- length = abi_encode(buf + 32, encode_input, context, returns_len=True, bufsz=maxlen)
+ length = abi_encode(
+ add_ofst(buf, 32), encode_input, context, returns_len=True, bufsz=maxlen
+ )
# write the output length to where bytestring stores its length
ret += [["mstore", buf, length]]
@@ -2508,13 +2512,12 @@ def build_IR(self, expr, args, kwargs, context):
# input validation
output_buf = context.new_internal_variable(wrapped_typ)
- output = IRnode.from_list(output_buf, typ=wrapped_typ, location=MEMORY)
# sanity check buffer size for wrapped output type will not buffer overflow
assert wrapped_typ.memory_bytes_required == output_typ.memory_bytes_required
- ret.append(make_setter(output, to_decode))
+ ret.append(make_setter(output_buf, to_decode))
- ret.append(output)
+ ret.append(output_buf)
# finalize. set the type and location for the return buffer.
# (note: unwraps the tuple type if necessary)
ret = IRnode.from_list(ret, typ=output_typ, location=MEMORY)
diff --git a/vyper/codegen/abi_encoder.py b/vyper/codegen/abi_encoder.py
index b227161d94..09a22cd857 100644
--- a/vyper/codegen/abi_encoder.py
+++ b/vyper/codegen/abi_encoder.py
@@ -159,9 +159,9 @@ def abi_encoding_matches_vyper(typ):
# the abi_encode routine will push the output len onto the stack,
# otherwise it will return 0 items to the stack.
def abi_encode(dst, ir_node, context, bufsz, returns_len=False):
- # TODO change dst to be an IRnode so it has type info to begin with.
- # setting the typ of dst to ir_node.typ is a footgun.
+ # cast dst to the type of the input so that make_setter works
dst = IRnode.from_list(dst, typ=ir_node.typ, location=MEMORY)
+
abi_t = dst.typ.abi_type
size_bound = abi_t.size_bound()
diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py
index 5ac8cdd758..42488f06da 100644
--- a/vyper/codegen/context.py
+++ b/vyper/codegen/context.py
@@ -3,7 +3,8 @@
from dataclasses import dataclass
from typing import Any, Optional
-from vyper.codegen.ir_node import Encoding
+from vyper.codegen.ir_node import Encoding, IRnode
+from vyper.compiler.settings import get_global_settings
from vyper.evm.address_space import MEMORY, AddrSpace
from vyper.exceptions import CompilerPanic, StateAccessViolation
from vyper.semantics.types import VyperType
@@ -14,6 +15,17 @@ class Constancy(enum.Enum):
Constant = 1
+@dataclass(frozen=True)
+class Alloca:
+ name: str
+ offset: int
+ typ: VyperType
+ size: int
+
+ def __post_init__(self):
+ assert self.typ.memory_bytes_required == self.size
+
+
# Function variable
@dataclass
class VariableRecord:
@@ -27,6 +39,9 @@ class VariableRecord:
blockscopes: Optional[list] = None
defined_at: Any = None
is_internal: bool = False
+ alloca: Optional[Alloca] = None
+
+ # the following members are probably dead
is_immutable: bool = False
is_transient: bool = False
data_offset: Optional[int] = None
@@ -43,6 +58,20 @@ def __repr__(self):
ret["allocated"] = self.typ.memory_bytes_required
return f"VariableRecord({ret})"
+ def as_ir_node(self):
+ ret = IRnode.from_list(
+ self.pos,
+ typ=self.typ,
+ annotation=self.name,
+ encoding=self.encoding,
+ mutable=self.mutable,
+ location=self.location,
+ )
+ ret._referenced_variables = {self}
+ if self.alloca is not None:
+ ret.passthrough_metadata["alloca"] = self.alloca
+ return ret
+
# compilation context for a function
class Context:
@@ -91,6 +120,8 @@ def __init__(
# either the constructor, or called from the constructor
self.is_ctor_context = is_ctor_context
+ self.settings = get_global_settings()
+
def is_constant(self):
return self.constancy is Constancy.Constant or self.in_range_expr
@@ -140,10 +171,7 @@ def internal_memory_scope(self):
(k, v) for k, v in self.vars.items() if v.is_internal and scope_id in v.blockscopes
]
for name, var in released:
- n = var.typ.memory_bytes_required
- assert n == var.size
- self.memory_allocator.deallocate_memory(var.pos, n)
- del self.vars[name]
+ self.deallocate_variable(name, var)
# Remove block scopes
self._scopes.remove(scope_id)
@@ -164,36 +192,68 @@ def block_scope(self):
# Remove all variables that have specific scope_id attached
released = [(k, v) for k, v in self.vars.items() if scope_id in v.blockscopes]
for name, var in released:
- n = var.typ.memory_bytes_required
- # sanity check the type's size hasn't changed since allocation.
- assert n == var.size
- self.memory_allocator.deallocate_memory(var.pos, n)
- del self.vars[name]
+ self.deallocate_variable(name, var)
# Remove block scopes
self._scopes.remove(scope_id)
- def _new_variable(
- self, name: str, typ: VyperType, var_size: int, is_internal: bool, is_mutable: bool = True
- ) -> int:
- var_pos = self.memory_allocator.allocate_memory(var_size)
+ def deallocate_variable(self, varname, var):
+ assert varname == var.name
+
+ # sanity check the type's size hasn't changed since allocation.
+ n = var.typ.memory_bytes_required
+ assert n == var.size
- assert var_pos + var_size <= self.memory_allocator.size_of_mem, "function frame overrun"
+ if self.settings.experimental_codegen:
+ # do not deallocate at this stage because this will break
+ # analysis in venom; venom will do its own alloc/dealloc/analysis.
+ pass
+ else:
+ self.memory_allocator.deallocate_memory(var.pos, var.size)
- self.vars[name] = VariableRecord(
+ del self.vars[var.name]
+
+ def _new_variable(
+ self,
+ name: str,
+ typ: VyperType,
+ is_internal: bool,
+ is_mutable: bool = True,
+ internal_function=False,
+ ) -> IRnode:
+ size = typ.memory_bytes_required
+
+ ofst = self.memory_allocator.allocate_memory(size)
+ assert ofst + size <= self.memory_allocator.size_of_mem, "function frame overrun"
+
+ pos = ofst
+ alloca = None
+ if self.settings.experimental_codegen:
+ # convert it into an abstract pointer
+ if internal_function:
+ pos = f"$palloca_{ofst}_{size}"
+ else:
+ pos = f"$alloca_{ofst}_{size}"
+ alloca = Alloca(name=name, offset=ofst, typ=typ, size=size)
+
+ var = VariableRecord(
name=name,
- pos=var_pos,
+ pos=pos,
typ=typ,
- size=var_size,
+ size=size,
mutable=is_mutable,
blockscopes=self._scopes.copy(),
is_internal=is_internal,
+ alloca=alloca,
)
- return var_pos
+ self.vars[name] = var
+ return var.as_ir_node()
- def new_variable(self, name: str, typ: VyperType, is_mutable: bool = True) -> int:
+ def new_variable(
+ self, name: str, typ: VyperType, is_mutable: bool = True, internal_function=False
+ ) -> IRnode:
"""
- Allocate memory for a user-defined variable.
+ Allocate memory for a user-defined variable and return an IR node referencing it.
Arguments
---------
@@ -208,8 +268,9 @@ def new_variable(self, name: str, typ: VyperType, is_mutable: bool = True) -> in
Memory offset for the variable
"""
- var_size = typ.memory_bytes_required
- return self._new_variable(name, typ, var_size, False, is_mutable=is_mutable)
+ return self._new_variable(
+ name, typ, is_internal=False, is_mutable=is_mutable, internal_function=internal_function
+ )
def fresh_varname(self, name: str) -> str:
"""
@@ -219,8 +280,7 @@ def fresh_varname(self, name: str) -> str:
self._internal_var_iter += 1
return f"{name}{t}"
- # do we ever allocate immutable internal variables?
- def new_internal_variable(self, typ: VyperType) -> int:
+ def new_internal_variable(self, typ: VyperType) -> IRnode:
"""
Allocate memory for an internal variable.
@@ -237,10 +297,9 @@ def new_internal_variable(self, typ: VyperType) -> int:
# internal variable names begin with a number sign so there is no chance for collision
name = self.fresh_varname("#internal")
- var_size = typ.memory_bytes_required
- return self._new_variable(name, typ, var_size, True)
+ return self._new_variable(name, typ, is_internal=True)
- def lookup_var(self, varname):
+ def lookup_var(self, varname) -> VariableRecord:
return self.vars[varname]
# Pretty print constancy for error messages
diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py
index aca524cd6a..69ffb2bfd6 100644
--- a/vyper/codegen/expr.py
+++ b/vyper/codegen/expr.py
@@ -168,17 +168,7 @@ def parse_Name(self):
if self.expr.id == "self":
return IRnode.from_list(["address"], typ=AddressT())
elif self.expr.id in self.context.vars:
- var = self.context.vars[self.expr.id]
- ret = IRnode.from_list(
- var.pos,
- typ=var.typ,
- location=var.location, # either 'memory' or 'calldata' storage is handled above.
- encoding=var.encoding,
- annotation=self.expr.id,
- mutable=var.mutable,
- )
- ret._referenced_variables = {var}
- return ret
+ return self.context.lookup_var(self.expr.id).as_ir_node()
elif (varinfo := self.expr._expr_info.var_info) is not None:
if varinfo.is_constant:
diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py
index 1e50886baf..607872b052 100644
--- a/vyper/codegen/external_call.py
+++ b/vyper/codegen/external_call.py
@@ -1,9 +1,11 @@
+import copy
from dataclasses import dataclass
import vyper.utils as util
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.core import (
_freshname,
+ add_ofst,
calculate_type_for_external_return,
check_assign,
check_external_call,
@@ -54,7 +56,7 @@ def _pack_arguments(fn_type, args, context):
buf_t = get_type_for_exact_size(buflen)
buf = context.new_internal_variable(buf_t)
- args_ofst = buf + 28
+ args_ofst = add_ofst(buf, 28)
args_len = args_abi_t.size_bound() + 4
abi_signature = fn_type.name + dst_tuple_t.abi_type.selector_name()
@@ -69,7 +71,7 @@ def _pack_arguments(fn_type, args, context):
pack_args.append(["mstore", buf, util.method_id_int(abi_signature)])
if len(args) != 0:
- pack_args.append(abi_encode(buf + 32, args_as_tuple, context, bufsz=buflen))
+ pack_args.append(abi_encode(add_ofst(buf, 32), args_as_tuple, context, bufsz=buflen))
return buf, pack_args, args_ofst, args_len
@@ -93,13 +95,11 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp
encoding = Encoding.ABI
- buf = IRnode.from_list(
- buf,
- typ=wrapped_return_t,
- location=MEMORY,
- encoding=encoding,
- annotation=f"{expr.node_source_code} returndata buffer",
- )
+ assert buf.location == MEMORY
+ buf = copy.copy(buf)
+ buf.typ = wrapped_return_t
+ buf.encoding = encoding
+ buf.annotation = f"{expr.node_source_code} returndata buffer"
unpacker = ["seq"]
@@ -117,7 +117,6 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp
# unpack strictly
if needs_clamp(wrapped_return_t, encoding):
return_buf = context.new_internal_variable(wrapped_return_t)
- return_buf = IRnode.from_list(return_buf, typ=wrapped_return_t, location=MEMORY)
# note: make_setter does ABI decoding and clamps
unpacker.append(make_setter(return_buf, buf))
diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py
index fe706699bb..a9b4a93025 100644
--- a/vyper/codegen/function_definitions/external_function.py
+++ b/vyper/codegen/function_definitions/external_function.py
@@ -11,7 +11,7 @@
)
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.codegen.stmt import parse_body
-from vyper.evm.address_space import CALLDATA, DATA, MEMORY
+from vyper.evm.address_space import CALLDATA, DATA
from vyper.semantics.types import TupleT
from vyper.semantics.types.function import ContractFunctionT
from vyper.utils import calc_mem_gas
@@ -35,8 +35,7 @@ def _register_function_args(func_t: ContractFunctionT, context: Context) -> list
if needs_clamp(arg.typ, Encoding.ABI):
# allocate a memory slot for it and copy
- p = context.new_variable(arg.name, arg.typ, is_mutable=False)
- dst = IRnode(p, typ=arg.typ, location=MEMORY)
+ dst = context.new_variable(arg.name, arg.typ, is_mutable=False)
copy_arg = make_setter(dst, arg_ir)
copy_arg.ast_source = arg.ast_source
@@ -94,9 +93,7 @@ def handler_for(calldata_kwargs, default_kwargs):
for i, arg_meta in enumerate(calldata_kwargs):
k = func_t.n_positional_args + i
- dst = context.lookup_var(arg_meta.name).pos
-
- lhs = IRnode(dst, location=MEMORY, typ=arg_meta.typ)
+ lhs = context.lookup_var(arg_meta.name).as_ir_node()
rhs = get_element_ptr(calldata_kwargs_ofst, k, array_bounds_check=False)
@@ -105,8 +102,7 @@ def handler_for(calldata_kwargs, default_kwargs):
ret.append(copy_arg)
for x in default_kwargs:
- dst = context.lookup_var(x.name).pos
- lhs = IRnode(dst, location=MEMORY, typ=x.typ)
+ lhs = context.lookup_var(x.name).as_ir_node()
lhs.ast_source = x.ast_source
kw_ast_val = func_t.default_values[x.name] # e.g. `3` in x: int = 3
rhs = Expr(kw_ast_val, context).ir_node
diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py
index cde1ec5c87..0ad993b33c 100644
--- a/vyper/codegen/function_definitions/internal_function.py
+++ b/vyper/codegen/function_definitions/internal_function.py
@@ -49,7 +49,7 @@ def generate_ir_for_internal_function(
for arg in func_t.arguments:
# allocate a variable for every arg, setting mutability
# to True to allow internal function arguments to be mutable
- context.new_variable(arg.name, arg.typ, is_mutable=True)
+ context.new_variable(arg.name, arg.typ, is_mutable=True, internal_function=True)
# Get nonreentrant lock
nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t)
diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py
index 2363de3641..1114dd46cc 100644
--- a/vyper/codegen/self_call.py
+++ b/vyper/codegen/self_call.py
@@ -77,9 +77,7 @@ def ir_for_self_call(stmt_expr, context):
if args_as_tuple.contains_self_call:
copy_args = ["seq"]
# TODO deallocate me
- tmp_args_buf = IRnode(
- context.new_internal_variable(dst_tuple_t), typ=dst_tuple_t, location=MEMORY
- )
+ tmp_args_buf = context.new_internal_variable(dst_tuple_t)
copy_args.append(
# --> args evaluate here <--
make_setter(tmp_args_buf, args_as_tuple)
diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py
index 9c5720d61c..1018a6ec43 100644
--- a/vyper/codegen/stmt.py
+++ b/vyper/codegen/stmt.py
@@ -7,6 +7,7 @@
LOAD,
STORE,
IRnode,
+ add_ofst,
clamp_le,
get_dyn_array_count,
get_element_ptr,
@@ -136,14 +137,14 @@ def _assert_reason(self, test_expr, msg):
# abi encode method_id + bytestring to `buf+32`, then
# write method_id to `buf` and get out of here
- payload_buf = buf + 32
+ payload_buf = add_ofst(buf, 32)
bufsz -= 32 # reduce buffer by size of `method_id` slot
encoded_length = abi_encode(payload_buf, msg_ir, self.context, bufsz, returns_len=True)
with encoded_length.cache_when_complex("encoded_len") as (b1, encoded_length):
revert_seq = [
"seq",
["mstore", buf, method_id],
- ["revert", buf + 28, ["add", 4, encoded_length]],
+ ["revert", add_ofst(buf, 28), ["add", 4, encoded_length]],
]
revert_seq = b1.resolve(revert_seq)
diff --git a/vyper/utils.py b/vyper/utils.py
index cf8a709997..01ae37e213 100644
--- a/vyper/utils.py
+++ b/vyper/utils.py
@@ -405,6 +405,7 @@ class SizeLimits:
MAX_AST_DECIMAL = decimal.Decimal(2**167 - 1) / DECIMAL_DIVISOR
MAX_UINT8 = 2**8 - 1
MAX_UINT256 = 2**256 - 1
+ CEILING_UINT256 = 2**256
def quantize(d: decimal.Decimal, places=MAX_DECIMAL_PLACES, rounding_mode=decimal.ROUND_DOWN):
diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py
index 0a75e567ba..a60f679a76 100644
--- a/vyper/venom/__init__.py
+++ b/vyper/venom/__init__.py
@@ -1,7 +1,7 @@
# maybe rename this `main.py` or `venom.py`
# (can have an `__init__.py` which exposes the API).
-from typing import Any, Optional
+from typing import Optional
from vyper.codegen.ir_node import IRnode
from vyper.compiler.settings import OptimizationLevel
@@ -11,13 +11,12 @@
ir_pass_optimize_unused_variables,
ir_pass_remove_unreachable_blocks,
)
-from vyper.venom.dominators import DominatorTree
from vyper.venom.function import IRFunction
from vyper.venom.ir_node_to_venom import ir_node_to_venom
-from vyper.venom.passes.constant_propagation import ir_pass_constant_propagation
from vyper.venom.passes.dft import DFTPass
from vyper.venom.passes.make_ssa import MakeSSA
-from vyper.venom.passes.normalization import NormalizationPass
+from vyper.venom.passes.mem2var import Mem2Var
+from vyper.venom.passes.sccp import SCCP
from vyper.venom.passes.simplify_cfg import SimplifyCFGPass
from vyper.venom.venom_to_assembly import VenomCompiler
@@ -56,10 +55,30 @@ def _run_passes(ctx: IRFunction, optimize: OptimizationLevel) -> None:
for entry in internals:
SimplifyCFGPass().run_pass(ctx, entry)
+ dfg = DFG.build_dfg(ctx)
+ Mem2Var().run_pass(ctx, ctx.basic_blocks[0], dfg)
+ for entry in internals:
+ Mem2Var().run_pass(ctx, entry, dfg)
+
make_ssa_pass = MakeSSA()
make_ssa_pass.run_pass(ctx, ctx.basic_blocks[0])
+
+ cfg_dirty = False
+ sccp_pass = SCCP(make_ssa_pass.dom)
+ sccp_pass.run_pass(ctx, ctx.basic_blocks[0])
+ cfg_dirty |= sccp_pass.cfg_dirty
+
for entry in internals:
make_ssa_pass.run_pass(ctx, entry)
+ sccp_pass = SCCP(make_ssa_pass.dom)
+ sccp_pass.run_pass(ctx, entry)
+ cfg_dirty |= sccp_pass.cfg_dirty
+
+ calculate_cfg(ctx)
+ SimplifyCFGPass().run_pass(ctx, ctx.basic_blocks[0])
+
+ calculate_cfg(ctx)
+ calculate_liveness(ctx)
while True:
changes = 0
diff --git a/vyper/venom/analysis.py b/vyper/venom/analysis.py
index 066a60f45e..8e4c24fea3 100644
--- a/vyper/venom/analysis.py
+++ b/vyper/venom/analysis.py
@@ -152,6 +152,10 @@ def get_uses(self, op: IRVariable) -> list[IRInstruction]:
def get_producing_instruction(self, op: IRVariable) -> Optional[IRInstruction]:
return self._dfg_outputs.get(op)
+ @property
+ def outputs(self) -> dict[IRVariable, IRInstruction]:
+ return self._dfg_outputs
+
@classmethod
def build_dfg(cls, ctx: IRFunction) -> "DFG":
dfg = cls()
diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py
index 0993fb0515..7c54145018 100644
--- a/vyper/venom/basicblock.py
+++ b/vyper/venom/basicblock.py
@@ -1,14 +1,14 @@
-from typing import TYPE_CHECKING, Any, Generator, Iterator, Optional, Union
+from typing import TYPE_CHECKING, Any, Iterator, Optional, Union
+from vyper.codegen.ir_node import IRnode
from vyper.utils import OrderedSet
# instructions which can terminate a basic block
-BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "revert", "stop", "exit"])
+BB_TERMINATORS = frozenset(["jmp", "djmp", "jnz", "ret", "return", "stop", "exit"])
VOLATILE_INSTRUCTIONS = frozenset(
[
"param",
- "alloca",
"call",
"staticcall",
"delegatecall",
@@ -190,16 +190,15 @@ class IRInstruction:
"""
opcode: str
- volatile: bool
operands: list[IROperand]
output: Optional[IROperand]
# set of live variables at this instruction
liveness: OrderedSet[IRVariable]
dup_requirements: OrderedSet[IRVariable]
- parent: Optional["IRBasicBlock"]
+ parent: "IRBasicBlock"
fence_id: int
annotation: Optional[str]
- ast_source: Optional[int]
+ ast_source: Optional[IRnode]
error_msg: Optional[str]
def __init__(
@@ -211,34 +210,36 @@ def __init__(
assert isinstance(opcode, str), "opcode must be an str"
assert isinstance(operands, list | Iterator), "operands must be a list"
self.opcode = opcode
- self.volatile = opcode in VOLATILE_INSTRUCTIONS
self.operands = list(operands) # in case we get an iterator
self.output = output
self.liveness = OrderedSet()
self.dup_requirements = OrderedSet()
- self.parent = None
self.fence_id = -1
self.annotation = None
self.ast_source = None
self.error_msg = None
- def get_label_operands(self) -> list[IRLabel]:
+ @property
+ def volatile(self) -> bool:
+ return self.opcode in VOLATILE_INSTRUCTIONS
+
+ def get_label_operands(self) -> Iterator[IRLabel]:
"""
Get all labels in instruction.
"""
- return [op for op in self.operands if isinstance(op, IRLabel)]
+ return (op for op in self.operands if isinstance(op, IRLabel))
- def get_non_label_operands(self) -> list[IROperand]:
+ def get_non_label_operands(self) -> Iterator[IROperand]:
"""
Get input operands for instruction which are not labels
"""
- return [op for op in self.operands if not isinstance(op, IRLabel)]
+ return (op for op in self.operands if not isinstance(op, IRLabel))
- def get_inputs(self) -> list[IRVariable]:
+ def get_inputs(self) -> Iterator[IRVariable]:
"""
Get all input operands for instruction.
"""
- return [op for op in self.operands if isinstance(op, IRVariable)]
+ return (op for op in self.operands if isinstance(op, IRVariable))
def get_outputs(self) -> list[IROperand]:
"""
@@ -268,7 +269,7 @@ def replace_label_operands(self, replacements: dict) -> None:
self.operands[i] = replacements[operand.value]
@property
- def phi_operands(self) -> Generator[tuple[IRLabel, IRVariable], None, None]:
+ def phi_operands(self) -> Iterator[tuple[IRLabel, IROperand]]:
"""
Get phi operands for instruction.
"""
@@ -277,9 +278,30 @@ def phi_operands(self) -> Generator[tuple[IRLabel, IRVariable], None, None]:
label = self.operands[i]
var = self.operands[i + 1]
assert isinstance(label, IRLabel), "phi operand must be a label"
- assert isinstance(var, IRVariable), "phi operand must be a variable"
+ assert isinstance(
+ var, (IRVariable, IRLiteral)
+ ), "phi operand must be a variable or literal"
yield label, var
+ def remove_phi_operand(self, label: IRLabel) -> None:
+ """
+ Remove a phi operand from the instruction.
+ """
+ assert self.opcode == "phi", "instruction must be a phi"
+ for i in range(0, len(self.operands), 2):
+ if self.operands[i] == label:
+ del self.operands[i : i + 2]
+ return
+
+ def get_ast_source(self) -> Optional[IRnode]:
+ if self.ast_source:
+ return self.ast_source
+ idx = self.parent.instructions.index(self)
+ for inst in reversed(self.parent.instructions[:idx]):
+ if inst.ast_source:
+ return inst.ast_source
+ return self.parent.parent.ast_source
+
def __repr__(self) -> str:
s = ""
if self.output:
@@ -451,6 +473,15 @@ def get_assignments(self):
"""
return [inst.output for inst in self.instructions if inst.output]
+ def get_uses(self) -> dict[IRVariable, OrderedSet[IRInstruction]]:
+ uses: dict[IRVariable, OrderedSet[IRInstruction]] = {}
+ for inst in self.instructions:
+ for op in inst.get_inputs():
+ if op not in uses:
+ uses[op] = OrderedSet()
+ uses[op].add(inst)
+ return uses
+
@property
def is_empty(self) -> bool:
"""
diff --git a/vyper/venom/bb_optimizer.py b/vyper/venom/bb_optimizer.py
index 60dd8bbee1..284a1f1b9c 100644
--- a/vyper/venom/bb_optimizer.py
+++ b/vyper/venom/bb_optimizer.py
@@ -13,7 +13,8 @@ def _optimize_unused_variables(ctx: IRFunction) -> set[IRInstruction]:
for i, inst in enumerate(bb.instructions[:-1]):
if inst.volatile:
continue
- if inst.output and inst.output not in bb.instructions[i + 1].liveness:
+ next_liveness = bb.instructions[i + 1].liveness
+ if (inst.output and inst.output not in next_liveness) or inst.opcode == "nop":
removeList.add(inst)
bb.instructions = [inst for inst in bb.instructions if inst not in removeList]
diff --git a/vyper/venom/function.py b/vyper/venom/function.py
index d1680385f5..8756642f80 100644
--- a/vyper/venom/function.py
+++ b/vyper/venom/function.py
@@ -30,7 +30,7 @@ class IRFunction:
last_variable: int
# Used during code generation
- _ast_source_stack: list[int]
+ _ast_source_stack: list[IRnode]
_error_msg_stack: list[str]
_bb_index: dict[str, int]
@@ -158,6 +158,12 @@ def remove_unreachable_blocks(self) -> int:
continue
in_labels = inst.get_label_operands()
if bb.label in in_labels:
+ inst.remove_phi_operand(bb.label)
+ op_len = len(inst.operands)
+ if op_len == 2:
+ inst.opcode = "store"
+ inst.operands = [inst.operands[1]]
+ elif op_len == 0:
out_bb.remove_instruction(inst)
return len(removed)
@@ -230,7 +236,7 @@ def pop_source(self):
self._error_msg_stack.pop()
@property
- def ast_source(self) -> Optional[int]:
+ def ast_source(self) -> Optional[IRnode]:
return self._ast_source_stack[-1] if len(self._ast_source_stack) > 0 else None
@property
diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py
index 7dc6bd2d47..775b55f9a8 100644
--- a/vyper/venom/ir_node_to_venom.py
+++ b/vyper/venom/ir_node_to_venom.py
@@ -62,9 +62,13 @@
"gasprice",
"gaslimit",
"returndatasize",
+ "mload",
"iload",
+ "istore",
"sload",
+ "sstore",
"tload",
+ "tstore",
"coinbase",
"number",
"prevrandao",
@@ -88,9 +92,6 @@
"codecopy",
"returndatacopy",
"revert",
- "istore",
- "sstore",
- "tstore",
"create",
"create2",
"addmod",
@@ -104,10 +105,14 @@
NOOP_INSTRUCTIONS = frozenset(["pass", "cleanup_repeat", "var_list", "unique_symbol"])
SymbolTable = dict[str, Optional[IROperand]]
+_global_symbols: SymbolTable = {}
# convert IRnode directly to venom
def ir_node_to_venom(ir: IRnode) -> IRFunction:
+ global _global_symbols
+ _global_symbols = {}
+
ctx = IRFunction()
_convert_ir_bb(ctx, ir, {})
@@ -234,7 +239,7 @@ def pop_source(*args, **kwargs):
@pop_source_on_return
def _convert_ir_bb(ctx, ir, symbols):
assert isinstance(ir, IRnode), ir
- global _break_target, _continue_target, current_func, var_list
+ global _break_target, _continue_target, current_func, var_list, _global_symbols
ctx.push_source(ir)
@@ -267,6 +272,7 @@ def _convert_ir_bb(ctx, ir, symbols):
# Internal definition
var_list = ir.args[0].args[1]
does_return_data = IRnode.from_list(["return_buffer"]) in var_list.args
+ _global_symbols = {}
symbols = {}
_handle_internal_func(ctx, ir, does_return_data, symbols)
for ir_node in ir.args[1:]:
@@ -274,6 +280,7 @@ def _convert_ir_bb(ctx, ir, symbols):
return ret
elif is_external:
+ _global_symbols = {}
ret = _convert_ir_bb(ctx, ir.args[0], symbols)
_append_return_args(ctx)
else:
@@ -294,12 +301,24 @@ def _convert_ir_bb(ctx, ir, symbols):
cont_ret = _convert_ir_bb(ctx, cond, symbols)
cond_block = ctx.get_basic_block()
- cond_symbols = symbols.copy()
+ saved_global_symbols = _global_symbols.copy()
+ then_block = IRBasicBlock(ctx.get_next_label("then"), ctx)
else_block = IRBasicBlock(ctx.get_next_label("else"), ctx)
- ctx.append_basic_block(else_block)
+
+ # convert "then"
+ cond_symbols = symbols.copy()
+ ctx.append_basic_block(then_block)
+ then_ret_val = _convert_ir_bb(ctx, ir.args[1], cond_symbols)
+ if isinstance(then_ret_val, IRLiteral):
+ then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val)
+
+ then_block_finish = ctx.get_basic_block()
# convert "else"
+ cond_symbols = symbols.copy()
+ _global_symbols = saved_global_symbols.copy()
+ ctx.append_basic_block(else_block)
else_ret_val = None
if len(ir.args) == 3:
else_ret_val = _convert_ir_bb(ctx, ir.args[2], cond_symbols)
@@ -309,20 +328,9 @@ def _convert_ir_bb(ctx, ir, symbols):
else_block_finish = ctx.get_basic_block()
- # convert "then"
- cond_symbols = symbols.copy()
-
- then_block = IRBasicBlock(ctx.get_next_label("then"), ctx)
- ctx.append_basic_block(then_block)
-
- then_ret_val = _convert_ir_bb(ctx, ir.args[1], cond_symbols)
- if isinstance(then_ret_val, IRLiteral):
- then_ret_val = ctx.get_basic_block().append_instruction("store", then_ret_val)
-
+ # finish the condition block
cond_block.append_instruction("jnz", cont_ret, then_block.label, else_block.label)
- then_block_finish = ctx.get_basic_block()
-
# exit bb
exit_bb = IRBasicBlock(ctx.get_next_label("if_exit"), ctx)
exit_bb = ctx.append_basic_block(exit_bb)
@@ -338,6 +346,8 @@ def _convert_ir_bb(ctx, ir, symbols):
if not then_block_finish.is_terminated:
then_block_finish.append_instruction("jmp", exit_bb.label)
+ _global_symbols = saved_global_symbols
+
return if_ret
elif ir.value == "with":
@@ -345,13 +355,12 @@ def _convert_ir_bb(ctx, ir, symbols):
ret = ctx.get_basic_block().append_instruction("store", ret)
- # Handle with nesting with same symbol
- with_symbols = symbols.copy()
-
sym = ir.args[0]
+ with_symbols = symbols.copy()
with_symbols[sym.value] = ret
return _convert_ir_bb(ctx, ir.args[2], with_symbols) # body
+
elif ir.value == "goto":
_append_jmp(ctx, IRLabel(ir.args[0].value))
elif ir.value == "djump":
@@ -423,27 +432,13 @@ def _convert_ir_bb(ctx, ir, symbols):
bb.append_instruction("dloadbytes", len_, src, dst)
return None
- elif ir.value == "mload":
- arg_0 = _convert_ir_bb(ctx, ir.args[0], symbols)
- bb = ctx.get_basic_block()
- if isinstance(arg_0, IRVariable):
- return bb.append_instruction("mload", arg_0)
-
- if isinstance(arg_0, IRLiteral):
- avar = symbols.get(f"%{arg_0.value}")
- if avar is not None:
- return bb.append_instruction("mload", avar)
-
- return bb.append_instruction("mload", arg_0)
elif ir.value == "mstore":
# some upstream code depends on reversed order of evaluation --
# to fix upstream.
- arg_1, arg_0 = _convert_ir_bb_list(ctx, reversed(ir.args), symbols)
+ val, ptr = _convert_ir_bb_list(ctx, reversed(ir.args), symbols)
- if isinstance(arg_1, IRVariable):
- symbols[f"&{arg_0.value}"] = arg_1
+ return ctx.get_basic_block().append_instruction("mstore", val, ptr)
- ctx.get_basic_block().append_instruction("mstore", arg_1, arg_0)
elif ir.value == "ceil32":
x = ir.args[0]
expanded = IRnode.from_list(["and", ["add", x, 31], ["not", 31]])
@@ -467,11 +462,13 @@ def _convert_ir_bb(ctx, ir, symbols):
elif ir.value == "repeat":
def emit_body_blocks():
- global _break_target, _continue_target
+ global _break_target, _continue_target, _global_symbols
old_targets = _break_target, _continue_target
_break_target, _continue_target = exit_block, incr_block
+ saved_global_symbols = _global_symbols.copy()
_convert_ir_bb(ctx, body, symbols.copy())
_break_target, _continue_target = old_targets
+ _global_symbols = saved_global_symbols
sym = ir.args[0]
start, end, _ = _convert_ir_bb_list(ctx, ir.args[1:4], symbols)
@@ -545,8 +542,17 @@ def emit_body_blocks():
ctx.get_basic_block().append_instruction("log", topic_count, *args)
elif isinstance(ir.value, str) and ir.value.upper() in get_opcodes():
_convert_ir_opcode(ctx, ir, symbols)
- elif isinstance(ir.value, str) and ir.value in symbols:
- return symbols[ir.value]
+ elif isinstance(ir.value, str):
+ if ir.value.startswith("$alloca") and ir.value not in _global_symbols:
+ alloca = ir.passthrough_metadata["alloca"]
+ ptr = ctx.get_basic_block().append_instruction("alloca", alloca.offset, alloca.size)
+ _global_symbols[ir.value] = ptr
+ elif ir.value.startswith("$palloca") and ir.value not in _global_symbols:
+ alloca = ir.passthrough_metadata["alloca"]
+ ptr = ctx.get_basic_block().append_instruction("store", alloca.offset)
+ _global_symbols[ir.value] = ptr
+
+ return _global_symbols.get(ir.value) or symbols.get(ir.value)
elif ir.is_literal:
return IRLiteral(ir.value)
else:
diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py
index 5d149cf003..994ab9d70d 100644
--- a/vyper/venom/passes/dft.py
+++ b/vyper/venom/passes/dft.py
@@ -22,13 +22,15 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset:
# if the instruction is a terminator, we need to place
# it at the end of the basic block
# along with all the instructions that "lead" to it
- if uses_this.opcode in BB_TERMINATORS:
- offset = len(bb.instructions)
self._process_instruction_r(bb, uses_this, offset)
if inst in self.visited_instructions:
return
self.visited_instructions.add(inst)
+ self.inst_order_num += 1
+
+ if inst.opcode in BB_TERMINATORS:
+ offset = len(bb.instructions)
if inst.opcode == "phi":
# phi instructions stay at the beginning of the basic block
@@ -45,7 +47,6 @@ def _process_instruction_r(self, bb: IRBasicBlock, inst: IRInstruction, offset:
continue
self._process_instruction_r(bb, target, offset)
- self.inst_order_num += 1
self.inst_order[inst] = self.inst_order_num + offset
def _process_basic_block(self, bb: IRBasicBlock) -> None:
diff --git a/vyper/venom/passes/mem2var.py b/vyper/venom/passes/mem2var.py
new file mode 100644
index 0000000000..9d74dfec0b
--- /dev/null
+++ b/vyper/venom/passes/mem2var.py
@@ -0,0 +1,66 @@
+from vyper.utils import OrderedSet
+from vyper.venom.analysis import DFG, calculate_cfg, calculate_liveness
+from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable
+from vyper.venom.function import IRFunction
+from vyper.venom.passes.base_pass import IRPass
+
+
+class Mem2Var(IRPass):
+ """
+ This pass promoted memory operations to variable operations, when possible.
+ It does yet do any memory aliasing analysis, so it is conservative.
+ """
+
+ ctx: IRFunction
+ defs: dict[IRVariable, OrderedSet[IRBasicBlock]]
+ dfg: DFG
+
+ def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock, dfg: DFG) -> int:
+ self.ctx = ctx
+ self.dfg = dfg
+
+ calculate_cfg(ctx)
+
+ dfg = DFG.build_dfg(ctx)
+ self.dfg = dfg
+
+ calculate_liveness(ctx)
+
+ self.var_name_count = 0
+ for var, inst in dfg.outputs.items():
+ if inst.opcode != "alloca":
+ continue
+ self._process_alloca_var(dfg, var)
+
+ return 0
+
+ def _process_alloca_var(self, dfg: DFG, var: IRVariable):
+ """
+ Process alloca allocated variable. If it is only used by mstore/mload/return
+ instructions, it is promoted to a stack variable. Otherwise, it is left as is.
+ """
+ uses = dfg.get_uses(var)
+ if all([inst.opcode == "mload" for inst in uses]):
+ return
+ elif all([inst.opcode == "mstore" for inst in uses]):
+ return
+ elif all([inst.opcode in ["mstore", "mload", "return"] for inst in uses]):
+ var_name = f"addr{var.name}_{self.var_name_count}"
+ self.var_name_count += 1
+ for inst in uses:
+ if inst.opcode == "mstore":
+ inst.opcode = "store"
+ inst.output = IRVariable(var_name)
+ inst.operands = [inst.operands[0]]
+ elif inst.opcode == "mload":
+ inst.opcode = "store"
+ inst.operands = [IRVariable(var_name)]
+ elif inst.opcode == "return":
+ bb = inst.parent
+ new_var = self.ctx.get_next_variable()
+ idx = bb.instructions.index(inst)
+ bb.insert_instruction(
+ IRInstruction("mstore", [IRVariable(var_name), inst.operands[1]], new_var),
+ idx,
+ )
+ inst.operands[1] = new_var
diff --git a/vyper/venom/passes/sccp/__init__.py b/vyper/venom/passes/sccp/__init__.py
new file mode 100644
index 0000000000..866d55e801
--- /dev/null
+++ b/vyper/venom/passes/sccp/__init__.py
@@ -0,0 +1 @@
+from vyper.venom.passes.sccp.sccp import SCCP
diff --git a/vyper/venom/passes/sccp/eval.py b/vyper/venom/passes/sccp/eval.py
new file mode 100644
index 0000000000..8acca039c0
--- /dev/null
+++ b/vyper/venom/passes/sccp/eval.py
@@ -0,0 +1,132 @@
+import operator
+from typing import Callable
+
+from vyper.utils import SizeLimits, evm_div, evm_mod, signed_to_unsigned, unsigned_to_signed
+from vyper.venom.basicblock import IROperand
+
+
+def _unsigned_to_signed(value: int) -> int:
+ if value <= SizeLimits.MAX_INT256:
+ return value # fast exit
+ else:
+ return unsigned_to_signed(value, 256)
+
+
+def _signed_to_unsigned(value: int) -> int:
+ if value >= 0:
+ return value # fast exit
+ else:
+ return signed_to_unsigned(value, 256)
+
+
+def _wrap_signed_binop(operation):
+ def wrapper(ops: list[IROperand]) -> int:
+ first = _unsigned_to_signed(ops[1].value)
+ second = _unsigned_to_signed(ops[0].value)
+ return _signed_to_unsigned(int(operation(first, second)))
+
+ return wrapper
+
+
+def _wrap_binop(operation):
+ def wrapper(ops: list[IROperand]) -> int:
+ first = ops[1].value
+ second = ops[0].value
+ return (int(operation(first, second))) & SizeLimits.MAX_UINT256
+
+ return wrapper
+
+
+def _evm_signextend(ops: list[IROperand]) -> int:
+ value = ops[0].value
+ nbytes = ops[1].value
+
+ assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds"
+
+ if nbytes > 31:
+ return value
+
+ sign_bit = 1 << (nbytes * 8 + 7)
+ if value & sign_bit:
+ value |= SizeLimits.CEILING_UINT256 - sign_bit
+ else:
+ value &= sign_bit - 1
+
+ return value
+
+
+def _evm_iszero(ops: list[IROperand]) -> int:
+ value = ops[0].value
+ assert SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds"
+ return int(value == 0) # 1 if True else 0
+
+
+def _evm_shr(ops: list[IROperand]) -> int:
+ value = ops[0].value
+ shift_len = ops[1].value
+ assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds"
+ return value >> shift_len
+
+
+def _evm_shl(ops: list[IROperand]) -> int:
+ value = ops[0].value
+ shift_len = ops[1].value
+ assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds"
+ if shift_len >= 256:
+ return 0
+ return (value << shift_len) & SizeLimits.MAX_UINT256
+
+
+def _evm_sar(ops: list[IROperand]) -> int:
+ value = _unsigned_to_signed(ops[0].value)
+ assert SizeLimits.MIN_INT256 <= value <= SizeLimits.MAX_INT256, "Value out of bounds"
+ shift_len = ops[1].value
+ return value >> shift_len
+
+
+def _evm_not(ops: list[IROperand]) -> int:
+ value = ops[0].value
+ assert 0 <= value <= SizeLimits.MAX_UINT256, "Value out of bounds"
+ return SizeLimits.MAX_UINT256 ^ value
+
+
+def _evm_exp(ops: list[IROperand]) -> int:
+ base = ops[1].value
+ exponent = ops[0].value
+
+ if base == 0:
+ return 0
+
+ return pow(base, exponent, SizeLimits.CEILING_UINT256)
+
+
+ARITHMETIC_OPS: dict[str, Callable[[list[IROperand]], int]] = {
+ "add": _wrap_binop(operator.add),
+ "sub": _wrap_binop(operator.sub),
+ "mul": _wrap_binop(operator.mul),
+ "div": _wrap_binop(evm_div),
+ "sdiv": _wrap_signed_binop(evm_div),
+ "mod": _wrap_binop(evm_mod),
+ "smod": _wrap_signed_binop(evm_mod),
+ "exp": _evm_exp,
+ "eq": _wrap_binop(operator.eq),
+ "ne": _wrap_binop(operator.ne),
+ "lt": _wrap_binop(operator.lt),
+ "le": _wrap_binop(operator.le),
+ "gt": _wrap_binop(operator.gt),
+ "ge": _wrap_binop(operator.ge),
+ "slt": _wrap_signed_binop(operator.lt),
+ "sle": _wrap_signed_binop(operator.le),
+ "sgt": _wrap_signed_binop(operator.gt),
+ "sge": _wrap_signed_binop(operator.ge),
+ "or": _wrap_binop(operator.or_),
+ "and": _wrap_binop(operator.and_),
+ "xor": _wrap_binop(operator.xor),
+ "not": _evm_not,
+ "signextend": _evm_signextend,
+ "iszero": _evm_iszero,
+ "shr": _evm_shr,
+ "shl": _evm_shl,
+ "sar": _evm_sar,
+ "store": lambda ops: ops[0].value,
+}
diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py
new file mode 100644
index 0000000000..7dfca8edd4
--- /dev/null
+++ b/vyper/venom/passes/sccp/sccp.py
@@ -0,0 +1,332 @@
+from dataclasses import dataclass
+from enum import Enum
+from functools import reduce
+from typing import Union
+
+from vyper.exceptions import CompilerPanic, StaticAssertionException
+from vyper.utils import OrderedSet
+from vyper.venom.basicblock import (
+ IRBasicBlock,
+ IRInstruction,
+ IRLabel,
+ IRLiteral,
+ IROperand,
+ IRVariable,
+)
+from vyper.venom.dominators import DominatorTree
+from vyper.venom.function import IRFunction
+from vyper.venom.passes.base_pass import IRPass
+from vyper.venom.passes.sccp.eval import ARITHMETIC_OPS
+
+
+class LatticeEnum(Enum):
+ TOP = 1
+ BOTTOM = 2
+
+
+@dataclass
+class SSAWorkListItem:
+ inst: IRInstruction
+
+
+@dataclass
+class FlowWorkItem:
+ start: IRBasicBlock
+ end: IRBasicBlock
+
+
+WorkListItem = Union[FlowWorkItem, SSAWorkListItem]
+LatticeItem = Union[LatticeEnum, IRLiteral]
+Lattice = dict[IROperand, LatticeItem]
+
+
+class SCCP(IRPass):
+ """
+ This class implements the Sparse Conditional Constant Propagation
+ algorithm by Wegman and Zadeck. It is a forward dataflow analysis
+ that propagates constant values through the IR graph. It is used
+ to optimize the IR by removing dead code and replacing variables
+ with their constant values.
+ """
+
+ ctx: IRFunction
+ dom: DominatorTree
+ uses: dict[IRVariable, OrderedSet[IRInstruction]]
+ lattice: Lattice
+ work_list: list[WorkListItem]
+ cfg_dirty: bool
+ cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]]
+
+ def __init__(self, dom: DominatorTree):
+ self.dom = dom
+ self.lattice = {}
+ self.work_list: list[WorkListItem] = []
+ self.cfg_dirty = False
+
+ def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> int:
+ self.ctx = ctx
+ self._compute_uses(self.dom)
+ self._calculate_sccp(entry)
+ self._propagate_constants()
+
+ # self._propagate_variables()
+ return 0
+
+ def _calculate_sccp(self, entry: IRBasicBlock):
+ """
+ This method is the main entry point for the SCCP algorithm. It
+ initializes the work list and the lattice and then iterates over
+ the work list until it is empty. It then visits each basic block
+ in the CFG and processes the instructions in the block.
+
+ This method does not update the IR, it only updates the lattice
+ and the work list. The `_propagate_constants()` method is responsible
+ for updating the IR with the constant values.
+ """
+ self.cfg_in_exec = {bb: OrderedSet() for bb in self.ctx.basic_blocks}
+
+ dummy = IRBasicBlock(IRLabel("__dummy_start"), self.ctx)
+ self.work_list.append(FlowWorkItem(dummy, entry))
+
+ # Initialize the lattice with TOP values for all variables
+ for v in self.uses.keys():
+ self.lattice[v] = LatticeEnum.TOP
+
+ # Iterate over the work list until it is empty
+ # Items in the work list can be either FlowWorkItem or SSAWorkListItem
+ while len(self.work_list) > 0:
+ work_item = self.work_list.pop()
+ if isinstance(work_item, FlowWorkItem):
+ self._handle_flow_work_item(work_item)
+ elif isinstance(work_item, SSAWorkListItem):
+ self._handle_SSA_work_item(work_item)
+ else:
+ raise CompilerPanic("Invalid work item type")
+
+ def _handle_flow_work_item(self, work_item: FlowWorkItem):
+ """
+ This method handles a FlowWorkItem.
+ """
+ start = work_item.start
+ end = work_item.end
+ if start in self.cfg_in_exec[end]:
+ return
+ self.cfg_in_exec[end].add(start)
+
+ for inst in end.instructions:
+ if inst.opcode == "phi":
+ self._visit_phi(inst)
+ else:
+ # Stop at the first non-phi instruction
+ # as phis are only valid at the beginning of a block
+ break
+
+ if len(self.cfg_in_exec[end]) == 1:
+ for inst in end.instructions:
+ if inst.opcode == "phi":
+ continue
+ self._visit_expr(inst)
+
+ if len(end.cfg_out) == 1:
+ self.work_list.append(FlowWorkItem(end, end.cfg_out.first()))
+
+ def _handle_SSA_work_item(self, work_item: SSAWorkListItem):
+ """
+ This method handles a SSAWorkListItem.
+ """
+ if work_item.inst.opcode == "phi":
+ self._visit_phi(work_item.inst)
+ elif len(self.cfg_in_exec[work_item.inst.parent]) > 0:
+ self._visit_expr(work_item.inst)
+
+ def _visit_phi(self, inst: IRInstruction):
+ assert inst.opcode == "phi", "Can't visit non phi instruction"
+ in_vars: list[LatticeItem] = []
+ for bb_label, var in inst.phi_operands:
+ bb = self.ctx.get_basic_block(bb_label.name)
+ if bb not in self.cfg_in_exec[inst.parent]:
+ continue
+ in_vars.append(self.lattice[var])
+ value = reduce(_meet, in_vars, LatticeEnum.TOP) # type: ignore
+ assert inst.output in self.lattice, "Got undefined var for phi"
+ if value != self.lattice[inst.output]:
+ self.lattice[inst.output] = value
+ self._add_ssa_work_items(inst)
+
+ def _visit_expr(self, inst: IRInstruction):
+ opcode = inst.opcode
+ if opcode in ["store", "alloca"]:
+ if isinstance(inst.operands[0], IRLiteral):
+ self.lattice[inst.output] = inst.operands[0] # type: ignore
+ else:
+ self.lattice[inst.output] = self.lattice[inst.operands[0]] # type: ignore
+ self._add_ssa_work_items(inst)
+ elif opcode == "jmp":
+ target = self.ctx.get_basic_block(inst.operands[0].value)
+ self.work_list.append(FlowWorkItem(inst.parent, target))
+ elif opcode == "jnz":
+ lat = self.lattice[inst.operands[0]]
+ assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}"
+ if lat == LatticeEnum.BOTTOM:
+ for out_bb in inst.parent.cfg_out:
+ self.work_list.append(FlowWorkItem(inst.parent, out_bb))
+ else:
+ if _meet(lat, IRLiteral(0)) == LatticeEnum.BOTTOM:
+ target = self.ctx.get_basic_block(inst.operands[1].name)
+ self.work_list.append(FlowWorkItem(inst.parent, target))
+ if _meet(lat, IRLiteral(1)) == LatticeEnum.BOTTOM:
+ target = self.ctx.get_basic_block(inst.operands[2].name)
+ self.work_list.append(FlowWorkItem(inst.parent, target))
+ elif opcode == "djmp":
+ lat = self.lattice[inst.operands[0]]
+ assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}"
+ if lat == LatticeEnum.BOTTOM:
+ for op in inst.operands[1:]:
+ target = self.ctx.get_basic_block(op.name)
+ self.work_list.append(FlowWorkItem(inst.parent, target))
+ elif isinstance(lat, IRLiteral):
+ raise CompilerPanic("Unimplemented djmp with literal")
+
+ elif opcode in ["param", "calldataload"]:
+ self.lattice[inst.output] = LatticeEnum.BOTTOM # type: ignore
+ self._add_ssa_work_items(inst)
+ elif opcode == "mload":
+ self.lattice[inst.output] = LatticeEnum.BOTTOM # type: ignore
+ elif opcode in ARITHMETIC_OPS:
+ self._eval(inst)
+ else:
+ if inst.output is not None:
+ self.lattice[inst.output] = LatticeEnum.BOTTOM
+
+ def _eval(self, inst) -> LatticeItem:
+ """
+ This method evaluates an arithmetic operation and returns the result.
+ At the same time it updates the lattice with the result and adds the
+ instruction to the SSA work list if the knowledge about the variable
+ changed.
+ """
+ opcode = inst.opcode
+
+ ops = []
+ for op in inst.operands:
+ if isinstance(op, IRVariable):
+ ops.append(self.lattice[op])
+ elif isinstance(op, IRLabel):
+ return LatticeEnum.BOTTOM
+ else:
+ ops.append(op)
+
+ ret = None
+ if LatticeEnum.BOTTOM in ops:
+ ret = LatticeEnum.BOTTOM
+ else:
+ if opcode in ARITHMETIC_OPS:
+ fn = ARITHMETIC_OPS[opcode]
+ ret = IRLiteral(fn(ops)) # type: ignore
+ elif len(ops) > 0:
+ ret = ops[0] # type: ignore
+ else:
+ raise CompilerPanic("Bad constant evaluation")
+
+ old_val = self.lattice.get(inst.output, LatticeEnum.TOP)
+ if old_val != ret:
+ self.lattice[inst.output] = ret # type: ignore
+ self._add_ssa_work_items(inst)
+
+ return ret # type: ignore
+
+ def _add_ssa_work_items(self, inst: IRInstruction):
+ for target_inst in self._get_uses(inst.output): # type: ignore
+ self.work_list.append(SSAWorkListItem(target_inst))
+
+ def _compute_uses(self, dom: DominatorTree):
+ """
+ This method computes the uses for each variable in the IR.
+ It iterates over the dominator tree and collects all the
+ instructions that use each variable.
+ """
+ self.uses = {}
+ for bb in dom.dfs_walk:
+ for var, insts in bb.get_uses().items():
+ self._get_uses(var).update(insts)
+
+ def _get_uses(self, var: IRVariable):
+ if var not in self.uses:
+ self.uses[var] = OrderedSet()
+ return self.uses[var]
+
+ def _propagate_constants(self):
+ """
+ This method iterates over the IR and replaces constant values
+ with their actual values. It also replaces conditional jumps
+ with unconditional jumps if the condition is a constant value.
+ """
+ for bb in self.dom.dfs_walk:
+ for inst in bb.instructions:
+ self._replace_constants(inst, self.lattice)
+
+ def _replace_constants(self, inst: IRInstruction, lattice: Lattice):
+ """
+ This method replaces constant values in the instruction with
+ their actual values. It also updates the instruction opcode in
+ case of jumps and asserts as needed.
+ """
+ if inst.opcode == "jnz":
+ lat = lattice[inst.operands[0]]
+ if isinstance(lat, IRLiteral):
+ if lat.value == 0:
+ target = inst.operands[2]
+ else:
+ target = inst.operands[1]
+ inst.opcode = "jmp"
+ inst.operands = [target]
+ self.cfg_dirty = True
+ elif inst.opcode == "assert":
+ lat = lattice[inst.operands[0]]
+ if isinstance(lat, IRLiteral):
+ if lat.value > 0:
+ inst.opcode = "nop"
+ else:
+ raise StaticAssertionException(
+ f"assertion found to fail at compile time ({inst.error_msg}).",
+ inst.get_ast_source(),
+ )
+
+ inst.operands = []
+
+ elif inst.opcode == "phi":
+ return
+
+ for i, op in enumerate(inst.operands):
+ if isinstance(op, IRVariable):
+ lat = lattice[op]
+ if isinstance(lat, IRLiteral):
+ inst.operands[i] = lat
+
+ def _propagate_variables(self):
+ """
+ Copy elimination. #NOTE: Not working yet, but it's also not needed atm.
+ """
+ for bb in self.dom.dfs_walk:
+ for inst in bb.instructions:
+ if inst.opcode == "store":
+ uses = self._get_uses(inst.output)
+ remove_inst = True
+ for usage_inst in uses:
+ if usage_inst.opcode == "phi":
+ remove_inst = False
+ continue
+ for i, op in enumerate(usage_inst.operands):
+ if op == inst.output:
+ usage_inst.operands[i] = inst.operands[0]
+ if remove_inst:
+ inst.opcode = "nop"
+ inst.operands = []
+
+
+def _meet(x: LatticeItem, y: LatticeItem) -> LatticeItem:
+ if x == LatticeEnum.TOP:
+ return y
+ if y == LatticeEnum.TOP or x == y:
+ return x
+ return LatticeEnum.BOTTOM
diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py
index 7f02ccf819..bebf2acd32 100644
--- a/vyper/venom/passes/simplify_cfg.py
+++ b/vyper/venom/passes/simplify_cfg.py
@@ -1,5 +1,7 @@
+from vyper.exceptions import CompilerPanic
from vyper.utils import OrderedSet
from vyper.venom.basicblock import IRBasicBlock
+from vyper.venom.bb_optimizer import ir_pass_remove_unreachable_blocks
from vyper.venom.function import IRFunction
from vyper.venom.passes.base_pass import IRPass
@@ -24,6 +26,11 @@ def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock):
next_bb.remove_cfg_in(b)
next_bb.add_cfg_in(a)
+ for inst in next_bb.instructions:
+ if inst.opcode != "phi":
+ break
+ inst.operands[inst.operands.index(b.label)] = a.label
+
self.ctx.basic_blocks.remove(b)
def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock):
@@ -79,4 +86,9 @@ def _collapse_chained_blocks(self, entry: IRBasicBlock):
def _run_pass(self, ctx: IRFunction, entry: IRBasicBlock) -> None:
self.ctx = ctx
- self._collapse_chained_blocks(entry)
+ for _ in range(len(ctx.basic_blocks)): # essentially `while True`
+ self._collapse_chained_blocks(entry)
+ if ir_pass_remove_unreachable_blocks(ctx) == 0:
+ break
+ else:
+ raise CompilerPanic("Too many iterations collapsing chained blocks")
diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py
index 6924628958..7b58f0d54e 100644
--- a/vyper/venom/venom_to_assembly.py
+++ b/vyper/venom/venom_to_assembly.py
@@ -353,9 +353,10 @@ def _generate_evm_for_instruction(
# Step 1: Apply instruction special stack manipulations
if opcode in ["jmp", "djmp", "jnz", "invoke"]:
- operands = inst.get_non_label_operands()
+ operands = list(inst.get_non_label_operands())
elif opcode == "alloca":
- operands = inst.operands[1:2]
+ offset, _size = inst.operands
+ operands = [offset]
# iload and istore are special cases because they can take a literal
# that is handled specialy with the _OFST macro. Look below, after the
@@ -381,7 +382,7 @@ def _generate_evm_for_instruction(
if opcode == "phi":
ret = inst.get_outputs()[0]
- phis = inst.get_inputs()
+ phis = list(inst.get_inputs())
depth = stack.get_phi_depth(phis)
# collapse the arguments to the phi node in the stack.
# example, for `%56 = %label1 %13 %label2 %14`, we will
@@ -523,6 +524,8 @@ def _generate_evm_for_instruction(
assembly.append("MSTORE")
elif opcode == "log":
assembly.extend([f"LOG{log_topic_count}"])
+ elif opcode == "nop":
+ pass
else:
raise Exception(f"Unknown opcode: {opcode}")