Skip to content

Commit 4c2f136

Browse files
[ASR Pass] Symbolic: Simplify basic_has_symbol to return FunctionCall
1 parent ec60a8d commit 4c2f136

File tree

1 file changed

+54
-58
lines changed

1 file changed

+54
-58
lines changed

src/libasr/pass/replace_symbolic.cpp

Lines changed: 54 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,58 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
747747
return ASRUtils::STMT(ASR::make_SubroutineCall_t(al, loc, basic_unaryop_sym,
748748
basic_unaryop_sym, call_args.p, call_args.n, nullptr));
749749
}
750+
751+
ASR::expr_t *basic_has_symbol(const Location &loc, ASR::expr_t *value_01, ASR::expr_t *value_02) {
752+
std::string fn_name = "basic_has_symbol";
753+
symbolic_dependencies.push_back(fn_name);
754+
ASR::symbol_t *basic_has_symbol_sym = current_scope->resolve_symbol(fn_name);
755+
if ( !basic_has_symbol_sym ) {
756+
std::string header = "symengine/cwrapper.h";
757+
SymbolTable* fn_symtab = al.make_new<SymbolTable>(current_scope->parent);
758+
759+
Vec<ASR::expr_t*> args; args.reserve(al, 1);
760+
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
761+
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
762+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)),
763+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
764+
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
765+
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
766+
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
767+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
768+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
769+
fn_symtab->add_symbol(s2c(al, "x"), arg2);
770+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
771+
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
772+
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
773+
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
774+
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
775+
fn_symtab->add_symbol(s2c(al, "y"), arg3);
776+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));
777+
778+
Vec<ASR::stmt_t*> body; body.reserve(al, 1);
779+
Vec<char*> dep; dep.reserve(al, 1);
780+
ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg1));
781+
basic_has_symbol_sym = ASR::down_cast<ASR::symbol_t>(ASRUtils::make_Function_t_util(al, loc,
782+
fn_symtab, s2c(al, fn_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
783+
return_var, ASR::abiType::BindC, ASR::accessType::Public,
784+
ASR::deftypeType::Interface, s2c(al, fn_name), false, false, false,
785+
false, false, nullptr, 0, false, false, false, s2c(al, header)));
786+
current_scope->parent->add_symbol(s2c(al, fn_name), basic_has_symbol_sym);
787+
}
788+
789+
Vec<ASR::call_arg_t> call_args;
790+
call_args.reserve(al, 1);
791+
ASR::call_arg_t call_arg1, call_arg2;
792+
call_arg1.loc = loc;
793+
call_arg1.m_value = handle_argument(al, loc, value_01);
794+
call_args.push_back(al, call_arg1);
795+
call_arg2.loc = loc;
796+
call_arg2.m_value = handle_argument(al, loc, value_02);
797+
call_args.push_back(al, call_arg2);
798+
return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
799+
basic_has_symbol_sym, basic_has_symbol_sym, call_args.p, call_args.n,
800+
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
801+
}
750802
/********************************** Utils *********************************/
751803

752804
void visit_Function(const ASR::Function_t &x) {
@@ -958,64 +1010,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
9581010
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
9591011
switch (static_cast<LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id)) {
9601012
case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ: {
961-
std::string name = "basic_has_symbol";
962-
symbolic_dependencies.push_back(name);
963-
if (!module_scope->get_symbol(name)) {
964-
std::string header = "symengine/cwrapper.h";
965-
SymbolTable* fn_symtab = al.make_new<SymbolTable>(module_scope);
966-
967-
Vec<ASR::expr_t*> args;
968-
args.reserve(al, 1);
969-
ASR::symbol_t* arg1 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
970-
al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar,
971-
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)),
972-
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false));
973-
fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1);
974-
ASR::symbol_t* arg2 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
975-
al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
976-
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
977-
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
978-
fn_symtab->add_symbol(s2c(al, "x"), arg2);
979-
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2)));
980-
ASR::symbol_t* arg3 = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
981-
al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In,
982-
nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)),
983-
nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
984-
fn_symtab->add_symbol(s2c(al, "y"), arg3);
985-
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3)));
986-
987-
Vec<ASR::stmt_t*> body;
988-
body.reserve(al, 1);
989-
990-
Vec<char*> dep;
991-
dep.reserve(al, 1);
992-
993-
ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable")));
994-
ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc,
995-
fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n,
996-
return_var, ASR::abiType::BindC, ASR::accessType::Public,
997-
ASR::deftypeType::Interface, s2c(al, name), false, false, false,
998-
false, false, nullptr, 0, false, false, false, s2c(al, header));
999-
ASR::symbol_t* symbol = ASR::down_cast<ASR::symbol_t>(subrout);
1000-
module_scope->add_symbol(s2c(al, name), symbol);
1001-
}
1002-
1003-
ASR::symbol_t* basic_has_symbol = module_scope->get_symbol(name);
1004-
ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]);
1005-
ASR::expr_t* value2 = handle_argument(al, loc, intrinsic_func->m_args[1]);
1006-
Vec<ASR::call_arg_t> call_args;
1007-
call_args.reserve(al, 1);
1008-
ASR::call_arg_t call_arg1, call_arg2;
1009-
call_arg1.loc = loc;
1010-
call_arg1.m_value = value1;
1011-
call_args.push_back(al, call_arg1);
1012-
call_arg2.loc = loc;
1013-
call_arg2.m_value = value2;
1014-
call_args.push_back(al, call_arg2);
1015-
return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc,
1016-
basic_has_symbol, basic_has_symbol, call_args.p, call_args.n,
1017-
ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr));
1018-
break;
1013+
return basic_has_symbol(loc, intrinsic_func->m_args[0],
1014+
intrinsic_func->m_args[1]);
10191015
}
10201016
// (sym_name, n) where n = 16, 15, ... as the right value of the
10211017
// IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM

0 commit comments

Comments
 (0)