diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index d40d936cdc83d..df835d63dc2d5 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -2,6 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import inspect from typing import ( List as _List, Optional as _Optional, @@ -50,6 +51,32 @@ def segmented_accessor(elements, raw_segments, idx): return elements[start:end] +def get_source_location(): + """ + Returns a source location from the frame just before the one whose + filename includes 'python_packages'. + """ + frame = inspect.currentframe() + outer_frames = inspect.getouterframes(frame) + + # Traverse the frames in reverse order, excluding the current frame + selected_frame = None + for i in range(len(outer_frames) - 1, -1, -1): + current_frame = outer_frames[i] + if "python_packages" in current_frame.filename: + # Select the frame before the one containing 'python_packages' + selected_frame = outer_frames[i + 1] if i - 1 >= 0 else current_frame + break + if selected_frame is None: + # If no frame containing 'python_packages' is found, use the last frame + selected_frame = outer_frames[-1] + + # Create file location using the selected frame + file_loc = _cext.ir.Location.file(selected_frame.filename, selected_frame.lineno, 0) + loc = _cext.ir.Location.name(selected_frame.function, childLoc=file_loc) + return loc + + def equally_sized_accessor( elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic ): @@ -138,11 +165,10 @@ def get_op_result_or_op_results( return ( list(get_op_results_or_values(op)) if len(op.results) > 1 - else get_op_result_or_value(op) - if len(op.results) > 0 - else op + else get_op_result_or_value(op) if len(op.results) > 0 else op ) + ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value ResultValueT = _Union[ResultValueTypeTuple] VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] diff --git a/mlir/test/python/ir/location.py b/mlir/test/python/ir/location.py index f66d6c501dcf5..db87ef450cb6a 100644 --- a/mlir/test/python/ir/location.py +++ b/mlir/test/python/ir/location.py @@ -2,6 +2,7 @@ import gc from mlir.ir import * +from mlir.dialects import arith def run(f): @@ -43,6 +44,7 @@ def testLocationAttr(): run(testLocationAttr) + # CHECK-LABEL: TEST: testFileLineCol def testFileLineCol(): with Context() as ctx: @@ -150,3 +152,33 @@ def testLocationCapsule(): run(testLocationCapsule) + + +# CHECK-LABEL: TEST: autoGeneratedLocation +def autoGeneratedLocation(): + def generator(): + return arith.ConstantOp(value=123, result=IntegerType.get_signless(32)) + + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + a = arith.ConstantOp(value=42, result=IntegerType.get_signless(32)) + b = arith.AddIOp(a, generator()) + module.operation.print(enable_debug_info=True) + + +# CHECK: module { +# CHECK: %{{.*}} = arith.constant 42 : i32 loc(#loc4) +# CHECK: %{{.*}} = arith.constant 123 : i32 loc(#loc5) +# CHECK: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i32 loc(#loc6) +# CHECK: } loc(#loc) + +# CHECK: #loc = loc(unknown) +# CHECK: #loc1 = loc({{.*}}:164:0) +# CHECK: #loc2 = loc({{.*}}:160:0) +# CHECK: #loc3 = loc({{.*}}:165:0) +# CHECK: #loc4 = loc("autoGeneratedLocation"(#loc1)) +# CHECK: #loc5 = loc("generator"(#loc2)) +# CHECK: #loc6 = loc("autoGeneratedLocation"(#loc3)) + +run(autoGeneratedLocation) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 0c5c936f5adde..bcea716748bd3 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -35,6 +35,7 @@ constexpr const char *fileHeader = R"Py( from ._ods_common import _cext as _ods_cext from ._ods_common import ( equally_sized_accessor as _ods_equally_sized_accessor, + get_source_location as _get_source_location, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_op_results as _get_op_result_or_op_results, get_op_result_or_value as _get_op_result_or_value, @@ -491,6 +492,8 @@ constexpr const char *initTemplate = R"Py( attributes = {{} regions = None {1} + if loc is None: + loc = _get_source_location() super().__init__(self.build_generic({2})) )Py";