@@ -41,6 +41,7 @@ enum class IntrinsicScalarFunctions : int64_t {
41
41
Exp,
42
42
Exp2,
43
43
Expm1,
44
+ FMA,
44
45
ListIndex,
45
46
Partition,
46
47
ListReverse,
@@ -93,6 +94,7 @@ inline std::string get_intrinsic_name(int x) {
93
94
INTRINSIC_NAME_CASE (Exp)
94
95
INTRINSIC_NAME_CASE (Exp2)
95
96
INTRINSIC_NAME_CASE (Expm1)
97
+ INTRINSIC_NAME_CASE (FMA)
96
98
INTRINSIC_NAME_CASE (ListIndex)
97
99
INTRINSIC_NAME_CASE (Partition)
98
100
INTRINSIC_NAME_CASE (ListReverse)
@@ -1281,6 +1283,82 @@ namespace Sign {
1281
1283
1282
1284
} // namespace Sign
1283
1285
1286
+ namespace FMA {
1287
+
1288
+ static inline void verify_args (const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) {
1289
+ ASRUtils::require_impl (x.n_args == 3 ,
1290
+ " ASR Verify: Call to FMA must have exactly 3 arguments" ,
1291
+ x.base .base .loc , diagnostics);
1292
+ ASR::ttype_t *type1 = ASRUtils::expr_type (x.m_args [0 ]);
1293
+ ASR::ttype_t *type2 = ASRUtils::expr_type (x.m_args [1 ]);
1294
+ ASR::ttype_t *type3 = ASRUtils::expr_type (x.m_args [2 ]);
1295
+ ASRUtils::require_impl ((is_real (*type1) && is_real (*type2) && is_real (*type3)),
1296
+ " ASR Verify: Arguments to FMA must be of real type" ,
1297
+ x.base .base .loc , diagnostics);
1298
+ }
1299
+
1300
+ static ASR::expr_t *eval_FMA (Allocator &al, const Location &loc,
1301
+ ASR::ttype_t * t1, Vec<ASR::expr_t *> &args) {
1302
+ double a = ASR::down_cast<ASR::RealConstant_t>(args[0 ])->m_r ;
1303
+ double b = ASR::down_cast<ASR::RealConstant_t>(args[1 ])->m_r ;
1304
+ double c = ASR::down_cast<ASR::RealConstant_t>(args[2 ])->m_r ;
1305
+ return make_ConstantWithType (make_RealConstant_t, a + b*c, t1, loc);
1306
+ }
1307
+
1308
+ static inline ASR::asr_t * create_FMA (Allocator& al, const Location& loc,
1309
+ Vec<ASR::expr_t *>& args,
1310
+ const std::function<void (const std::string &, const Location &)> err) {
1311
+ if (args.size () != 3 ) {
1312
+ err (" Intrinsic FMA function accepts exactly 3 arguments" , loc);
1313
+ }
1314
+ ASR::ttype_t *type1 = ASRUtils::expr_type (args[0 ]);
1315
+ ASR::ttype_t *type2 = ASRUtils::expr_type (args[1 ]);
1316
+ ASR::ttype_t *type3 = ASRUtils::expr_type (args[2 ]);
1317
+ if (!ASRUtils::is_real (*type1) || !ASRUtils::is_real (*type2) || !ASRUtils::is_real (*type3)) {
1318
+ err (" Argument of the FMA function must be Real" ,
1319
+ args[0 ]->base .loc );
1320
+ }
1321
+ ASR::expr_t *m_value = nullptr ;
1322
+ if (all_args_evaluated (args)) {
1323
+ Vec<ASR::expr_t *> arg_values; arg_values.reserve (al, 3 );
1324
+ arg_values.push_back (al, expr_value (args[0 ]));
1325
+ arg_values.push_back (al, expr_value (args[1 ]));
1326
+ arg_values.push_back (al, expr_value (args[2 ]));
1327
+ m_value = eval_FMA (al, loc, expr_type (args[0 ]), arg_values);
1328
+ }
1329
+ return ASR::make_IntrinsicScalarFunction_t (al, loc,
1330
+ static_cast <int64_t >(IntrinsicScalarFunctions::FMA),
1331
+ args.p , args.n , 0 , ASRUtils::expr_type (args[0 ]), m_value);
1332
+ }
1333
+
1334
+ static inline ASR::expr_t * instantiate_FMA (Allocator &al, const Location &loc,
1335
+ SymbolTable *scope, Vec<ASR::ttype_t *>& arg_types, ASR::ttype_t *return_type,
1336
+ Vec<ASR::call_arg_t >& new_args, int64_t /* overload_id*/ ,
1337
+ ASR::expr_t * compile_time_value) {
1338
+ if (compile_time_value) {
1339
+ return compile_time_value;
1340
+ }
1341
+ declare_basic_variables (" _lcompilers_optimization_fma_" + type_to_str_python (arg_types[0 ]));
1342
+ fill_func_arg (" a" , arg_types[0 ]);
1343
+ fill_func_arg (" b" , arg_types[0 ]);
1344
+ fill_func_arg (" c" , arg_types[0 ]);
1345
+ auto result = declare (fn_name, return_type, ReturnVar);
1346
+ /*
1347
+ * result = a + b*c
1348
+ */
1349
+
1350
+ ASR::expr_t *op1 = b.ElementalMul (args[1 ], args[2 ], loc);
1351
+ body.push_back (al, b.Assignment (result,
1352
+ b.ElementalAdd (args[0 ], op1, loc)));
1353
+
1354
+ ASR::symbol_t *f_sym = make_Function_t (fn_name, fn_symtab, dep, args,
1355
+ body, result, Source, Implementation, nullptr );
1356
+ scope->add_symbol (fn_name, f_sym);
1357
+ return b.Call (f_sym, new_args, return_type, nullptr );
1358
+ }
1359
+
1360
+ } // namespace FMA
1361
+
1284
1362
#define create_exp_macro (X, stdeval ) \
1285
1363
namespace X { \
1286
1364
static inline ASR::expr_t * eval_##X(Allocator &al, const Location &loc, \
@@ -2314,6 +2392,8 @@ namespace IntrinsicScalarFunctionRegistry {
2314
2392
{nullptr , &UnaryIntrinsicFunction::verify_args}},
2315
2393
{static_cast <int64_t >(IntrinsicScalarFunctions::Expm1),
2316
2394
{nullptr , &UnaryIntrinsicFunction::verify_args}},
2395
+ {static_cast <int64_t >(IntrinsicScalarFunctions::FMA),
2396
+ {&FMA::instantiate_FMA, &FMA::verify_args}},
2317
2397
{static_cast <int64_t >(IntrinsicScalarFunctions::Abs),
2318
2398
{&Abs::instantiate_Abs, &Abs::verify_args}},
2319
2399
{static_cast <int64_t >(IntrinsicScalarFunctions::Partition),
@@ -2400,6 +2480,8 @@ namespace IntrinsicScalarFunctionRegistry {
2400
2480
" exp" },
2401
2481
{static_cast <int64_t >(IntrinsicScalarFunctions::Exp2),
2402
2482
" exp2" },
2483
+ {static_cast <int64_t >(IntrinsicScalarFunctions::FMA),
2484
+ " fma" },
2403
2485
{static_cast <int64_t >(IntrinsicScalarFunctions::Expm1),
2404
2486
" expm1" },
2405
2487
{static_cast <int64_t >(IntrinsicScalarFunctions::ListIndex),
@@ -2474,6 +2556,7 @@ namespace IntrinsicScalarFunctionRegistry {
2474
2556
{" exp" , {&Exp::create_Exp, &Exp::eval_Exp}},
2475
2557
{" exp2" , {&Exp2::create_Exp2, &Exp2::eval_Exp2}},
2476
2558
{" expm1" , {&Expm1::create_Expm1, &Expm1::eval_Expm1}},
2559
+ {" fma" , {&FMA::create_FMA, &FMA::eval_FMA}},
2477
2560
{" list.index" , {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}},
2478
2561
{" list.reverse" , {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}},
2479
2562
{" list.pop" , {&ListPop::create_ListPop, &ListPop::eval_list_pop}},
0 commit comments