Skip to content

Commit 31a6be1

Browse files
authored
Allow passing backend optimisation flags in lpython decorator (#2201)
1 parent f2bf702 commit 31a6be1

File tree

4 files changed

+60
-24
lines changed

4 files changed

+60
-24
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ jobs:
349349
- name: Test Linux
350350
shell: bash -e -l {0}
351351
run: |
352-
ctest
352+
ctest --rerun-failed --output-on-failure
353353
./run_tests.py -s
354354
cd integration_tests
355355
./run_tests.py -b llvm c

integration_tests/lpython_decorator_01.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from numpy import array
22
from lpython import i32, f64, lpython
33

4-
@lpython
4+
@lpython(backend="c", backend_optimisation_flags=["-ffast-math", "-funroll-loops", "-O3"])
55
def fast_sum(n: i32, x: f64[:]) -> f64:
66
s: f64 = 0.0
77
i: i32

integration_tests/lpython_decorator_02.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
n = TypeVar("n")
55

6-
@lpython
6+
@lpython(backend="c", backend_optimisation_flags=["-ffast-math", "-funroll-loops"])
77
def multiply_01(n: i32, x: f64[:]) -> f64[n]:
88
i: i32
99
for i in range(n):

src/runtime/lpython/lpython.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ctypes
44
import platform
55
from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass
6+
import functools
67

78

89
# TODO: this does not seem to restrict other imports
@@ -647,37 +648,49 @@ def ccallable(f):
647648
def ccallback(f):
648649
return f
649650

650-
class lpython:
651-
"""
652-
The @lpython decorator compiles a given function using LPython.
651+
class LpythonJITCache:
653652

654-
The decorator should be used from CPython mode, i.e., when the module is
655-
being run using CPython. When possible, it is recommended to use LPython
656-
for the main program, and use the @cpython decorator from the LPython mode
657-
to access CPython features that are not supported by LPython.
658-
"""
653+
def __init__(self):
654+
self.pyfunc2compiledfunc = {}
655+
656+
def compile(self, function, backend, optimisation_flags):
657+
if function in self.pyfunc2compiledfunc:
658+
return self.pyfunc2compiledfunc[function]
659+
660+
if optimisation_flags is not None and backend is None:
661+
raise ValueError("backend must be specified if backend_optimisation_flags are provided.")
662+
663+
if backend is None:
664+
backend = "c"
659665

660-
def __init__(self, function):
661666
def get_rtlib_dir():
662667
current_dir = os.path.dirname(os.path.abspath(__file__))
663668
return os.path.join(current_dir, "..")
664669

665-
self.fn_name = function.__name__
670+
fn_name = function.__name__
666671
# Get the source code of the function
667672
source_code = getsource(function)
668673
source_code = source_code[source_code.find('\n'):]
669674

670-
dir_name = "./lpython_decorator_" + self.fn_name
675+
dir_name = "./lpython_decorator_" + fn_name
671676
if not os.path.exists(dir_name):
672677
os.mkdir(dir_name)
673-
filename = dir_name + "/" + self.fn_name
678+
filename = dir_name + "/" + fn_name
674679

675680
# Open the file for writing
676681
with open(filename + ".py", "w") as file:
677682
# Write the Python source code to the file
678683
file.write("@pythoncallable")
679684
file.write(source_code)
680685

686+
if backend != "c":
687+
raise NotImplementedError("Backend %s is not supported with @lpython yet."%(backend))
688+
689+
opt_flags = " "
690+
if optimisation_flags is not None:
691+
for opt_flag in optimisation_flags:
692+
opt_flags += opt_flag + " "
693+
681694
# ----------------------------------------------------------------------
682695
# Generate the shared library
683696
# TODO: Use LLVM instead of C backend
@@ -687,12 +700,14 @@ def get_rtlib_dir():
687700

688701
gcc_flags = ""
689702
if platform.system() == "Linux":
690-
gcc_flags = " -shared -fPIC "
703+
gcc_flags = " -shared -fPIC"
691704
elif platform.system() == "Darwin":
692-
gcc_flags = " -bundle -flat_namespace -undefined suppress "
705+
gcc_flags = " -bundle -flat_namespace -undefined suppress"
693706
else:
694707
raise NotImplementedError("Platform not implemented")
695708

709+
gcc_flags += opt_flags
710+
696711
from numpy import get_include
697712
from distutils.sysconfig import get_python_inc, get_python_lib, \
698713
get_python_version
@@ -706,17 +721,38 @@ def get_rtlib_dir():
706721

707722
# ----------------------------------------------------------------------
708723
# Compile the C file and create a shared library
724+
shared_library_name = "lpython_module_" + fn_name
709725
r = os.system("gcc -g" + gcc_flags + python_path + numpy_path +
710-
filename + ".c -o lpython_module_" + self.fn_name + ".so " +
726+
filename + ".c -o " + shared_library_name + ".so " +
711727
rt_path_01 + rt_path_02 + python_lib)
712728
assert r == 0, "Failed to create the shared library"
729+
self.pyfunc2compiledfunc[function] = (shared_library_name, fn_name)
730+
return self.pyfunc2compiledfunc[function]
713731

714-
def __call__(self, *args, **kwargs):
715-
import sys; sys.path.append('.')
716-
# import the symbol from the shared library
717-
function = getattr(__import__("lpython_module_" + self.fn_name),
718-
self.fn_name)
719-
return function(*args, **kwargs)
732+
lpython_jit_cache = LpythonJITCache()
733+
734+
# Taken from https://stackoverflow.com/a/24617244
735+
def lpython(original_function=None, backend=None, backend_optimisation_flags=None):
736+
"""
737+
The @lpython decorator compiles a given function using LPython.
738+
739+
The decorator should be used from CPython mode, i.e., when the module is
740+
being run using CPython. When possible, it is recommended to use LPython
741+
for the main program, and use the @cpython decorator from the LPython mode
742+
to access CPython features that are not supported by LPython.
743+
"""
744+
def _lpython(function):
745+
@functools.wraps(function)
746+
def __lpython(*args, **kwargs):
747+
import sys; sys.path.append('.')
748+
lib_name, fn_name = lpython_jit_cache.compile(
749+
function, backend, backend_optimisation_flags)
750+
return getattr(__import__(lib_name), fn_name)(*args, **kwargs)
751+
return __lpython
752+
753+
if original_function:
754+
return _lpython(original_function)
755+
return _lpython
720756

721757
def bitnot(x, bitsize):
722758
return (~x) % (2 ** bitsize)

0 commit comments

Comments
 (0)