Skip to content

Commit af972f0

Browse files
Tai78641Jerry-Ge
andauthored
[TOSA] Add StatefulOps to TOSA Dialect (#66843)
This patch adds tosa.variable, tosa.variable.read and tosa.variable.write operators and tests. Change-Id: I647e2e5c3762d7890b03f6aa7c09a29198b7d355 --------- Signed-off-by: Jerry Ge <jerry.ge@arm.com> Co-authored-by: Jerry Ge <jerry.ge@arm.com>
1 parent ff1329e commit af972f0

File tree

10 files changed

+281
-19
lines changed

10 files changed

+281
-19
lines changed

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ std::unique_ptr<Pass> createTosaToLinalgNamed();
3535
void addTosaToLinalgPasses(
3636
OpPassManager &pm, const TosaToLinalgOptions &options,
3737
// Note: Default to 'none' level unless otherwise specified.
38-
tosa::ValidationOptions const &validationOptions =
39-
tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None));
38+
tosa::TosaValidationOptions const &validationOptions = {
39+
tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None});
4040

4141
/// Populates conversion passes from TOSA dialect to Linalg dialect.
4242
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class PatternRewriter;
3434

3535
namespace tosa {
3636

37+
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
38+
Attribute &attr);
39+
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
40+
Attribute attr);
41+
3742
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
3843

3944
} // namespace tosa

mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,71 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
7979
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
8080
}
8181

82+
//===----------------------------------------------------------------------===//
83+
// Operator: variable
84+
//===----------------------------------------------------------------------===//
85+
def Tosa_VariableOp : Tosa_Op<"variable", []> {
86+
let summary = "Defines a variable";
87+
88+
let description = [{
89+
Defines a new TOSA variable. This is a mutable value.
90+
Modifications are expressed using read/write semantics.
91+
}];
92+
93+
let arguments = (ins
94+
SymbolNameAttr:$name,
95+
TypeAttr:$type,
96+
OptionalAttr<AnyAttr>:$initial_value
97+
);
98+
99+
let assemblyFormat = [{
100+
$name
101+
attr-dict
102+
custom<TypeOrAttr>($type, $initial_value)
103+
}];
104+
}
105+
106+
//===----------------------------------------------------------------------===//
107+
// Operator: variable.write
108+
//===----------------------------------------------------------------------===//
109+
def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> {
110+
let summary = "write_buffer operator";
111+
112+
let description = [{
113+
Assigns a value to pseudo-buffer resource holding a mutable tensor.
114+
}];
115+
116+
let arguments = (ins
117+
SymbolNameAttr:$name,
118+
AnyType:$value
119+
);
120+
121+
let assemblyFormat = [{
122+
$name attr-dict `,` $value `:` type($value)
123+
}];
124+
}
125+
126+
//===----------------------------------------------------------------------===//
127+
// Operator: variable.read
128+
//===----------------------------------------------------------------------===//
129+
def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> {
130+
let summary = "read_buffer operator";
131+
132+
let description = [{
133+
Reads the value from a pseudo-buffer resource holding a mutable tensor.
134+
}];
135+
136+
let arguments = (ins
137+
SymbolNameAttr:$name
138+
);
139+
140+
let results = (outs
141+
AnyType:$value
142+
);
143+
144+
let assemblyFormat = [{
145+
$name attr-dict `:` type($value)
146+
}];
147+
}
148+
82149
#endif // TOSA_UTIL_OPS

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ struct ValidationOptions {
6868
}
6969
};
7070

71-
std::unique_ptr<Pass> createTosaValidationPass(
72-
ValidationOptions const &options = ValidationOptions());
73-
7471
#define GEN_PASS_REGISTRATION
7572
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
7673

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,12 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level",
8989
let cppNamespace = "mlir::tosa";
9090
}
9191

92-
def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> {
92+
def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> {
9393
let summary = "Validates TOSA dialect";
9494
let description = [{
9595
This pass validates if input TOSA operations match the specification for given
9696
criteria, e.g. TOSA profile.
9797
}];
98-
let constructor = "createTosaValidationPass()";
9998

10099
let options = [
101100
Option<"profile", "profile", "mlir::tosa::TosaProfileEnum",

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
7676

7777
void mlir::tosa::addTosaToLinalgPasses(
7878
OpPassManager &pm, const TosaToLinalgOptions &options,
79-
tosa::ValidationOptions const &validationOptions) {
79+
tosa::TosaValidationOptions const &validationOptions) {
8080
// Optional decompositions are designed to benefit linalg.
8181
if (!options.disableTosaDecompositions)
8282
pm.addNestedPass<func::FuncOp>(tosa::createTosaOptionalDecompositions());
@@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses(
9090
pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
9191
{options.aggressiveReduceConstant}));
9292
pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
93-
pm.addNestedPass<func::FuncOp>(
94-
tosa::createTosaValidationPass(validationOptions));
93+
pm.addNestedPass<func::FuncOp>(tosa::createTosaValidation(validationOptions));
9594
pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
9695
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,49 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
146146
return nullptr;
147147
}
148148

149+
//===----------------------------------------------------------------------===//
150+
// Parsers and printers
151+
//===----------------------------------------------------------------------===//
152+
153+
ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
154+
Attribute &attr) {
155+
if (succeeded(parser.parseOptionalEqual())) {
156+
if (failed(parser.parseAttribute(attr))) {
157+
return parser.emitError(parser.getCurrentLocation())
158+
<< "expected attribute";
159+
}
160+
if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
161+
typeAttr = TypeAttr::get(typedAttr.getType());
162+
}
163+
return success();
164+
}
165+
166+
Type type;
167+
if (failed(parser.parseColonType(type))) {
168+
return parser.emitError(parser.getCurrentLocation()) << "expected type";
169+
}
170+
typeAttr = TypeAttr::get(type);
171+
172+
return success();
173+
}
174+
175+
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
176+
Attribute attr) {
177+
bool needsSpace = false;
178+
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
179+
if (!typedAttr || typedAttr.getType() != type.getValue()) {
180+
p << ": ";
181+
p.printAttribute(type);
182+
needsSpace = true; // subsequent attr value needs a space separator
183+
}
184+
if (attr) {
185+
if (needsSpace)
186+
p << ' ';
187+
p << "= ";
188+
p.printAttribute(attr);
189+
}
190+
}
191+
149192
//===----------------------------------------------------------------------===//
150193
// TOSA Operator Verifiers.
151194
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
1515
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
1616

17+
#include <string>
18+
#include <unordered_map>
19+
1720
#include "mlir/Dialect/Func/IR/FuncOps.h"
1821
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1922
#include "mlir/IR/Builders.h"
@@ -96,12 +99,13 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0};
9699
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
97100
public:
98101
explicit TosaValidation() { populateConstantOperandChecks(); }
99-
explicit TosaValidation(const ValidationOptions &options) : TosaValidation() {
102+
explicit TosaValidation(const TosaValidationOptions &options)
103+
: TosaValidation() {
100104
this->profile = options.profile;
101-
this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment;
105+
this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
102106
this->level = options.level;
103107
}
104-
void runOnOperation() override;
108+
void runOnOperation() final;
105109

106110
LogicalResult applyConstantOperandCheck(Operation *op) {
107111
for (auto &checker : const_checkers) {
@@ -113,6 +117,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
113117

114118
LogicalResult applyLevelCheck(Operation *op);
115119

120+
// check variable read/write data types against variable declarations
121+
LogicalResult applyVariableCheck(Operation *op);
122+
116123
private:
117124
void populateConstantOperandChecks() {
118125
const_checkers.emplace_back(checkConstantOperandPad);
@@ -398,8 +405,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
398405
}
399406
}
400407

408+
bool CheckVariable(Operation *op);
409+
bool CheckVariableReadOrWrite(Operation *op);
410+
401411
SmallVector<std::function<LogicalResult(Operation *)>> const_checkers;
402412
tosa_level_t tosa_level;
413+
DenseMap<const mlir::StringAttr *, mlir::Type> variables_map;
403414
};
404415

405416
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
@@ -427,6 +438,69 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
427438
return success();
428439
}
429440

441+
inline bool CompatibleTypes(const mlir::Type &type,
442+
const mlir::Type &declared_type) {
443+
// for now, simply use type equality comparison
444+
return type == declared_type;
445+
}
446+
447+
bool TosaValidation::CheckVariable(Operation *op) {
448+
if (isa<mlir::tosa::VariableOp>(op)) {
449+
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
450+
451+
if (variables_map.count(&name_attr)) {
452+
op->emitOpError() << "name has already been declared";
453+
return false;
454+
}
455+
456+
auto type_attr = cast<mlir::TypeAttr>(op->getAttr("type"));
457+
mlir::Type type = type_attr.getValue();
458+
459+
variables_map[&name_attr] = type;
460+
}
461+
462+
return true;
463+
}
464+
465+
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
466+
if (isa<mlir::tosa::VariableReadOp>(op) ||
467+
isa<mlir::tosa::VariableWriteOp>(op)) {
468+
auto name_attr = cast<mlir::StringAttr>(op->getAttr("name"));
469+
470+
if (!variables_map.count(&name_attr)) {
471+
op->emitOpError() << "name has not been declared";
472+
return false;
473+
}
474+
475+
auto var_type = variables_map[&name_attr];
476+
477+
for (auto v : op->getOperands()) {
478+
auto type = v.getType();
479+
if (!CompatibleTypes(type, var_type)) {
480+
op->emitOpError() << "operand type does not equal variable type";
481+
return false;
482+
}
483+
}
484+
485+
for (auto v : op->getResults()) {
486+
auto type = v.getType();
487+
if (!CompatibleTypes(type, var_type)) {
488+
op->emitOpError() << "result type does not equal variable type";
489+
return false;
490+
}
491+
}
492+
}
493+
494+
return true;
495+
}
496+
497+
LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
498+
if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
499+
return failure();
500+
}
501+
return success();
502+
}
503+
430504
void TosaValidation::runOnOperation() {
431505
configLevelAndProfile();
432506
getOperation().walk([&](Operation *op) {
@@ -440,18 +514,18 @@ void TosaValidation::runOnOperation() {
440514
}
441515
}
442516

443-
// Some uses of TOSA rely on the constant operands of particular operations.
517+
// Some uses of TOSA rely on the constant operands of particular
518+
// operations.
444519
if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
445520
signalPassFailure();
446521

447522
// do level checks
448523
if (failed(applyLevelCheck(op)))
449524
signalPassFailure();
525+
526+
// do variable type checks
527+
if (failed(applyVariableCheck(op)))
528+
signalPassFailure();
450529
});
451530
}
452531
} // namespace
453-
454-
std::unique_ptr<Pass>
455-
mlir::tosa::createTosaValidationPass(ValidationOptions const &options) {
456-
return std::make_unique<TosaValidation>(options);
457-
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,48 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<
203203
: (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
204204
return %0 : tensor<1x7x7x9xf32>
205205
}
206+
207+
// -----
208+
209+
func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () {
210+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
211+
// expected-error@+1 {{'tosa.variable' op name has already been declared}}
212+
tosa.variable @stored_var : tensor<1x4x8xi32>
213+
return
214+
}
215+
216+
// -----
217+
218+
func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () {
219+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
220+
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
221+
%0 = tosa.variable.read @stored_var : tensor<2x4x8xi16>
222+
return
223+
}
224+
225+
// -----
226+
227+
func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () {
228+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
229+
// expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}}
230+
%0 = tosa.variable.read @stored_var : tensor<1x4x8xi32>
231+
return
232+
}
233+
234+
// -----
235+
236+
func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () {
237+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
238+
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
239+
tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16>
240+
return
241+
}
242+
243+
// -----
244+
245+
func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () {
246+
tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32>
247+
// expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}}
248+
tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32>
249+
return
250+
}

0 commit comments

Comments
 (0)