Skip to content

Commit 6e9fb0d

Browse files
authored
Merge pull request #2282 from Smit-create/i-1671-1
Fix FMA pass to use IntrinsicFunction
2 parents 91207c9 + 9638342 commit 6e9fb0d

File tree

6 files changed

+112
-13
lines changed

6 files changed

+112
-13
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ RUN(NAME expr_18 FAIL LABELS cpython llvm c)
480480
RUN(NAME expr_19 LABELS cpython llvm c)
481481
RUN(NAME expr_20 LABELS cpython llvm c)
482482
RUN(NAME expr_21 LABELS cpython llvm c)
483+
RUN(NAME expr_22 LABELS cpython llvm c)
483484

484485
RUN(NAME expr_01u LABELS cpython llvm c NOFAST)
485486
RUN(NAME expr_02u LABELS cpython llvm c NOFAST)

integration_tests/expr_22.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from lpython import f64
2+
3+
# test issue 1671
4+
def test_fast_fma() -> f64:
5+
a : f64 = 5.00
6+
a = a + a * 10.00
7+
assert abs(a - 55.00) < 1e-12
8+
return a
9+
10+
print(test_fast_fma())

src/libasr/pass/fma.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FMAVisitor>
118118
}
119119

120120
fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg,
121-
al, unit, pass_options, current_scope, x.base.base.loc,
122-
[&](const std::string &msg, const Location &) { throw LCompilersException(msg); });
121+
al, unit, x.base.base.loc);
123122
from_fma = false;
124123
}
125124

@@ -170,6 +169,8 @@ void pass_replace_fma(Allocator &al, ASR::TranslationUnit_t &unit,
170169
const LCompilers::PassOptions& pass_options) {
171170
FMAVisitor v(al, unit, pass_options);
172171
v.visit_TranslationUnit(unit);
172+
PassUtils::UpdateDependenciesVisitor u(al);
173+
u.visit_TranslationUnit(unit);
173174
}
174175

175176

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ enum class IntrinsicScalarFunctions : int64_t {
4141
Exp,
4242
Exp2,
4343
Expm1,
44+
FMA,
4445
ListIndex,
4546
Partition,
4647
ListReverse,
@@ -93,6 +94,7 @@ inline std::string get_intrinsic_name(int x) {
9394
INTRINSIC_NAME_CASE(Exp)
9495
INTRINSIC_NAME_CASE(Exp2)
9596
INTRINSIC_NAME_CASE(Expm1)
97+
INTRINSIC_NAME_CASE(FMA)
9698
INTRINSIC_NAME_CASE(ListIndex)
9799
INTRINSIC_NAME_CASE(Partition)
98100
INTRINSIC_NAME_CASE(ListReverse)
@@ -1281,6 +1283,82 @@ namespace Sign {
12811283

12821284
} // namespace Sign
12831285

1286+
namespace FMA {
1287+
1288+
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) {
1289+
ASRUtils::require_impl(x.n_args == 3,
1290+
"ASR Verify: Call to FMA must have exactly 3 arguments",
1291+
x.base.base.loc, diagnostics);
1292+
ASR::ttype_t *type1 = ASRUtils::expr_type(x.m_args[0]);
1293+
ASR::ttype_t *type2 = ASRUtils::expr_type(x.m_args[1]);
1294+
ASR::ttype_t *type3 = ASRUtils::expr_type(x.m_args[2]);
1295+
ASRUtils::require_impl((is_real(*type1) && is_real(*type2) && is_real(*type3)),
1296+
"ASR Verify: Arguments to FMA must be of real type",
1297+
x.base.base.loc, diagnostics);
1298+
}
1299+
1300+
static ASR::expr_t *eval_FMA(Allocator &al, const Location &loc,
1301+
ASR::ttype_t* t1, Vec<ASR::expr_t*> &args) {
1302+
double a = ASR::down_cast<ASR::RealConstant_t>(args[0])->m_r;
1303+
double b = ASR::down_cast<ASR::RealConstant_t>(args[1])->m_r;
1304+
double c = ASR::down_cast<ASR::RealConstant_t>(args[2])->m_r;
1305+
return make_ConstantWithType(make_RealConstant_t, a + b*c, t1, loc);
1306+
}
1307+
1308+
static inline ASR::asr_t* create_FMA(Allocator& al, const Location& loc,
1309+
Vec<ASR::expr_t*>& args,
1310+
const std::function<void (const std::string &, const Location &)> err) {
1311+
if (args.size() != 3) {
1312+
err("Intrinsic FMA function accepts exactly 3 arguments", loc);
1313+
}
1314+
ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]);
1315+
ASR::ttype_t *type2 = ASRUtils::expr_type(args[1]);
1316+
ASR::ttype_t *type3 = ASRUtils::expr_type(args[2]);
1317+
if (!ASRUtils::is_real(*type1) || !ASRUtils::is_real(*type2) || !ASRUtils::is_real(*type3)) {
1318+
err("Argument of the FMA function must be Real",
1319+
args[0]->base.loc);
1320+
}
1321+
ASR::expr_t *m_value = nullptr;
1322+
if (all_args_evaluated(args)) {
1323+
Vec<ASR::expr_t*> arg_values; arg_values.reserve(al, 3);
1324+
arg_values.push_back(al, expr_value(args[0]));
1325+
arg_values.push_back(al, expr_value(args[1]));
1326+
arg_values.push_back(al, expr_value(args[2]));
1327+
m_value = eval_FMA(al, loc, expr_type(args[0]), arg_values);
1328+
}
1329+
return ASR::make_IntrinsicScalarFunction_t(al, loc,
1330+
static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
1331+
args.p, args.n, 0, ASRUtils::expr_type(args[0]), m_value);
1332+
}
1333+
1334+
static inline ASR::expr_t* instantiate_FMA(Allocator &al, const Location &loc,
1335+
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
1336+
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/,
1337+
ASR::expr_t* compile_time_value) {
1338+
if (compile_time_value) {
1339+
return compile_time_value;
1340+
}
1341+
declare_basic_variables("_lcompilers_optimization_fma_" + type_to_str_python(arg_types[0]));
1342+
fill_func_arg("a", arg_types[0]);
1343+
fill_func_arg("b", arg_types[0]);
1344+
fill_func_arg("c", arg_types[0]);
1345+
auto result = declare(fn_name, return_type, ReturnVar);
1346+
/*
1347+
* result = a + b*c
1348+
*/
1349+
1350+
ASR::expr_t *op1 = b.ElementalMul(args[1], args[2], loc);
1351+
body.push_back(al, b.Assignment(result,
1352+
b.ElementalAdd(args[0], op1, loc)));
1353+
1354+
ASR::symbol_t *f_sym = make_Function_t(fn_name, fn_symtab, dep, args,
1355+
body, result, Source, Implementation, nullptr);
1356+
scope->add_symbol(fn_name, f_sym);
1357+
return b.Call(f_sym, new_args, return_type, nullptr);
1358+
}
1359+
1360+
} // namespace FMA
1361+
12841362
#define create_exp_macro(X, stdeval) \
12851363
namespace X { \
12861364
static inline ASR::expr_t* eval_##X(Allocator &al, const Location &loc, \
@@ -2314,6 +2392,8 @@ namespace IntrinsicScalarFunctionRegistry {
23142392
{nullptr, &UnaryIntrinsicFunction::verify_args}},
23152393
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
23162394
{nullptr, &UnaryIntrinsicFunction::verify_args}},
2395+
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
2396+
{&FMA::instantiate_FMA, &FMA::verify_args}},
23172397
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
23182398
{&Abs::instantiate_Abs, &Abs::verify_args}},
23192399
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),
@@ -2400,6 +2480,8 @@ namespace IntrinsicScalarFunctionRegistry {
24002480
"exp"},
24012481
{static_cast<int64_t>(IntrinsicScalarFunctions::Exp2),
24022482
"exp2"},
2483+
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
2484+
"fma"},
24032485
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
24042486
"expm1"},
24052487
{static_cast<int64_t>(IntrinsicScalarFunctions::ListIndex),
@@ -2474,6 +2556,7 @@ namespace IntrinsicScalarFunctionRegistry {
24742556
{"exp", {&Exp::create_Exp, &Exp::eval_Exp}},
24752557
{"exp2", {&Exp2::create_Exp2, &Exp2::eval_Exp2}},
24762558
{"expm1", {&Expm1::create_Expm1, &Expm1::eval_Expm1}},
2559+
{"fma", {&FMA::create_FMA, &FMA::eval_FMA}},
24772560
{"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}},
24782561
{"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}},
24792562
{"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}},

src/libasr/pass/pass_utils.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -666,11 +666,17 @@ namespace LCompilers {
666666
}
667667

668668
ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
669-
Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options,
670-
SymbolTable*& current_scope, Location& loc,
671-
const std::function<void (const std::string &, const Location &)> err) {
672-
ASR::symbol_t *v = import_generic_procedure("fma", "lfortran_intrinsic_optimization",
673-
al, unit, pass_options, current_scope, arg0->base.loc);
669+
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc){
670+
671+
ASRUtils::impl_function instantiate_function =
672+
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
673+
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
674+
Vec<ASR::ttype_t*> arg_types;
675+
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
676+
arg_types.reserve(al, 3);
677+
arg_types.push_back(al, ASRUtils::expr_type(arg0));
678+
arg_types.push_back(al, ASRUtils::expr_type(arg1));
679+
arg_types.push_back(al, ASRUtils::expr_type(arg2));
674680
Vec<ASR::call_arg_t> args;
675681
args.reserve(al, 3);
676682
ASR::call_arg_t arg0_, arg1_, arg2_;
@@ -680,9 +686,9 @@ namespace LCompilers {
680686
args.push_back(al, arg1_);
681687
arg2_.loc = arg2->base.loc, arg2_.m_value = arg2;
682688
args.push_back(al, arg2_);
683-
return ASRUtils::EXPR(
684-
ASRUtils::symbol_resolve_external_generic_procedure_without_eval(
685-
loc, v, args, current_scope, al, err));
689+
return instantiate_function(al, loc,
690+
unit.m_global_scope, arg_types, type, args, 0,
691+
nullptr);
686692
}
687693

688694
ASR::symbol_t* insert_fallback_vector_copy(Allocator& al, ASR::TranslationUnit_t& unit,

src/libasr/pass/pass_utils.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ namespace LCompilers {
9090
ASR::intentType var_intent=ASR::intentType::Local);
9191

9292
ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
93-
Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options,
94-
SymbolTable*& current_scope,Location& loc,
95-
const std::function<void (const std::string &, const Location &)> err);
93+
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc);
9694

9795
ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1,
9896
Allocator& al, ASR::TranslationUnit_t& unit,

0 commit comments

Comments
 (0)