Skip to content

Commit e510c2b

Browse files
committed
Update FMA pass to use intrinsic function
1 parent fad2ad3 commit e510c2b

File tree

4 files changed

+23
-13
lines changed

4 files changed

+23
-13
lines changed

src/libasr/pass/fma.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FMAVisitor>
118118
}
119119

120120
fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg,
121-
al, unit, pass_options, current_scope, x.base.base.loc,
122-
[&](const std::string &msg, const Location &) { throw LCompilersException(msg); });
121+
al, unit, x.base.base.loc);
123122
from_fma = false;
124123
}
125124

@@ -170,6 +169,8 @@ void pass_replace_fma(Allocator &al, ASR::TranslationUnit_t &unit,
170169
const LCompilers::PassOptions& pass_options) {
171170
FMAVisitor v(al, unit, pass_options);
172171
v.visit_TranslationUnit(unit);
172+
PassUtils::UpdateDependenciesVisitor u(al);
173+
u.visit_TranslationUnit(unit);
173174
}
174175

175176

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,6 +2392,8 @@ namespace IntrinsicScalarFunctionRegistry {
23922392
{nullptr, &UnaryIntrinsicFunction::verify_args}},
23932393
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
23942394
{nullptr, &UnaryIntrinsicFunction::verify_args}},
2395+
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
2396+
{&FMA::instantiate_FMA, &FMA::verify_args}},
23952397
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
23962398
{&Abs::instantiate_Abs, &Abs::verify_args}},
23972399
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),
@@ -2478,6 +2480,8 @@ namespace IntrinsicScalarFunctionRegistry {
24782480
"exp"},
24792481
{static_cast<int64_t>(IntrinsicScalarFunctions::Exp2),
24802482
"exp2"},
2483+
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
2484+
"fma"},
24812485
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
24822486
"expm1"},
24832487
{static_cast<int64_t>(IntrinsicScalarFunctions::ListIndex),
@@ -2552,6 +2556,7 @@ namespace IntrinsicScalarFunctionRegistry {
25522556
{"exp", {&Exp::create_Exp, &Exp::eval_Exp}},
25532557
{"exp2", {&Exp2::create_Exp2, &Exp2::eval_Exp2}},
25542558
{"expm1", {&Expm1::create_Expm1, &Expm1::eval_Expm1}},
2559+
{"fma", {&FMA::create_FMA, &FMA::eval_FMA}},
25552560
{"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}},
25562561
{"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}},
25572562
{"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}},

src/libasr/pass/pass_utils.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -666,11 +666,17 @@ namespace LCompilers {
666666
}
667667

668668
ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
669-
Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options,
670-
SymbolTable*& current_scope, Location& loc,
671-
const std::function<void (const std::string &, const Location &)> err) {
672-
ASR::symbol_t *v = import_generic_procedure("fma", "lfortran_intrinsic_optimization",
673-
al, unit, pass_options, current_scope, arg0->base.loc);
669+
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc){
670+
671+
ASRUtils::impl_function instantiate_function =
672+
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
673+
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
674+
Vec<ASR::ttype_t*> arg_types;
675+
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
676+
arg_types.reserve(al, 3);
677+
arg_types.push_back(al, ASRUtils::expr_type(arg0));
678+
arg_types.push_back(al, ASRUtils::expr_type(arg1));
679+
arg_types.push_back(al, ASRUtils::expr_type(arg2));
674680
Vec<ASR::call_arg_t> args;
675681
args.reserve(al, 3);
676682
ASR::call_arg_t arg0_, arg1_, arg2_;
@@ -680,9 +686,9 @@ namespace LCompilers {
680686
args.push_back(al, arg1_);
681687
arg2_.loc = arg2->base.loc, arg2_.m_value = arg2;
682688
args.push_back(al, arg2_);
683-
return ASRUtils::EXPR(
684-
ASRUtils::symbol_resolve_external_generic_procedure_without_eval(
685-
loc, v, args, current_scope, al, err));
689+
return instantiate_function(al, loc,
690+
unit.m_global_scope, arg_types, type, args, 0,
691+
nullptr);
686692
}
687693

688694
ASR::symbol_t* insert_fallback_vector_copy(Allocator& al, ASR::TranslationUnit_t& unit,

src/libasr/pass/pass_utils.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ namespace LCompilers {
9090
ASR::intentType var_intent=ASR::intentType::Local);
9191

9292
ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
93-
Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options,
94-
SymbolTable*& current_scope,Location& loc,
95-
const std::function<void (const std::string &, const Location &)> err);
93+
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc);
9694

9795
ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1,
9896
Allocator& al, ASR::TranslationUnit_t& unit,

0 commit comments

Comments
 (0)