[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request NethermindEth#163 from NethermindEth/murcake/fix-v…
Browse files Browse the repository at this point in the history
…ariable-inliner

invalidate variables' values on function calls in VariableInliner
  • Loading branch information
temyurchenko committed Dec 20, 2021
2 parents 9d66d67 + 73b4e2c commit c68941d
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 35 deletions.
1 change: 1 addition & 0 deletions tests/behaviour/for-loop-no-condition.sol
36 changes: 36 additions & 0 deletions tests/behaviour/for-loop-no-condition_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os.path as path
import shutil
from pathlib import Path
from tempfile import mkdtemp

import pytest
from starkware.starknet.compiler.compile import compile_starknet_files
from starkware.starknet.testing.state import StarknetState
from yul.main import transpile_from_solidity
from yul.starknet_utils import deploy_contract, invoke_method

WARP_ROOT = Path(__file__).parents[2]
CAIRO_PATH = WARP_ROOT / "warp" / "cairo-src"
TEST_DIR = Path(__file__).parent


def spit(fname, content):
"""Writes 'content' to the file named 'fname'"""
with open(fname, "w") as f:
f.write(content)


@pytest.mark.asyncio
async def test():
sol = TEST_DIR / "for-loop-no-condition.sol"
tmpdir = mkdtemp()
cairo = path.join(tmpdir, "for-loop-no-condition.cairo")
info = transpile_from_solidity(sol, "WARP")
spit(cairo, info["cairo_code"])
def_ = compile_starknet_files([cairo], debug_info=True, cairo_path=[CAIRO_PATH])

starknet = await StarknetState.empty()
address = await deploy_contract(starknet, info, def_)
res = await invoke_method(starknet, info, address, "f")
assert res.retdata == [32, 2, 0, 10]
shutil.rmtree(tmpdir)
187 changes: 187 additions & 0 deletions tests/golden/for-loop-no-condition.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
%lang starknet
%builtins pedersen range_check bitwise

from evm.array import validate_array
from evm.calls import calldataload, calldatasize
from evm.exec_env import ExecutionEnvironment
from evm.memory import uint256_mload, uint256_mstore
from evm.uint256 import is_eq, is_gt, is_lt, is_zero, slt, u256_add, u256_shr
from evm.yul_api import warp_return
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin, HashBuiltin
from starkware.cairo.common.default_dict import default_dict_finalize, default_dict_new
from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.uint256 import Uint256, uint256_sub

func __warp_constant_0() -> (res : Uint256):
return (Uint256(low=0, high=0))
end

@constructor
func constructor{range_check_ptr}(calldata_size, calldata_len, calldata : felt*):
alloc_locals
validate_array(calldata_size, calldata_len, calldata)
let (memory_dict) = default_dict_new(0)
let memory_dict_start = memory_dict
let msize = 0
with memory_dict, msize:
__constructor_meat()
end
default_dict_finalize(memory_dict_start, memory_dict, 0)
return ()
end

@external
func __main{bitwise_ptr : BitwiseBuiltin*, range_check_ptr}(
calldata_size, calldata_len, calldata : felt*) -> (
returndata_size, returndata_len, returndata : felt*):
alloc_locals
validate_array(calldata_size, calldata_len, calldata)
let (__fp__, _) = get_fp_and_pc()
local exec_env_ : ExecutionEnvironment = ExecutionEnvironment(calldata_size=calldata_size, calldata_len=calldata_len, calldata=calldata, returndata_size=0, returndata_len=0, returndata=cast(0, felt*), to_returndata_size=0, to_returndata_len=0, to_returndata=cast(0, felt*))
let exec_env : ExecutionEnvironment* = &exec_env_
let (memory_dict) = default_dict_new(0)
let memory_dict_start = memory_dict
let msize = 0
let termination_token = 0
with exec_env, memory_dict, msize, termination_token:
__main_meat()
end
default_dict_finalize(memory_dict_start, memory_dict, 0)
return (exec_env.to_returndata_size, exec_env.to_returndata_len, exec_env.to_returndata)
end

func __constructor_meat{memory_dict : DictAccess*, msize, range_check_ptr}() -> ():
alloc_locals
uint256_mstore(offset=Uint256(low=64, high=0), value=Uint256(low=128, high=0))
let (__warp_subexpr_0 : Uint256) = __warp_constant_0()
if __warp_subexpr_0.low + __warp_subexpr_0.high != 0:
assert 0 = 1
jmp rel 0
else:
return ()
end
end

func __warp_loop_body_0{range_check_ptr}(
_1 : Uint256, _2 : Uint256, __warp_break_0 : Uint256, var_ret : Uint256,
var_ret_1 : Uint256) -> (__warp_break_0 : Uint256, var_ret : Uint256):
alloc_locals
let (__warp_subexpr_0 : Uint256) = is_gt(
var_ret,
Uint256(low=340282366920938463463374607431768211454, high=340282366920938463463374607431768211455))
if __warp_subexpr_0.low + __warp_subexpr_0.high != 0:
assert 0 = 1
jmp rel 0
end
let (sum : Uint256) = u256_add(var_ret, var_ret_1)
let var_ret_2 : Uint256 = sum
let var_ret : Uint256 = sum
let (__warp_subexpr_2 : Uint256) = is_lt(var_ret_2, Uint256(low=10, high=0))
let (__warp_subexpr_1 : Uint256) = is_zero(__warp_subexpr_2)
if __warp_subexpr_1.low + __warp_subexpr_1.high != 0:
let __warp_break_0 : Uint256 = Uint256(low=1, high=0)
return (__warp_break_0, var_ret)
else:
return (__warp_break_0, var_ret)
end
end

func __warp_loop_0{range_check_ptr}(
_1 : Uint256, _2 : Uint256, var_ret : Uint256, var_ret_1 : Uint256) -> (var_ret : Uint256):
alloc_locals
let __warp_break_0 : Uint256 = Uint256(low=0, high=0)
let (__warp_subexpr_0 : Uint256) = is_zero(var_ret_1)
if __warp_subexpr_0.low + __warp_subexpr_0.high != 0:
return (var_ret)
end
let (__warp_break_0 : Uint256, var_ret : Uint256) = __warp_loop_body_0(
_1, _2, __warp_break_0, var_ret, var_ret_1)
if __warp_break_0.low + __warp_break_0.high != 0:
return (var_ret)
end
let (var_ret : Uint256) = __warp_loop_0(_1, _2, var_ret, var_ret_1)
return (var_ret)
end

func abi_encode_uint256{memory_dict : DictAccess*, msize, range_check_ptr}(
headStart : Uint256, value0 : Uint256) -> (tail : Uint256):
alloc_locals
let (tail : Uint256) = u256_add(headStart, Uint256(low=32, high=0))
uint256_mstore(offset=headStart, value=value0)
return (tail)
end

func __warp_block_1{
bitwise_ptr : BitwiseBuiltin*, exec_env : ExecutionEnvironment*, memory_dict : DictAccess*,
msize, range_check_ptr, termination_token}() -> ():
alloc_locals
let (__warp_subexpr_2 : Uint256) = calldatasize()
let (__warp_subexpr_1 : Uint256) = u256_add(
__warp_subexpr_2,
Uint256(low=340282366920938463463374607431768211452, high=340282366920938463463374607431768211455))
let (__warp_subexpr_0 : Uint256) = slt(__warp_subexpr_1, Uint256(low=0, high=0))
if __warp_subexpr_0.low + __warp_subexpr_0.high != 0:
assert 0 = 1
jmp rel 0
end
let var_ret : Uint256 = Uint256(low=1, high=0)
let (var_ret : Uint256) = __warp_loop_0(
Uint256(low=4, high=0), Uint256(low=0, high=0), var_ret, Uint256(low=1, high=0))
let (memPos : Uint256) = uint256_mload(Uint256(low=64, high=0))
let (__warp_subexpr_4 : Uint256) = abi_encode_uint256(memPos, var_ret)
let (__warp_subexpr_3 : Uint256) = uint256_sub(__warp_subexpr_4, memPos)
warp_return(memPos, __warp_subexpr_3)
return ()
end

func __warp_if_1{
bitwise_ptr : BitwiseBuiltin*, exec_env : ExecutionEnvironment*, memory_dict : DictAccess*,
msize, range_check_ptr, termination_token}(__warp_subexpr_0 : Uint256) -> ():
alloc_locals
if __warp_subexpr_0.low + __warp_subexpr_0.high != 0:
__warp_block_1()
return ()
else:
return ()
end
end

func __warp_block_0{
bitwise_ptr : BitwiseBuiltin*, exec_env : ExecutionEnvironment*, memory_dict : DictAccess*,
msize, range_check_ptr, termination_token}() -> ():
alloc_locals
let (__warp_subexpr_2 : Uint256) = calldataload(Uint256(low=0, high=0))
let (__warp_subexpr_1 : Uint256) = u256_shr(Uint256(low=224, high=0), __warp_subexpr_2)
let (__warp_subexpr_0 : Uint256) = is_eq(Uint256(low=638722032, high=0), __warp_subexpr_1)
__warp_if_1(__warp_subexpr_0)
return ()
end

func __warp_if_0{
bitwise_ptr : BitwiseBuiltin*, exec_env : ExecutionEnvironment*, memory_dict : DictAccess*,
msize, range_check_ptr, termination_token}(__warp_subexpr_0 : Uint256) -> ():
alloc_locals
if __warp_subexpr_0.low + __warp_subexpr_0.high != 0:
__warp_block_0()
return ()
else:
return ()
end
end

func __main_meat{
bitwise_ptr : BitwiseBuiltin*, exec_env : ExecutionEnvironment*, memory_dict : DictAccess*,
msize, range_check_ptr, termination_token}() -> ():
alloc_locals
uint256_mstore(offset=Uint256(low=64, high=0), value=Uint256(low=128, high=0))
let (__warp_subexpr_2 : Uint256) = calldatasize()
let (__warp_subexpr_1 : Uint256) = is_lt(__warp_subexpr_2, Uint256(low=4, high=0))
let (__warp_subexpr_0 : Uint256) = is_zero(__warp_subexpr_1)
__warp_if_0(__warp_subexpr_0)
if termination_token == 1:
return ()
end
assert 0 = 1
jmp rel 0
end
11 changes: 11 additions & 0 deletions tests/golden/for-loop-no-condition.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pragma solidity ^0.8.6;

contract WARP {
function f() public returns(uint ret) {
ret = 1;
for (;;) {
ret += 1;
if (ret >= 10) break;
}
}
}
6 changes: 5 additions & 1 deletion tests/golden/for-loop-with-break.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,17 @@ end
func __warp_loop_0{range_check_ptr}(
_1 : Uint256, value0 : Uint256, value1 : Uint256, var_k : Uint256) -> (var_k : Uint256):
alloc_locals
let __warp_break_0 : Uint256 = Uint256(low=0, high=0)
let (__warp_subexpr_1 : Uint256) = is_lt(var_k, value0)
let (__warp_subexpr_0 : Uint256) = is_zero(__warp_subexpr_1)
if __warp_subexpr_0.low + __warp_subexpr_0.high != 0:
return (var_k)
end
let (__warp_break_0 : Uint256, var_k : Uint256) = __warp_loop_body_0(
_1, Uint256(low=0, high=0), value1, var_k)
_1, __warp_break_0, value1, var_k)
if __warp_break_0.low + __warp_break_0.high != 0:
return (var_k)
end
let (var_k : Uint256) = __warp_loop_0(_1, value0, value1, var_k)
return (var_k)
end
Expand Down
23 changes: 18 additions & 5 deletions tests/golden/for-loop-with-nested-return.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,31 @@ func __warp_loop_0{range_check_ptr}(
return (__warp_leave_0, var, var_k_1)
end

func __warp_block_0{range_check_ptr}(var : Uint256, var_i : Uint256, var_j : Uint256) -> (
var : Uint256, var_k_1 : Uint256):
func __warp_block_0{range_check_ptr}(
__warp_leave_5 : Uint256, var : Uint256, var_i : Uint256, var_j : Uint256,
var_k_1 : Uint256) -> (__warp_leave_5 : Uint256, var : Uint256, var_k_1 : Uint256):
alloc_locals
let __warp_leave_0 : Uint256 = Uint256(low=0, high=0)
let (__warp_leave_0 : Uint256, var : Uint256, var_k_1 : Uint256) = __warp_loop_0(
Uint256(low=0, high=0), var, var_i, var_j, Uint256(low=0, high=0), Uint256(low=0, high=0))
return (var, var_k_1)
__warp_leave_0, var, var_i, var_j, Uint256(low=0, high=0), var_k_1)
if __warp_leave_0.low + __warp_leave_0.high != 0:
let __warp_leave_5 : Uint256 = Uint256(low=1, high=0)
return (__warp_leave_5, var, var_k_1)
else:
return (__warp_leave_5, var, var_k_1)
end
end

func fun_transferFrom{range_check_ptr}(var_i : Uint256, var_j : Uint256) -> (var : Uint256):
alloc_locals
let var : Uint256 = Uint256(low=0, high=0)
let (var : Uint256, var_k_1 : Uint256) = __warp_block_0(var, var_i, var_j)
let __warp_leave_5 : Uint256 = Uint256(low=0, high=0)
let var_k_1 : Uint256 = Uint256(low=0, high=0)
let (__warp_leave_5 : Uint256, var : Uint256, var_k_1 : Uint256) = __warp_block_0(
__warp_leave_5, var, var_i, var_j, var_k_1)
if __warp_leave_5.low + __warp_leave_5.high != 0:
return (var)
end
let var : Uint256 = Uint256(low=1, high=0)
return (var)
end
Expand Down
8 changes: 7 additions & 1 deletion warp/yul/ForLoopEliminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def visit_for_loop(self, node: ast.ForLoop):
assert not node.post.statements, "Loop not simplified"

with self._new_for_loop():
assert self.loop_name and self.body_name
assert self.leave_name and self.break_name

body = self.visit(node.body)
body_fun, body_stmt = extract_block_as_function(body, self.body_name)

Expand All @@ -56,6 +59,7 @@ def visit_for_loop(self, node: ast.ForLoop):
self.aux_functions.extend((body_fun, head_fun))

leave_id = ast.Identifier(self.leave_name)
call_statements: ast.Statements
if leave_id not in get_scope(body).modified_variables:
call_statements = (head_stmt,)
else:
Expand All @@ -71,6 +75,7 @@ def visit_for_loop(self, node: ast.ForLoop):
return ast.Block(call_statements)

def visit_break(self, node: ast.Break):
assert self.break_name
return ast.Block(
(
ast.Assignment(
Expand Down Expand Up @@ -114,10 +119,11 @@ def _new_for_loop(self):
def _make_loop_head(
self, condition: ast.Expression, body_stmt: ast.Statement, rec: ast.Statement
) -> ast.Block:
assert self.break_name and self.leave_name
break_id = ast.Identifier(self.break_name)
leave_id = ast.Identifier(self.leave_name)
modified_vars = get_scope(ast.Block((body_stmt,))).modified_variables
head_stmts = []
head_stmts: list[ast.Statement] = []
if break_id in modified_vars:
head_stmts.append(
ast.VariableDeclaration(
Expand Down
Loading

0 comments on commit c68941d

Please sign in to comment.