Skip to content

Commit c1eab57

Browse files
authored
[mlir] fix Operation::getDiscardableAttrs in absence of properties (#76816)
When properties are not enabled in an operation, inherent attributes are stored in the common dictionary with discardable attributes. However, `getDiscardableAttrs` and `getDiscardableAttrDictionary` were returning the entire dictionary, making the caller mistakenly believe that all inherent attributes are discardable. Fix this by filtering out attributes whose names are registered with the operation, i.e., inherent attributes. This requires an API change so `getDiscardableAttrs` returns a filter range.
1 parent 82e33d6 commit c1eab57

File tree

4 files changed

+66
-12
lines changed

4 files changed

+66
-12
lines changed

mlir/include/mlir/IR/Operation.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -475,19 +475,33 @@ class alignas(8) Operation final
475475
return removeDiscardableAttr(StringAttr::get(getContext(), name));
476476
}
477477

478-
/// Return all of the discardable attributes on this operation.
479-
ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
478+
/// Return a range of all of discardable attributes on this operation. Note
479+
/// that for unregistered operations that are not storing inherent attributes
480+
/// as properties, all attributes are considered discardable.
481+
auto getDiscardableAttrs() {
482+
std::optional<RegisteredOperationName> opName = getRegisteredInfo();
483+
ArrayRef<StringAttr> attributeNames =
484+
opName ? getRegisteredInfo()->getAttributeNames()
485+
: ArrayRef<StringAttr>();
486+
return llvm::make_filter_range(
487+
attrs.getValue(),
488+
[this, attributeNames](const NamedAttribute attribute) {
489+
return getPropertiesStorage() ||
490+
!llvm::is_contained(attributeNames, attribute.getName());
491+
});
492+
}
480493

481494
/// Return all of the discardable attributes on this operation as a
482495
/// DictionaryAttr.
483-
DictionaryAttr getDiscardableAttrDictionary() { return attrs; }
496+
DictionaryAttr getDiscardableAttrDictionary() {
497+
if (getPropertiesStorage())
498+
return attrs;
499+
return DictionaryAttr::get(getContext(),
500+
llvm::to_vector(getDiscardableAttrs()));
501+
}
484502

485503
/// Return all of the attributes on this operation.
486-
ArrayRef<NamedAttribute> getAttrs() {
487-
if (!getPropertiesStorage())
488-
return getDiscardableAttrs();
489-
return getAttrDictionary().getValue();
490-
}
504+
ArrayRef<NamedAttribute> getAttrs() { return getAttrDictionary().getValue(); }
491505

492506
/// Return all of the attributes on this operation as a DictionaryAttr.
493507
DictionaryAttr getAttrDictionary();

mlir/lib/CAPI/IR/IR.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,12 +613,14 @@ void mlirOperationSetInherentAttributeByName(MlirOperation op,
613613
}
614614

615615
intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
616-
return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
616+
return static_cast<intptr_t>(
617+
llvm::range_size(unwrap(op)->getDiscardableAttrs()));
617618
}
618619

619620
MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
620621
intptr_t pos) {
621-
NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
622+
NamedAttribute attr =
623+
*std::next(unwrap(op)->getDiscardableAttrs().begin(), pos);
622624
return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
623625
}
624626

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3543,7 +3543,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
35433543
}
35443544

35453545
auto attrs = op->getDiscardableAttrs();
3546-
printOptionalAttrDict(attrs);
3546+
printOptionalAttrDict(llvm::to_vector(attrs));
35473547

35483548
// Print the type signature of the operation.
35493549
os << " : ";

mlir/unittests/IR/OpPropertiesTest.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/IR/Attributes.h"
910
#include "mlir/IR/OpDefinition.h"
1011
#include "mlir/Parser/Parser.h"
1112
#include "gtest/gtest.h"
@@ -132,6 +133,23 @@ class OpWithProperties : public Op<OpWithProperties> {
132133
}
133134
};
134135

136+
/// A custom operation for the purpose of showcasing how discardable attributes
137+
/// are handled in absence of properties.
138+
class OpWithoutProperties : public Op<OpWithoutProperties> {
139+
public:
140+
// Begin boilerplate.
141+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithoutProperties)
142+
using Op::Op;
143+
static ArrayRef<StringRef> getAttributeNames() {
144+
static StringRef attributeNames[] = {StringRef("inherent_attr")};
145+
return ArrayRef(attributeNames);
146+
};
147+
static StringRef getOperationName() {
148+
return "test_op_properties.op_without_properties";
149+
}
150+
// End boilerplate.
151+
};
152+
135153
// A trivial supporting dialect to register the above operation.
136154
class TestOpPropertiesDialect : public Dialect {
137155
public:
@@ -142,7 +160,7 @@ class TestOpPropertiesDialect : public Dialect {
142160
explicit TestOpPropertiesDialect(MLIRContext *context)
143161
: Dialect(getDialectNamespace(), context,
144162
TypeID::get<TestOpPropertiesDialect>()) {
145-
addOperations<OpWithProperties>();
163+
addOperations<OpWithProperties, OpWithoutProperties>();
146164
}
147165
};
148166

@@ -359,4 +377,24 @@ TEST(OpPropertiesTest, getOrAddProperties) {
359377
op->erase();
360378
}
361379

380+
constexpr StringLiteral withoutPropertiesAttrsSrc = R"mlir(
381+
"test_op_properties.op_without_properties"()
382+
{inherent_attr = 42, other_attr = 56} : () -> ()
383+
)mlir";
384+
385+
TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) {
386+
MLIRContext context;
387+
context.getOrLoadDialect<TestOpPropertiesDialect>();
388+
ParserConfig config(&context);
389+
OwningOpRef<Operation *> op =
390+
parseSourceString(withoutPropertiesAttrsSrc, config);
391+
ASSERT_EQ(llvm::range_size(op->getDiscardableAttrs()), 1u);
392+
EXPECT_EQ(op->getDiscardableAttrs().begin()->getName().getValue(),
393+
"other_attr");
394+
395+
EXPECT_EQ(op->getAttrs().size(), 2u);
396+
EXPECT_TRUE(op->getInherentAttr("inherent_attr") != std::nullopt);
397+
EXPECT_TRUE(op->getDiscardableAttr("other_attr") != Attribute());
398+
}
399+
362400
} // namespace

0 commit comments

Comments
 (0)