Skip to content

Commit 0d9fff9

Browse files
authored
Merge pull request #2330 from khushi-411/lp_fix
[numpy] add fix
2 parents 5569d74 + abd6ef1 commit 0d9fff9

32 files changed

+1343
-1191
lines changed

integration_tests/elemental_13.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from lpython import f32, f64
2-
from numpy import trunc, empty, sqrt, reshape, int32, float32, float64
2+
from numpy import trunc, fix, empty, sqrt, reshape, int32, float32, float64
33

44

55
def elemental_trunc64():
@@ -60,5 +60,65 @@ def elemental_trunc32():
6060
assert abs(trunc(arraynd[i, j, k, l]) - observed[i, j, k, l]) <= eps
6161

6262

63+
def elemental_fix64():
64+
i: i32
65+
j: i32
66+
k: i32
67+
l: i32
68+
eps: f32
69+
eps = f32(1e-6)
70+
71+
arraynd: f64[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float64)
72+
73+
newshape: i32[1] = empty(1, dtype = int32)
74+
newshape[0] = 16384
75+
76+
for i in range(32):
77+
for j in range(16):
78+
for k in range(8):
79+
for l in range(4):
80+
arraynd[i, j, k, l] = f64((-1)**l) * sqrt(float(i + j + j + l))
81+
82+
observed: f64[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float64)
83+
observed = fix(arraynd)
84+
85+
observed1d: f64[16384] = empty(16384, dtype=float64)
86+
observed1d = reshape(observed, newshape)
87+
88+
array: f64[16384] = empty(16384, dtype=float64)
89+
array = reshape(arraynd, newshape)
90+
91+
for i in range(16384):
92+
assert f32(abs(fix(array[i]) - observed1d[i])) <= eps
93+
94+
95+
def elemental_fix32():
96+
i: i32
97+
j: i32
98+
k: i32
99+
l: i32
100+
eps: f32
101+
eps = f32(1e-6)
102+
103+
arraynd: f32[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float32)
104+
105+
for i in range(32):
106+
for j in range(16):
107+
for k in range(8):
108+
for l in range(4):
109+
arraynd[i, j, k, l] = f32(f64((-1)**l) * sqrt(float(i + j + j + l)))
110+
111+
observed: f32[32, 16, 8, 4] = empty((32, 16, 8, 4), dtype=float32)
112+
observed = fix(arraynd)
113+
114+
for i in range(32):
115+
for j in range(16):
116+
for k in range(8):
117+
for l in range(4):
118+
assert abs(fix(arraynd[i, j, k, l]) - observed[i, j, k, l]) <= eps
119+
120+
63121
elemental_trunc64()
64-
elemental_trunc32()
122+
elemental_trunc32()
123+
elemental_fix64()
124+
elemental_fix32()

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2799,6 +2799,7 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
27992799
SET_INTRINSIC_NAME(Exp2, "exp2");
28002800
SET_INTRINSIC_NAME(Expm1, "expm1");
28012801
SET_INTRINSIC_NAME(Trunc, "trunc");
2802+
SET_INTRINSIC_NAME(Fix, "fix");
28022803
default : {
28032804
throw LCompilersException("IntrinsicScalarFunction: `"
28042805
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)

src/libasr/codegen/asr_to_julia.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,6 +1900,7 @@ class ASRToJuliaVisitor : public ASR::BaseVisitor<ASRToJuliaVisitor>
19001900
SET_INTRINSIC_NAME(Exp2, "exp2");
19011901
SET_INTRINSIC_NAME(Expm1, "expm1");
19021902
SET_INTRINSIC_NAME(Trunc, "trunc");
1903+
SET_INTRINSIC_NAME(Fix, "fix");
19031904
default : {
19041905
throw LCompilersException("IntrinsicFunction: `"
19051906
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ enum class IntrinsicScalarFunctions : int64_t {
3939
Gamma,
4040
LogGamma,
4141
Trunc,
42+
Fix,
4243
Abs,
4344
Exp,
4445
Exp2,
@@ -98,6 +99,7 @@ inline std::string get_intrinsic_name(int x) {
9899
INTRINSIC_NAME_CASE(Gamma)
99100
INTRINSIC_NAME_CASE(LogGamma)
100101
INTRINSIC_NAME_CASE(Trunc)
102+
INTRINSIC_NAME_CASE(Fix)
101103
INTRINSIC_NAME_CASE(Abs)
102104
INTRINSIC_NAME_CASE(Exp)
103105
INTRINSIC_NAME_CASE(Exp2)
@@ -1182,6 +1184,46 @@ namespace X {
11821184

11831185
create_trunc_macro(Trunc, trunc)
11841186

1187+
namespace Fix {
1188+
static inline ASR::expr_t *eval_Fix(Allocator &al, const Location &loc,
1189+
ASR::ttype_t *t, Vec<ASR::expr_t*>& args) {
1190+
LCOMPILERS_ASSERT(args.size() == 1);
1191+
double rv = ASR::down_cast<ASR::RealConstant_t>(args[0])->m_r;
1192+
double val;
1193+
if (rv > 0.0) {
1194+
val = floor(rv);
1195+
} else {
1196+
val = ceil(rv);
1197+
}
1198+
return make_ConstantWithType(make_RealConstant_t, val, t, loc);
1199+
}
1200+
1201+
static inline ASR::asr_t* create_Fix(Allocator& al, const Location& loc,
1202+
Vec<ASR::expr_t*>& args,
1203+
const std::function<void (const std::string &, const Location &)> err) {
1204+
ASR::ttype_t *type = ASRUtils::expr_type(args[0]);
1205+
if (args.n != 1) {
1206+
err("Intrinsic `fix` accepts exactly one argument", loc);
1207+
} else if (!ASRUtils::is_real(*type)) {
1208+
err("`fix` argument of `fix` must be real",
1209+
args[0]->base.loc);
1210+
}
1211+
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args,
1212+
eval_Fix, static_cast<int64_t>(IntrinsicScalarFunctions::Fix),
1213+
0, type);
1214+
}
1215+
1216+
static inline ASR::expr_t* instantiate_Fix (Allocator &al,
1217+
const Location &loc, SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types,
1218+
ASR::ttype_t *return_type, Vec<ASR::call_arg_t>& new_args,
1219+
int64_t overload_id) {
1220+
ASR::ttype_t* arg_type = arg_types[0];
1221+
return UnaryIntrinsicFunction::instantiate_functions(al, loc, scope,
1222+
"fix", arg_type, return_type, new_args, overload_id);
1223+
}
1224+
1225+
} // namespace Fix
1226+
11851227
// `X` is the name of the function in the IntrinsicScalarFunctions enum and
11861228
// we use the same name for `create_X` and other places
11871229
// `stdeval` is the name of the function in the `std` namespace for compile
@@ -2921,6 +2963,8 @@ namespace IntrinsicScalarFunctionRegistry {
29212963
{&LogGamma::instantiate_LogGamma, &UnaryIntrinsicFunction::verify_args}},
29222964
{static_cast<int64_t>(IntrinsicScalarFunctions::Trunc),
29232965
{&Trunc::instantiate_Trunc, &UnaryIntrinsicFunction::verify_args}},
2966+
{static_cast<int64_t>(IntrinsicScalarFunctions::Fix),
2967+
{&Fix::instantiate_Fix, &UnaryIntrinsicFunction::verify_args}},
29242968
{static_cast<int64_t>(IntrinsicScalarFunctions::Sin),
29252969
{&Sin::instantiate_Sin, &UnaryIntrinsicFunction::verify_args}},
29262970
{static_cast<int64_t>(IntrinsicScalarFunctions::Cos),
@@ -3021,6 +3065,8 @@ namespace IntrinsicScalarFunctionRegistry {
30213065

30223066
{static_cast<int64_t>(IntrinsicScalarFunctions::Trunc),
30233067
"trunc"},
3068+
{static_cast<int64_t>(IntrinsicScalarFunctions::Fix),
3069+
"fix"},
30243070
{static_cast<int64_t>(IntrinsicScalarFunctions::Sin),
30253071
"sin"},
30263072
{static_cast<int64_t>(IntrinsicScalarFunctions::Cos),
@@ -3119,6 +3165,7 @@ namespace IntrinsicScalarFunctionRegistry {
31193165
eval_intrinsic_function>>& intrinsic_function_by_name_db = {
31203166
{"log_gamma", {&LogGamma::create_LogGamma, &LogGamma::eval_log_gamma}},
31213167
{"trunc", {&Trunc::create_Trunc, &Trunc::eval_Trunc}},
3168+
{"fix", {&Fix::create_Fix, &Fix::eval_Fix}},
31223169
{"sin", {&Sin::create_Sin, &Sin::eval_Sin}},
31233170
{"cos", {&Cos::create_Cos, &Cos::eval_Cos}},
31243171
{"tan", {&Tan::create_Tan, &Tan::eval_Tan}},
@@ -3180,6 +3227,7 @@ namespace IntrinsicScalarFunctionRegistry {
31803227
id_ == IntrinsicScalarFunctions::Gamma ||
31813228
id_ == IntrinsicScalarFunctions::LogGamma ||
31823229
id_ == IntrinsicScalarFunctions::Trunc ||
3230+
id_ == IntrinsicScalarFunctions::Fix ||
31833231
id_ == IntrinsicScalarFunctions::Sin ||
31843232
id_ == IntrinsicScalarFunctions::Exp ||
31853233
id_ == IntrinsicScalarFunctions::Exp2 ||

src/libasr/runtime/lfortran_intrinsics.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,26 @@ LFORTRAN_API double _lfortran_dtrunc(double x)
11461146
return trunc(x);
11471147
}
11481148

1149+
// fix -----------------------------------------------------------------------
1150+
1151+
LFORTRAN_API float _lfortran_sfix(float x)
1152+
{
1153+
if (x > 0.0) {
1154+
return floorf(x);
1155+
} else {
1156+
return ceilf(x);
1157+
}
1158+
}
1159+
1160+
LFORTRAN_API double _lfortran_dfix(double x)
1161+
{
1162+
if (x > 0.0) {
1163+
return floor(x);
1164+
} else {
1165+
return ceil(x);
1166+
}
1167+
}
1168+
11491169
// phase --------------------------------------------------------------------
11501170

11511171
LFORTRAN_API float _lfortran_cphase(float_complex_t x)

src/libasr/runtime/lfortran_intrinsics.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ LFORTRAN_API float_complex_t _lfortran_catanh(float_complex_t x);
170170
LFORTRAN_API double_complex_t _lfortran_zatanh(double_complex_t x);
171171
LFORTRAN_API float _lfortran_strunc(float x);
172172
LFORTRAN_API double _lfortran_dtrunc(double x);
173+
LFORTRAN_API float _lfortran_sfix(float x);
174+
LFORTRAN_API double _lfortran_dfix(double x);
173175
LFORTRAN_API float _lfortran_cphase(float_complex_t x);
174176
LFORTRAN_API double _lfortran_zphase(double_complex_t x);
175177
LFORTRAN_API bool _lpython_str_compare_eq(char** s1, char** s2);

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7313,7 +7313,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
73137313
if (!s) {
73147314
std::string intrinsic_name = call_name;
73157315
std::set<std::string> not_cpython_builtin = {
7316-
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc",
7316+
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand", "trunc", "fix",
73177317
"sum" // For sum called over lists
73187318
};
73197319
std::set<std::string> symbolic_functions = {

src/runtime/lpython_intrinsic_numpy.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,23 @@ def _lfortran_strunc(x: f32) -> f32:
430430
@vectorize
431431
def trunc(x: f32) -> f32:
432432
return _lfortran_strunc(x)
433+
434+
########## fix ##########
435+
436+
@ccall
437+
def _lfortran_dfix(x: f64) -> f64:
438+
pass
439+
440+
@overload
441+
@vectorize
442+
def fix(x: f64) -> f64:
443+
return _lfortran_dfix(x)
444+
445+
@ccall
446+
def _lfortran_sfix(x: f32) -> f32:
447+
pass
448+
449+
@overload
450+
@vectorize
451+
def fix(x: f32) -> f32:
452+
return _lfortran_sfix(x)

tests/reference/asr-array_01_decl-39cf894.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-array_01_decl-39cf894.stdout",
9-
"stdout_hash": "2aa47467473392c970bb1ddde961e3007d4c157bb0ea507b5e0db4a4",
9+
"stdout_hash": "b0dc16e057dc08b7ec8adac23b2d98fa29d536fca17934c2689425d8",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)