Skip to content

Commit 11b7b5b

Browse files
authored
Merge pull request #2345 from iamdefinitelyahuman/fix-parse-sequence
Fix: memory corruption from nested internal function calls
2 parents c4cdb01 + 435ac0d commit 11b7b5b

File tree

4 files changed

+163
-80
lines changed

4 files changed

+163
-80
lines changed

tests/parser/types/test_lists.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,79 @@ def test_values(arr: int128[2][3], s: String[10]) -> (int128[2][3], String[10]):
286286

287287
c = get_contract(code)
288288
assert c.test_values([[1, 2], [3, 4], [5, 6]], "abcdef") == [[[1, 2], [3, 4], [5, 6]], "abcdef"]
289+
290+
291+
def test_nested_index_of_returned_array(get_contract):
292+
code = """
293+
@internal
294+
def inner() -> (int128, int128):
295+
return 1,2
296+
297+
@external
298+
def outer() -> int128[2]:
299+
return [333, self.inner()[0]]
300+
"""
301+
302+
c = get_contract(code)
303+
assert c.outer() == [333, 1]
304+
305+
306+
def test_nested_calls_inside_arrays(get_contract):
307+
code = """
308+
@internal
309+
def _foo(a: uint256, b: uint256[2]) -> (uint256, uint256, uint256, uint256, uint256):
310+
return 1, a, b[0], b[1], 5
311+
312+
@internal
313+
def _foo2() -> uint256:
314+
a: uint256[10] = [6,7,8,9,10,11,12,13,15,16]
315+
return 4
316+
317+
@external
318+
def foo() -> (uint256, uint256, uint256, uint256, uint256):
319+
return self._foo(2, [3, self._foo2()])
320+
"""
321+
322+
c = get_contract(code)
323+
assert c.foo() == [1, 2, 3, 4, 5]
324+
325+
326+
def test_nested_calls_inside_arrays_with_index_access(get_contract):
327+
code = """
328+
@internal
329+
def _foo(a: uint256[2], b: uint256[2]) -> (uint256, uint256, uint256, uint256, uint256):
330+
return a[1]-b[0], 2, a[0]-b[1], 8-b[1], 5
331+
332+
@internal
333+
def _foo2() -> (uint256, uint256):
334+
a: uint256[10] = [6,7,8,9,10,11,12,13,15,16]
335+
return a[6], 4
336+
337+
@external
338+
def foo() -> (uint256, uint256, uint256, uint256, uint256):
339+
return self._foo([7, self._foo2()[0]], [11, self._foo2()[1]])
340+
"""
341+
342+
c = get_contract(code)
343+
assert c.foo() == [1, 2, 3, 4, 5]
344+
345+
346+
def test_so_many_things_you_should_never_do(get_contract):
347+
code = """
348+
@internal
349+
def _foo(a: uint256[2], b: uint256[2]) -> uint256[5]:
350+
return [a[1]-b[0], 2, a[0]-b[1], 8-b[1], 5]
351+
352+
@internal
353+
def _foo2() -> (uint256, uint256):
354+
b: uint256[2] = [5, 8]
355+
a: uint256[10] = [6,7,8,9,10,11,12,13,self._foo([44,b[0]],b)[4],16]
356+
return a[6], 4
357+
358+
@external
359+
def foo() -> (uint256, uint256[3], uint256[2]):
360+
x: uint256[3] = [1, 14-self._foo2()[0], self._foo([7,self._foo2()[0]], [11,self._foo2()[1]])[2]]
361+
return 666, x, [88, self._foo2()[0]]
362+
"""
363+
c = get_contract(code)
364+
assert c.foo() == [666, [1, 2, 3], [88, 12]]

vyper/parser/expr.py

Lines changed: 76 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -949,28 +949,25 @@ def parse_Call(self):
949949
return external_call.make_external_call(self.expr, self.context)
950950

951951
def parse_List(self):
952-
if not len(self.expr.elements):
953-
return
954-
955-
def get_out_type(lll_node):
956-
if isinstance(lll_node, ListType):
957-
return get_out_type(lll_node.subtype)
958-
return lll_node.typ
952+
call_lll, multi_lll = parse_sequence(self.expr, self.expr.elements, self.context)
953+
out_type = next((i.typ for i in multi_lll if not i.typ.is_literal), multi_lll[0].typ)
954+
typ = ListType(out_type, len(self.expr.elements), is_literal=True)
955+
multi_lll = LLLnode.from_list(["multi"] + multi_lll, typ=typ, pos=getpos(self.expr))
956+
if not call_lll:
957+
return multi_lll
959958

960-
lll_node = []
961-
out_type = None
959+
lll_node = ["seq_unchecked"] + call_lll + [multi_lll]
960+
return LLLnode.from_list(lll_node, typ=typ, pos=getpos(self.expr))
962961

963-
for elt in self.expr.elements:
964-
current_lll_node = Expr(elt, self.context).lll_node
965-
if not out_type or not current_lll_node.typ.is_literal:
966-
# prefer to use a non-literal type here, because literals can be ambiguous
967-
# this should be removed altogether as we refactor types out of parser
968-
out_type = current_lll_node.typ
969-
lll_node.append(current_lll_node)
962+
def parse_Tuple(self):
963+
call_lll, multi_lll = parse_sequence(self.expr, self.expr.elements, self.context)
964+
typ = TupleType([x.typ for x in multi_lll], is_literal=True)
965+
multi_lll = LLLnode.from_list(["multi"] + multi_lll, typ=typ, pos=getpos(self.expr))
966+
if not call_lll:
967+
return multi_lll
970968

971-
return LLLnode.from_list(
972-
["multi"] + lll_node, typ=ListType(out_type, len(lll_node)), pos=getpos(self.expr),
973-
)
969+
lll_node = ["seq_unchecked"] + call_lll + [multi_lll]
970+
return LLLnode.from_list(lll_node, typ=typ, pos=getpos(self.expr))
974971

975972
@staticmethod
976973
def struct_literals(expr, name, context):
@@ -990,39 +987,6 @@ def struct_literals(expr, name, context):
990987
pos=getpos(expr),
991988
)
992989

993-
def parse_Tuple(self):
994-
if not len(self.expr.elements):
995-
return
996-
call_lll = []
997-
multi_lll = []
998-
for node in self.expr.elements:
999-
if isinstance(node, vy_ast.Call):
1000-
# for calls inside the tuple, we perform the call prior to building the tuple and
1001-
# assign it's result to memory - otherwise there is potential for memory corruption
1002-
lll_node = Expr(node, self.context).lll_node
1003-
target = LLLnode.from_list(
1004-
self.context.new_internal_variable(lll_node.typ),
1005-
typ=lll_node.typ,
1006-
location="memory",
1007-
pos=getpos(self.expr),
1008-
)
1009-
call_lll.append(make_setter(target, lll_node, "memory", pos=getpos(self.expr)))
1010-
multi_lll.append(
1011-
LLLnode.from_list(
1012-
target, typ=lll_node.typ, pos=getpos(self.expr), location="memory"
1013-
),
1014-
)
1015-
else:
1016-
multi_lll.append(Expr(node, self.context).lll_node)
1017-
1018-
typ = TupleType([x.typ for x in multi_lll], is_literal=True)
1019-
multi_lll = LLLnode.from_list(["multi"] + multi_lll, typ=typ, pos=getpos(self.expr))
1020-
if not call_lll:
1021-
return multi_lll
1022-
1023-
lll_node = ["seq_unchecked"] + call_lll + [multi_lll]
1024-
return LLLnode.from_list(lll_node, typ=typ, pos=getpos(self.expr))
1025-
1026990
# Parse an expression that results in a value
1027991
@classmethod
1028992
def parse_value_expr(cls, expr, context):
@@ -1035,3 +999,63 @@ def parse_variable_location(cls, expr, context):
1035999
if not o.location:
10361000
raise StructureException("Looking for a variable location, instead got a value", expr)
10371001
return o
1002+
1003+
1004+
def parse_sequence(base_node, elements, context):
1005+
"""
1006+
Generate an LLL node from a sequence of Vyper AST nodes, such as values inside a
1007+
list/tuple or arguments inside a call.
1008+
1009+
Arguments
1010+
---------
1011+
base_node : VyperNode
1012+
Parent node which contains the sequence being parsed.
1013+
elements : List[VyperNode]
1014+
A list of nodes within the sequence.
1015+
context : Context
1016+
Currently active local context.
1017+
1018+
Returns
1019+
-------
1020+
List[LLLNode]
1021+
LLL nodes that must execute prior to generating the actual sequence in order to
1022+
avoid memory corruption issues. This list may be empty, depending on the values
1023+
within `elements`.
1024+
List[LLLNode]
1025+
LLL nodes which collectively represent `elements`.
1026+
"""
1027+
init_lll = []
1028+
sequence_lll = []
1029+
for node in elements:
1030+
if isinstance(node, vy_ast.List):
1031+
# for nested lists, ensure the init LLL is also processed before the values
1032+
init, seq = parse_sequence(node, node.elements, context)
1033+
init_lll.extend(init)
1034+
out_type = next((i.typ for i in seq if not i.typ.is_literal), seq[0].typ)
1035+
typ = ListType(out_type, len(node.elements), is_literal=True)
1036+
multi_lll = LLLnode.from_list(["multi"] + seq, typ=typ, pos=getpos(node))
1037+
sequence_lll.append(multi_lll)
1038+
continue
1039+
1040+
lll_node = Expr(node, context).lll_node
1041+
if isinstance(node, vy_ast.Call) or (
1042+
isinstance(node, vy_ast.Subscript) and isinstance(node.value, vy_ast.Call)
1043+
):
1044+
# nodes which potentially create their own internal memory variables, and so must
1045+
# be parsed prior to generating the final sequence to avoid memory corruption
1046+
target = LLLnode.from_list(
1047+
context.new_internal_variable(lll_node.typ),
1048+
typ=lll_node.typ,
1049+
location="memory",
1050+
pos=getpos(base_node),
1051+
)
1052+
init_lll.append(make_setter(target, lll_node, "memory", pos=getpos(base_node)))
1053+
sequence_lll.append(
1054+
LLLnode.from_list(
1055+
target, typ=lll_node.typ, pos=getpos(base_node), location="memory"
1056+
),
1057+
)
1058+
else:
1059+
sequence_lll.append(lll_node)
1060+
1061+
return init_lll, sequence_lll

vyper/parser/parser_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,14 @@ def make_setter(left, right, location, pos, in_function_call=False):
534534
left = LLLnode.from_list(["sha3_32", left], typ=left.typ, location="storage_prehashed")
535535
left_token.location = "storage_prehashed"
536536
# If the right side is a literal
537-
if right.value == "multi":
538-
subs = []
537+
if right.value in ["multi", "seq_unchecked"] and right.typ.is_literal:
538+
if right.value == "seq_unchecked":
539+
# when the LLL is `seq_unchecked`, this is a literal where one or
540+
# more values must be pre-processed to avoid memory corruption
541+
subs = right.args[:-1]
542+
right = right.args[-1]
543+
else:
544+
subs = []
539545
for i in range(left.typ.count):
540546
lhs_setter = _make_array_index_setter(left, left_token, pos, location, i)
541547
subs.append(make_setter(lhs_setter, right.args[i], location, pos=pos,))

vyper/parser/self_call.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import itertools
22

3-
from vyper import ast as vy_ast
43
from vyper.codegen.abi import abi_decode
54
from vyper.exceptions import (
65
StateAccessViolation,
76
StructureException,
87
TypeCheckFailure,
98
)
109
from vyper.parser.lll_node import LLLnode
11-
from vyper.parser.parser_utils import getpos, make_setter, pack_arguments
10+
from vyper.parser.parser_utils import getpos, pack_arguments
1211
from vyper.signatures.function_signature import FunctionSignature
1312
from vyper.types import (
1413
BaseType,
@@ -57,37 +56,15 @@ def make_call(stmt_expr, context):
5756
# (x) pop return values
5857
# (x) pop local variables
5958

60-
pre_init = []
6159
pop_local_vars = []
6260
push_local_vars = []
6361
pop_return_values = []
6462
push_args = []
65-
66-
from vyper.parser.expr import Expr
67-
6863
method_name = stmt_expr.func.attr
6964

70-
expr_args = []
71-
for arg in stmt_expr.args:
72-
lll_node = Expr(arg, context).lll_node
73-
if isinstance(arg, vy_ast.Call):
74-
# if the argument is a function call, perform the call seperately and
75-
# assign it's result to memory, then reference the memory location when
76-
# building this call. otherwise there is potential for memory corruption
77-
target = LLLnode.from_list(
78-
context.new_internal_variable(lll_node.typ),
79-
typ=lll_node.typ,
80-
location="memory",
81-
pos=getpos(arg),
82-
)
83-
setter = make_setter(target, lll_node, "memory", pos=getpos(arg))
84-
expr_args.append(
85-
LLLnode.from_list(target, typ=lll_node.typ, pos=getpos(arg), location="memory")
86-
)
87-
pre_init.append(setter)
88-
else:
89-
expr_args.append(lll_node)
65+
from vyper.parser.expr import parse_sequence
9066

67+
pre_init, expr_args = parse_sequence(stmt_expr, stmt_expr.args, context)
9168
sig = FunctionSignature.lookup_sig(context.sigs, method_name, expr_args, stmt_expr, context,)
9269

9370
if context.is_constant() and sig.mutability not in ("view", "pure"):

0 commit comments

Comments
 (0)