Skip to content

Commit 984ca76

Browse files
authored
Merge pull request #2284 from Shaikh-Ubaid/syntactic_sugar_for_array
Syntactic sugar for Array annotation
2 parents 6e9fb0d + 3c71a96 commit 984ca76

File tree

4 files changed

+65
-4
lines changed

4 files changed

+65
-4
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ RUN(NAME array_01 LABELS cpython llvm wasm c)
424424
RUN(NAME array_02 LABELS cpython wasm c)
425425
RUN(NAME array_03 LABELS cpython llvm c)
426426
RUN(NAME array_04 LABELS cpython llvm c)
427+
RUN(NAME array_05 LABELS cpython llvm c)
427428
RUN(NAME bindc_01 LABELS cpython llvm c)
428429
RUN(NAME bindc_02 LABELS cpython llvm c)
429430
RUN(NAME bindc_04 LABELS llvm c NOFAST)

integration_tests/array_05.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from lpython import i32, f64, Array
2+
from numpy import empty, int32, float64
3+
4+
5+
def test_1():
6+
y: Array[f64, 3] = empty([3], dtype=float64)
7+
y[0] = 3.14
8+
y[1] = -4.14
9+
y[2] = 100.100
10+
11+
print(y)
12+
assert abs(y[0] - (3.14)) <= 1e-6
13+
assert abs(y[1] - (-4.14)) <= 1e-6
14+
assert abs(y[2] - (100.100)) <= 1e-6
15+
16+
def test_2():
17+
x: Array[i32, 2, 3] = empty([2, 3], dtype=int32)
18+
19+
x[0, 0] = 5
20+
x[0, 1] = -10
21+
x[0, 2] = 15
22+
x[1, 0] = 4
23+
x[1, 1] = -14
24+
x[1, 2] = 100
25+
26+
print(x)
27+
assert x[0, 0] == 5
28+
assert x[0, 1] == -10
29+
assert x[0, 2] == 15
30+
assert x[1, 0] == 4
31+
assert x[1, 1] == -14
32+
assert x[1, 2] == 100
33+
34+
35+
def main0():
36+
test_1()
37+
test_2()
38+
39+
main0()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1851,17 +1851,35 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
18511851
is_allocatable, raise_error, abi, is_argument);
18521852
return ASRUtils::TYPE(ASR::make_Const_t(al, loc, type));
18531853
} else {
1854+
AST::expr_t* dim_info = s->m_slice;
1855+
1856+
if (var_annotation == "Array") {
1857+
LCOMPILERS_ASSERT(AST::is_a<AST::Tuple_t>(*s->m_slice));
1858+
AST::Tuple_t *t = AST::down_cast<AST::Tuple_t>(s->m_slice);
1859+
LCOMPILERS_ASSERT(t->n_elts >= 2);
1860+
LCOMPILERS_ASSERT(AST::is_a<AST::Name_t>(*t->m_elts[0]));
1861+
var_annotation = AST::down_cast<AST::Name_t>(t->m_elts[0])->m_id;
1862+
Vec<AST::expr_t*> dims;
1863+
dims.reserve(al, 0);
1864+
for (size_t i = 1; i < t->n_elts; i++) {
1865+
dims.push_back(al, t->m_elts[i]);
1866+
}
1867+
AST::ast_t* dim_tuple = AST::make_Tuple_t(al, t->base.base.loc, dims.p, dims.size(),
1868+
AST::expr_contextType::Load);
1869+
dim_info = AST::down_cast<AST::expr_t>(dim_tuple);
1870+
}
1871+
18541872
ASR::ttype_t* type = get_type_from_var_annotation(var_annotation,
18551873
annotation.base.loc, dims, m_args, n_args, raise_error, abi, is_argument);
18561874

1857-
if (AST::is_a<AST::Slice_t>(*s->m_slice)) {
1875+
if (AST::is_a<AST::Slice_t>(*dim_info)) {
18581876
ASR::dimension_t dim;
18591877
dim.loc = loc;
18601878
dim.m_start = nullptr;
18611879
dim.m_length = nullptr;
18621880
dims.push_back(al, dim);
1863-
} else if( is_runtime_array(s->m_slice) ) {
1864-
AST::Tuple_t* tuple_multidim = AST::down_cast<AST::Tuple_t>(s->m_slice);
1881+
} else if( is_runtime_array(dim_info) ) {
1882+
AST::Tuple_t* tuple_multidim = AST::down_cast<AST::Tuple_t>(dim_info);
18651883
for( size_t i = 0; i < tuple_multidim->n_elts; i++ ) {
18661884
if( AST::is_a<AST::Slice_t>(*tuple_multidim->m_elts[i]) ) {
18671885
ASR::dimension_t dim;
@@ -1872,7 +1890,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
18721890
}
18731891
}
18741892
} else {
1875-
this->visit_expr(*s->m_slice);
1893+
this->visit_expr(*dim_info);
18761894
ASR::expr_t *value = ASRUtils::EXPR(tmp);
18771895
fill_dims_for_asr_type(dims, value, loc);
18781896
}

src/runtime/lpython/lpython.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def __init__(self, type, dims):
7777
self._type = type
7878
self._dims = dims
7979

80+
def __class_getitem__(self, params):
81+
return Array(params[0], params[1:])
82+
8083
i1 = Type("i1")
8184
i8 = Type("i8")
8285
i16 = Type("i16")

0 commit comments

Comments
 (0)