Skip to content

Commit bec0dea

Browse files
authored
[Python][OM] Handle BoolAttr's before IntegerAttr's. (#7438)
BoolAttr's are IntegerAttr's, check them first. IntegerAttr's that happen to have the characteristics of BoolAttr will accordingly become Python boolean values. Unclear where these come from but we do lower booleans to MLIR bool constants so make sure to handle that. Add test for object model IR with bool constants.
1 parent cbdee94 commit bec0dea

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

integration_test/Bindings/Python/dialects/om.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@
6868
6969
%map = om.map_create %entry1, %entry2: !om.string, !om.integer
7070
om.class.field @map_create, %map : !om.map<!om.string, !om.integer>
71+
72+
%true = om.constant true
73+
om.class.field @true, %true : i1
74+
%false = om.constant false
75+
om.class.field @false, %false : i1
7176
}
7277
7378
om.class @Child(%0: !om.integer) {
@@ -157,7 +162,7 @@
157162

158163
# CHECK: 14
159164
print(obj.child.foo)
160-
# CHECK: loc("-":60:7)
165+
# CHECK: loc("-":65:7)
161166
print(obj.child.get_field_loc("foo"))
162167
# CHECK: ('Root', 'x')
163168
print(obj.reference)
@@ -224,6 +229,11 @@
224229
# CHECK-NEXT: Y 15
225230
print(k, v)
226231

232+
# CHECK: True
233+
print(obj.true)
234+
# CHECK: False
235+
print(obj.false)
236+
227237
obj = evaluator.instantiate("Client")
228238
object_dict: Dict[om.Object, str] = {}
229239
for field_name, data in obj:

lib/Bindings/Python/OMModule.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,15 +366,6 @@ Map::dunderGetItem(std::variant<intptr_t, std::string, MlirAttribute> key) {
366366
// Convert a generic MLIR Attribute to a PythonValue. This is basically a C++
367367
// fast path of the parts of attribute_to_var that we use in the OM dialect.
368368
static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
369-
if (mlirAttributeIsAInteger(attr)) {
370-
MlirType type = mlirAttributeGetType(attr);
371-
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
372-
return py::int_(mlirIntegerAttrGetValueInt(attr));
373-
if (mlirIntegerTypeIsSigned(type))
374-
return py::int_(mlirIntegerAttrGetValueSInt(attr));
375-
return py::int_(mlirIntegerAttrGetValueUInt(attr));
376-
}
377-
378369
if (omAttrIsAIntegerAttr(attr)) {
379370
auto strRef = omIntegerAttrToString(attr);
380371
return py::int_(py::str(strRef.data, strRef.length));
@@ -389,10 +380,20 @@ static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
389380
return py::str(strRef.data, strRef.length);
390381
}
391382

383+
// BoolAttr's are IntegerAttr's, check this first.
392384
if (mlirAttributeIsABool(attr)) {
393385
return py::bool_(mlirBoolAttrGetValue(attr));
394386
}
395387

388+
if (mlirAttributeIsAInteger(attr)) {
389+
MlirType type = mlirAttributeGetType(attr);
390+
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
391+
return py::int_(mlirIntegerAttrGetValueInt(attr));
392+
if (mlirIntegerTypeIsSigned(type))
393+
return py::int_(mlirIntegerAttrGetValueSInt(attr));
394+
return py::int_(mlirIntegerAttrGetValueUInt(attr));
395+
}
396+
396397
if (omAttrIsAReferenceAttr(attr)) {
397398
auto innerRef = omReferenceAttrGetInnerRef(attr);
398399
auto moduleStrRef =

0 commit comments

Comments
 (0)