Skip to content

Commit 37cc6eb

Browse files
anutosh491certik
authored andcommitted
Added support for casting within visit_SubroutineCall
1 parent 46df3e5 commit 37cc6eb

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

integration_tests/symbolics_09.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def call_addInteger():
1313
e: S = cos(b)
1414
addInteger(c, d, e, 2)
1515
addInteger(c, sin(a), cos(b), 2)
16-
addInteger(c, sin(Symbol("x")), cos(Symbol("y")), 2)
16+
addInteger(pi, sin(Symbol("x")), cos(Symbol("y")), 2)
1717

1818
def main0():
1919
call_addInteger()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,17 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
721721
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg));
722722
process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target);
723723

724+
ASR::call_arg_t call_arg;
725+
call_arg.loc = x.base.base.loc;
726+
call_arg.m_value = target;
727+
call_args.push_back(al, call_arg);
728+
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
729+
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
730+
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
731+
this->visit_Cast(*cast_t);
732+
ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop());
733+
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));
734+
724735
ASR::call_arg_t call_arg;
725736
call_arg.loc = x.base.base.loc;
726737
call_arg.m_value = target;
@@ -793,9 +804,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
793804
} else if (ASR::is_a<ASR::Cast_t>(*val)) {
794805
ASR::Cast_t* cast_t = ASR::down_cast<ASR::Cast_t>(val);
795806
if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return;
796-
ASR::symbol_t *var_sym = nullptr;
797807
this->visit_Cast(*cast_t);
798-
var_sym = current_scope->get_symbol(symengine_stack.pop());
808+
ASR::symbol_t *var_sym = current_scope->get_symbol(symengine_stack.pop());
799809
ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, var_sym));
800810

801811
// Now create the FunctionCall node for basic_str

0 commit comments

Comments
 (0)