Skip to content

Commit 61ffba8

Browse files
anutosh491certik
authored andcommitted
Added visit_SubroutineCall for the ASR symbolic pass
1 parent f9b09dd commit 61ffba8

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

integration_tests/symbolics_09.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sympy import Symbol, pi, S
1+
from sympy import Symbol, pi, sin, cos
22
from lpython import S, i32
33

44
def addInteger(x: S, y: S, z: S, i: i32):
@@ -9,7 +9,11 @@ def call_addInteger():
99
a: S = Symbol("x")
1010
b: S = Symbol("y")
1111
c: S = pi
12-
addInteger(a, b, c, 2)
12+
d: S = sin(a)
13+
e: S = cos(b)
14+
addInteger(c, d, e, 2)
15+
addInteger(c, sin(a), cos(b), 2)
16+
addInteger(c, sin(Symbol("x")), cos(Symbol("y")), 2)
1317

1418
def main0():
1519
call_addInteger()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
695695
}
696696
}
697697

698+
void visit_SubroutineCall(const ASR::SubroutineCall_t &x) {
699+
SymbolTable* module_scope = current_scope->parent;
700+
Vec<ASR::call_arg_t> call_args;
701+
call_args.reserve(al, 1);
702+
703+
for (size_t i=0; i<x.n_args; i++) {
704+
ASR::expr_t* val = x.m_args[i].m_value;
705+
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*val) && ASR::is_a<ASR::SymbolicExpression_t>(*ASRUtils::expr_type(val))) {
706+
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(val);
707+
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc));
708+
std::string symengine_var = symengine_stack.push();
709+
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
710+
al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local,
711+
nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr,
712+
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
713+
current_scope->add_symbol(s2c(al, symengine_var), arg);
714+
for (auto &item : current_scope->get_scope()) {
715+
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
716+
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
717+
this->visit_Variable(*s);
718+
}
719+
}
720+
721+
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
722+
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);
723+
724+
ASR::call_arg_t call_arg;
725+
call_arg.loc = x.base.base.loc;
726+
call_arg.m_value = target;
727+
call_args.push_back(al, call_arg);
728+
} else {
729+
call_args.push_back(al, x.m_args[i]);
730+
}
731+
}
732+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, x.m_name,
733+
x.m_name, call_args.p, call_args.n, nullptr));
734+
pass_result.push_back(al, stmt);
735+
}
736+
698737
void visit_Print(const ASR::Print_t &x) {
699738
std::vector<ASR::expr_t*> print_tmp;
700739
SymbolTable* module_scope = current_scope->parent;
@@ -739,6 +778,25 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
739778
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
740779
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);
741780

781+
// Now create the FunctionCall node for basic_str
782+
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
783+
Vec<ASR::call_arg_t> call_args;
784+
call_args.reserve(al, 1);
785+
ASR::call_arg_t call_arg;
786+
call_arg.loc = x.base.base.loc;
787+
call_arg.m_value = target;
788+
call_args.push_back(al, call_arg);
789+
ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc,
790+
basic_str_sym, basic_str_sym, call_args.p, call_args.n,
791+
ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr));
792+
print_tmp.push_back(function_call);
793+
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
794+
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
795+
ASR::symbol_t *var_sym = nullptr;
796+
this->visit_Cast(*cast_t);
797+
var_sym = current_scope->get_symbol(symengine_stack.pop());
798+
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));
799+
742800
// Now create the FunctionCall node for basic_str
743801
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
744802
Vec<ASR::call_arg_t> call_args;

0 commit comments

Comments
 (0)