Skip to content

Commit 7772a0b

Browse files
authored
Merge pull request #2425 from Thirumalai-Shaktivel/simd_02
[C] Simd changes from LFortran
2 parents 03ce89b + 25ca1c8 commit 7772a0b

File tree

6 files changed

+62
-29
lines changed

6 files changed

+62
-29
lines changed

src/libasr/asr_utils.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ void update_call_args(Allocator &al, SymbolTable *current_scope, bool implicit_i
190190
}
191191
return sym;
192192
}
193-
193+
194194
void handle_Var(ASR::expr_t* arg_expr, ASR::expr_t** expr_to_replace) {
195195
if (ASR::is_a<ASR::Var_t>(*arg_expr)) {
196196
ASR::Var_t* arg_var = ASR::down_cast<ASR::Var_t>(arg_expr);
@@ -1521,7 +1521,12 @@ void make_ArrayBroadcast_t_util(Allocator& al, const Location& loc,
15211521
if (ret_type == nullptr) {
15221522
// TODO: Construct appropriate return type here
15231523
// For now simply coping the type from expr1
1524-
ret_type = expr1_type;
1524+
if (ASRUtils::is_simd_array(expr1)) {
1525+
// TODO: Make this more general; do not check for SIMDArray
1526+
ret_type = ASRUtils::duplicate_type(al, expr1_type);
1527+
} else {
1528+
ret_type = expr1_type;
1529+
}
15251530
}
15261531
expr2 = ASRUtils::EXPR(ASR::make_ArrayBroadcast_t(al, loc, expr2, dest_shape, ret_type, value));
15271532

src/libasr/asr_utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4731,6 +4731,12 @@ inline ASR::ttype_t* make_Pointer_t_util(Allocator& al, const Location& loc, ASR
47314731

47324732
int64_t compute_trailing_zeros(int64_t number);
47334733

4734+
static inline bool is_simd_array(ASR::expr_t *v) {
4735+
return (ASR::is_a<ASR::Array_t>(*expr_type(v)) &&
4736+
ASR::down_cast<ASR::Array_t>(expr_type(v))->m_physical_type
4737+
== ASR::array_physical_typeType::SIMDArray);
4738+
}
4739+
47344740
} // namespace ASRUtils
47354741

47364742
} // namespace LCompilers

src/libasr/codegen/asr_to_c.cpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,32 +1204,24 @@ R"( // Initialise Numpy
12041204
src = this->check_tmp_buffer() + out;
12051205
}
12061206

1207-
void visit_ArrayBroadcast(const ASR::ArrayBroadcast_t& x) {
1207+
void visit_ArrayBroadcast(const ASR::ArrayBroadcast_t &x) {
12081208
/*
12091209
!LF$ attributes simd :: A
12101210
real :: A(8)
12111211
A = 1
12121212
We need to generate:
12131213
a = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
12141214
*/
1215-
12161215
CHECK_FAST_C(compiler_options, x)
1217-
if (x.m_value) {
1218-
ASR::expr_t* value = x.m_value;
1219-
LCOMPILERS_ASSERT(ASR::is_a<ASR::ArrayConstant_t>(*value));
1220-
ASR::ArrayConstant_t* array_const = ASR::down_cast<ASR::ArrayConstant_t>(value);
1221-
std::string array_const_str = "{";
1222-
for( size_t i = 0; i < array_const->n_args; i++ ) {
1223-
ASR::expr_t* array_const_arg = array_const->m_args[i];
1224-
this->visit_expr(*array_const_arg);
1225-
array_const_str += src + ", ";
1226-
}
1227-
array_const_str.pop_back();
1228-
array_const_str.pop_back();
1229-
array_const_str += "}";
1230-
1231-
src = array_const_str;
1216+
size_t size = ASRUtils::get_fixed_size_of_array(x.m_type);
1217+
std::string array_const_str = "{";
1218+
for( size_t i = 0; i < size; i++ ) {
1219+
this->visit_expr(*x.m_array);
1220+
array_const_str += src;
1221+
if (i < size - 1) array_const_str += ", ";
12321222
}
1223+
array_const_str += "}";
1224+
src = array_const_str;
12331225
}
12341226

12351227
void visit_ArraySize(const ASR::ArraySize_t& x) {

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,13 +1049,14 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
10491049
}
10501050

10511051
void visit_ArrayPhysicalCast(const ASR::ArrayPhysicalCast_t& x) {
1052-
src = "";
1053-
this->visit_expr(*x.m_arg);
1054-
if (x.m_old == ASR::array_physical_typeType::FixedSizeArray &&
1052+
src = "";
1053+
this->visit_expr(*x.m_arg);
1054+
if (x.m_old == ASR::array_physical_typeType::FixedSizeArray &&
10551055
x.m_new == ASR::array_physical_typeType::SIMDArray) {
10561056
std::string arr_element_type = CUtils::get_c_type_from_ttype_t(ASRUtils::expr_type(x.m_arg));
10571057
int64_t size = ASRUtils::get_fixed_size_of_array(ASRUtils::expr_type(x.m_arg));
1058-
std::string cast = arr_element_type + " __attribute__ (( vector_size(sizeof(" + arr_element_type + ") * " + std::to_string(size) + ") ))";
1058+
std::string cast = arr_element_type + " __attribute__ (( vector_size(sizeof("
1059+
+ arr_element_type + ") * " + std::to_string(size) + ") ))";
10591060
src = "(" + cast + ") " + src;
10601061
}
10611062
}
@@ -1245,7 +1246,37 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
12451246
bool is_value_dict = ASR::is_a<ASR::Dict_t>(*m_value_type);
12461247
bool alloc_return_var = false;
12471248
std::string indent(indentation_level*indentation_spaces, ' ');
1248-
if (ASR::is_a<ASR::Var_t>(*x.m_target)) {
1249+
if (ASRUtils::is_simd_array(x.m_target)) {
1250+
this->visit_expr(*x.m_target);
1251+
target = src;
1252+
if (ASR::is_a<ASR::Var_t>(*x.m_value) ||
1253+
ASR::is_a<ASR::ArraySection_t>(*x.m_value)) {
1254+
std::string arr_element_type = CUtils::get_c_type_from_ttype_t(
1255+
ASRUtils::expr_type(x.m_value));
1256+
std::string size = std::to_string(ASRUtils::get_fixed_size_of_array(
1257+
ASRUtils::expr_type(x.m_target)));
1258+
std::string value;
1259+
if (ASR::is_a<ASR::ArraySection_t>(*x.m_value)) {
1260+
ASR::ArraySection_t *arr = ASR::down_cast<ASR::ArraySection_t>(x.m_value);
1261+
this->visit_expr(*arr->m_v);
1262+
value = src;
1263+
if(!ASR::is_a<ASR::ArrayBound_t>(*arr->m_args->m_left)) {
1264+
this->visit_expr(*arr->m_args->m_left);
1265+
int n_dims = ASRUtils::extract_n_dims_from_ttype(arr->m_type) - 1;
1266+
value += "->data + (" + src + " - "+ value +"->dims["
1267+
+ std::to_string(n_dims) +"].lower_bound)";
1268+
} else {
1269+
value += "->data";
1270+
}
1271+
} else if (ASR::is_a<ASR::Var_t>(*x.m_value)) {
1272+
this->visit_expr(*x.m_value);
1273+
value = src + "->data";
1274+
}
1275+
src = indent + "memcpy(&"+ target +", "+ value +", sizeof("
1276+
+ arr_element_type + ") * "+ size +");\n";
1277+
return;
1278+
}
1279+
} else if (ASR::is_a<ASR::Var_t>(*x.m_target)) {
12491280
ASR::Var_t* x_m_target = ASR::down_cast<ASR::Var_t>(x.m_target);
12501281
visit_Var(*x_m_target);
12511282
target = src;
@@ -1398,6 +1429,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
13981429
}
13991430

14001431
void visit_Associate(const ASR::Associate_t &x) {
1432+
std::string indent(indentation_level*indentation_spaces, ' ');
14011433
if (ASR::is_a<ASR::ArraySection_t>(*x.m_value)) {
14021434
self().visit_expr(*x.m_target);
14031435
std::string target = std::move(src);
@@ -1422,7 +1454,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
14221454
}
14231455
c += left + ":" + right + ":" + step + ",";
14241456
}
1425-
src = target + "= " + value + "; // TODO: " + value + "(" + c + ")\n";
1457+
src = indent + target + "= " + value + "; // TODO: " + value + "(" + c + ")\n";
14261458
} else {
14271459
throw CodeGenError("Associate only implemented for ArraySection so far");
14281460
}

src/libasr/pass/array_op.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,9 +1586,7 @@ class ArrayOpVisitor : public ASR::CallReplacerOnExpressionsVisitor<ArrayOpVisit
15861586
}
15871587

15881588
void visit_Assignment(const ASR::Assignment_t &x) {
1589-
if (ASR::is_a<ASR::Array_t>(*ASRUtils::expr_type(x.m_target)) &&
1590-
ASR::down_cast<ASR::Array_t>(ASRUtils::expr_type(x.m_target))->m_physical_type
1591-
== ASR::array_physical_typeType::SIMDArray) {
1589+
if (ASRUtils::is_simd_array(x.m_target)) {
15921590
return;
15931591
}
15941592
if( (ASR::is_a<ASR::Pointer_t>(*ASRUtils::expr_type(x.m_target)) &&

src/libasr/pass/print_arr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class PrintArrVisitor : public PassUtils::PassVisitor<PrintArrVisitor>
134134
std::vector<ASR::expr_t*> print_body;
135135
ASR::stmt_t* empty_print_endl;
136136
ASR::stmt_t* print_stmt;
137-
if (x.m_values[0] != nullptr && ASR::is_a<ASR::StringFormat_t>(*x.m_values[0])) {
137+
if (x.n_values > 0 && ASR::is_a<ASR::StringFormat_t>(*x.m_values[0])) {
138138
empty_print_endl = ASRUtils::STMT(ASR::make_Print_t(al, x.base.base.loc,
139139
nullptr, 0, nullptr, nullptr));
140140
ASR::StringFormat_t* format = ASR::down_cast<ASR::StringFormat_t>(x.m_values[0]);

0 commit comments

Comments
 (0)