Skip to content

Commit 8513ff0

Browse files
author
Nicolas Vasilache
committed
[mlir][VectorOps][EDSC] Add EDSC for VectorOps
Summary: This revision adds EDSC support for VectorOps to enable the creation of a `vector_matmul` declaratively. The `vector_matmul` is a simple configuration of the `vector.contract` op that follows the StructuredOps abstraction. Differential Revision: https://reviews.llvm.org/D74284
1 parent 62ce7e6 commit 8513ff0

File tree

12 files changed

+228
-29
lines changed

12 files changed

+228
-29
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- Builders.h - MLIR Declarative Vector Builders ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Provides intuitive composable interfaces for building structured MLIR
10+
// snippets in a declarative fashion.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
#ifndef MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
14+
#define MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
15+
16+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
17+
#include "mlir/Dialect/VectorOps/VectorOps.h"
18+
#include "mlir/EDSC/Builders.h"
19+
#include "mlir/EDSC/Intrinsics.h"
20+
#include "mlir/IR/AffineExpr.h"
21+
#include "mlir/IR/Builders.h"
22+
23+
namespace mlir {
24+
namespace edsc {
25+
namespace ops {
26+
27+
/// Build a generic vector contraction, that is a `vector.contract` op with
28+
/// specified `iteratorTypes`. The client is responsible for specifying proper
29+
/// indexings when creating the StructuredIndexed.
30+
/// The computation represents a notional (A * B + C) where indexings specify
31+
/// which dimensions are reduced and reordered.
32+
/// Return the result of the `vector.contract` op
33+
///
34+
/// Prerequisites:
35+
/// A, B and C capture values of proper vector types, and indexing expressions
36+
/// that match semantics of the `vector.contract` op.
37+
Value vector_contraction(StructuredIndexed A, StructuredIndexed B,
38+
StructuredIndexed C,
39+
ArrayRef<IteratorType> iteratorTypes);
40+
41+
/// Build a generic vector contraction that computes a matmul on vectors.
42+
/// Return the result of C(i, j) + sum_k {A(i, k) * B(k, j)} on vectors.
43+
///
44+
/// Prerequisites:
45+
/// A, B and C capture values of proper vector types. For instance
46+
/// `A: vector<4x8xf32>`, `B: vector<8x16f32>` and `C: vector<4x16xf32>`.
47+
Value vector_matmul(Value A, Value B, Value C);
48+
49+
} // namespace ops
50+
} // namespace edsc
51+
} // namespace mlir
52+
53+
#endif // MLIR_DIALECT_VECTOR_EDSC_BUILDERS_H_
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- Intrinsics.h - MLIR EDSC Intrinsics for VectorOps --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_
9+
#define MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_
10+
11+
#include "mlir/Dialect/VectorOps/EDSC/Builders.h"
12+
13+
namespace mlir {
14+
namespace edsc {
15+
namespace intrinsics {
16+
17+
using vector_contract = ValueBuilder<vector::ContractionOp>;
18+
19+
} // namespace intrinsics
20+
} // namespace edsc
21+
} // namespace mlir
22+
23+
#endif // MLIR_DIALECT_VECTOROPS_EDSC_INTRINSICS_H_

mlir/include/mlir/Dialect/VectorOps/VectorOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,11 @@ def Vector_ContractionOp :
141141
}];
142142
let builders = [OpBuilder<
143143
"Builder *builder, OperationState &result, Value lhs, Value rhs, "
144-
"Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
144+
"Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">,
145+
OpBuilder<
146+
"Builder *builder, OperationState &result, Value lhs, Value rhs, "
147+
"Value acc, ArrayRef<ArrayRef<AffineExpr>> indexingExprs, "
148+
"ArrayRef<StringRef> iteratorTypes">];
145149
let extraClassDeclaration = [{
146150
VectorType getLhsType() {
147151
return lhs().getType().cast<VectorType>();

mlir/include/mlir/EDSC/Builders.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,9 @@ struct StructuredIndexed : public ValueHandle {
436436
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
437437
: ValueHandle(v), exprs(indexings.begin(), indexings.end()) {
438438
assert((v.getType().isa<MemRefType>() ||
439-
v.getType().isa<RankedTensorType>()) &&
440-
"MemRef or RankedTensor expected");
439+
v.getType().isa<RankedTensorType>() ||
440+
v.getType().isa<VectorType>()) &&
441+
"MemRef, RankedTensor or Vector expected");
441442
}
442443
StructuredIndexed(ValueHandle vh, ArrayRef<AffineExpr> indexings)
443444
: ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {}

mlir/include/mlir/IR/AffineMap.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ class AffineMap {
6363
static AffineMap getPermutationMap(ArrayRef<unsigned> permutation,
6464
MLIRContext *context);
6565

66+
/// Returns a vector of AffineMaps; each with as many results as
67+
/// `exprs.size()`, as many dims as the largest dim in `exprs` and as many
68+
/// symbols as the largest symbol in `exprs`.
69+
static SmallVector<AffineMap, 4>
70+
inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList);
71+
static SmallVector<AffineMap, 4>
72+
inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList);
73+
6674
MLIRContext *getContext() const;
6775

6876
explicit operator bool() { return map != nullptr; }

mlir/lib/Dialect/Linalg/EDSC/Builders.cpp

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,6 @@ GenericLoopNestRangeBuilder<loop::ParallelOp>::GenericLoopNestRangeBuilder(
130130
} // namespace edsc
131131
} // namespace mlir
132132

133-
static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
134-
unsigned &pos) {
135-
for (auto sidx : structuredIndices) {
136-
for (auto expr : sidx.getExprs()) {
137-
expr.walk([&pos](AffineExpr e) {
138-
if (auto d = e.dyn_cast<AffineDimExpr>())
139-
pos = std::max(pos, d.getPosition());
140-
});
141-
}
142-
}
143-
}
144-
145133
Operation *mlir::edsc::makeGenericLinalgOp(
146134
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
147135
ArrayRef<StructuredIndexed> outputs,
@@ -155,20 +143,16 @@ Operation *mlir::edsc::makeGenericLinalgOp(
155143
auto *ctx = builder.getContext();
156144
unsigned nInputs = inputs.size();
157145
unsigned nOutputs = outputs.size();
158-
unsigned maxPos = 0;
159-
getMaxDimIndex(inputs, maxPos);
160-
getMaxDimIndex(outputs, maxPos);
161-
// maxPos is 0 indexed, need to turn this into a count (i.e. +1)
162-
unsigned nDims = maxPos + 1;
163-
164-
SmallVector<AffineMap, 4> maps;
165-
maps.reserve(nInputs + nOutputs);
166-
for (auto in : inputs)
167-
maps.push_back(
168-
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs()));
169-
for (auto out : outputs)
170-
maps.push_back(
171-
AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs()));
146+
147+
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
148+
exprsList.reserve(nInputs + nOutputs);
149+
for (auto structuredIndexed : inputs)
150+
exprsList.emplace_back(structuredIndexed.getExprs().begin(),
151+
structuredIndexed.getExprs().end());
152+
for (auto structuredIndexed : outputs)
153+
exprsList.emplace_back(structuredIndexed.getExprs().begin(),
154+
structuredIndexed.getExprs().end());
155+
auto maps = AffineMap::inferFromExprList(exprsList);
172156

173157
unsigned nViews = nInputs + nOutputs;
174158
SmallVector<Value, 4> values;

mlir/lib/Dialect/VectorOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_llvm_library(MLIRVectorOps
33
VectorOps.cpp
44
VectorTransforms.cpp
55
VectorUtils.cpp
6+
EDSC/Builders.cpp
67

78
ADDITIONAL_HEADER_DIRS
89
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/VectorOps
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===- Builders.cpp - MLIR Declarative Linalg Builders --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/VectorOps/EDSC/Builders.h"
10+
#include "mlir/Dialect/VectorOps/EDSC/Intrinsics.h"
11+
#include "mlir/Dialect/VectorOps/VectorOps.h"
12+
#include "mlir/EDSC/Builders.h"
13+
#include "mlir/EDSC/Intrinsics.h"
14+
#include "mlir/IR/AffineExpr.h"
15+
#include "mlir/IR/Builders.h"
16+
#include "mlir/Support/Functional.h"
17+
18+
using namespace mlir;
19+
using namespace mlir::edsc;
20+
using namespace mlir::edsc::intrinsics;
21+
using namespace mlir::edsc::ops;
22+
23+
Value mlir::edsc::ops::vector_contraction(
24+
StructuredIndexed A, StructuredIndexed B, StructuredIndexed C,
25+
ArrayRef<IteratorType> iteratorTypes) {
26+
using IndexingExprs = ArrayRef<ArrayRef<AffineExpr>>;
27+
return vector_contract(
28+
A.getValue(), B.getValue(), C.getValue(),
29+
IndexingExprs{A.getExprs(), B.getExprs(), C.getExprs()},
30+
ArrayRef<StringRef>{functional::map(toString, iteratorTypes)});
31+
}
32+
33+
Value mlir::edsc::ops::vector_matmul(Value A, Value B, Value C) {
34+
AffineExpr m, n, k;
35+
bindDims(ScopedContext::getContext(), m, n, k);
36+
return vector_contraction(StructuredIndexed(A, {m, k}),
37+
StructuredIndexed(B, {k, n}),
38+
StructuredIndexed(C, {m, n}),
39+
{IteratorType::Parallel, IteratorType::Parallel,
40+
IteratorType::Reduction});
41+
}

mlir/lib/Dialect/VectorOps/VectorOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,19 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
6464
// ContractionOp
6565
//===----------------------------------------------------------------------===//
6666

67+
void vector::ContractionOp::build(Builder *builder, OperationState &result,
68+
Value lhs, Value rhs, Value acc,
69+
ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
70+
ArrayRef<StringRef> iteratorTypes) {
71+
result.addOperands({lhs, rhs, acc});
72+
result.addTypes(acc.getType());
73+
result.addAttribute(getIndexingMapsAttrName(),
74+
builder->getAffineMapArrayAttr(
75+
AffineMap::inferFromExprList(indexingExprs)));
76+
result.addAttribute(getIteratorTypesAttrName(),
77+
builder->getStrArrayAttr(iteratorTypes));
78+
}
79+
6780
void vector::ContractionOp::build(Builder *builder, OperationState &result,
6881
Value lhs, Value rhs, Value acc,
6982
ArrayAttr indexingMaps,

mlir/lib/IR/AffineMap.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,44 @@ AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
111111
return permutationMap;
112112
}
113113

114+
template <typename AffineExprContainer>
115+
static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
116+
int64_t &maxDim, int64_t &maxSym) {
117+
for (const auto &exprs : exprsList) {
118+
for (auto expr : exprs) {
119+
expr.walk([&maxDim, &maxSym](AffineExpr e) {
120+
if (auto d = e.dyn_cast<AffineDimExpr>())
121+
maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
122+
if (auto s = e.dyn_cast<AffineSymbolExpr>())
123+
maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
124+
});
125+
}
126+
}
127+
}
128+
129+
template <typename AffineExprContainer>
130+
SmallVector<AffineMap, 4>
131+
inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
132+
int64_t maxDim = -1, maxSym = -1;
133+
getMaxDimAndSymbol(exprsList, maxDim, maxSym);
134+
SmallVector<AffineMap, 4> maps;
135+
maps.reserve(exprsList.size());
136+
for (const auto &exprs : exprsList)
137+
maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
138+
/*symbolCount=*/maxSym + 1, exprs));
139+
return maps;
140+
}
141+
142+
SmallVector<AffineMap, 4>
143+
AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
144+
return ::inferFromExprList(exprsList);
145+
}
146+
147+
SmallVector<AffineMap, 4>
148+
AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
149+
return ::inferFromExprList(exprsList);
150+
}
151+
114152
AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
115153
MLIRContext *context) {
116154
SmallVector<AffineExpr, 4> dimExprs;

mlir/test/EDSC/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ target_link_libraries(mlir-edsc-builder-api-test
1414
MLIRLoopOps
1515
MLIRStandardOps
1616
MLIRTransforms
17+
MLIRVectorOps
1718
LLVMCore
1819
LLVMSupport
1920
)
@@ -25,5 +26,6 @@ whole_archive_link(mlir-edsc-builder-api-test
2526
MLIRLinalgOps
2627
MLIRLoopOps
2728
MLIRStandardOps
29+
MLIRVectorOps
2830
MLIRTransforms
2931
)

mlir/test/EDSC/builder-api-test.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
1313
#include "mlir/Dialect/LoopOps/EDSC/Builders.h"
1414
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
15+
#include "mlir/Dialect/VectorOps/EDSC/Intrinsics.h"
1516
#include "mlir/EDSC/Builders.h"
1617
#include "mlir/EDSC/Intrinsics.h"
1718
#include "mlir/IR/AffineExpr.h"
@@ -981,6 +982,36 @@ TEST_FUNC(linalg_tensors_test) {
981982
f.erase();
982983
}
983984

985+
// CHECK-LABEL: func @vector_matmul_test(
986+
// CHECK-SAME: %[[A:.*]]: vector<4x16xf32>,
987+
// CHECK-SAME: %[[B:.*]]: vector<16x8xf32>,
988+
// CHECK-SAME: %[[C:.*]]: vector<4x8xf32>)
989+
// CHECK: vector.contract {{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>,
990+
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
991+
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
992+
// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"]
993+
// CHECK-SAME: %[[A]], %[[B]], %[[C]]
994+
// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
995+
TEST_FUNC(vector_matmul_test) {
996+
using namespace edsc;
997+
using namespace edsc::ops;
998+
999+
int64_t M = 4, N = 8, K = 16;
1000+
auto f32Type = FloatType::getF32(&globalContext());
1001+
auto mkVectorType = VectorType::get({M, K}, f32Type);
1002+
auto knVectorType = VectorType::get({K, N}, f32Type);
1003+
auto mnVectorType = VectorType::get({M, N}, f32Type);
1004+
auto f = makeFunction("vector_matmul_test", {},
1005+
{mkVectorType, knVectorType, mnVectorType});
1006+
1007+
OpBuilder builder(f.getBody());
1008+
ScopedContext scope(builder, f.getLoc());
1009+
ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
1010+
vector_matmul(A, B, C);
1011+
f.print(llvm::outs());
1012+
f.erase();
1013+
}
1014+
9841015
int main() {
9851016
RUN_TESTS();
9861017
return 0;

0 commit comments

Comments
 (0)