Skip to content

Commit ca9a335

Browse files
committed
[mlir][ArmSME] Add tile load op and extend tile store tile size support
This extends the existing 'arm_sme.tile_store' op to support all tile sizes and adds a new op 'arm_sme.tile_load', as well as lowerings from vector -> custom ops and custom ops -> intrinsics. Currently there's no lowering for i128. Depends on D154867 Reviewed By: awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D155306
1 parent cee4494 commit ca9a335

File tree

12 files changed

+1069
-66
lines changed

12 files changed

+1069
-66
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -224,21 +224,74 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
224224
let assemblyFormat = "attr-dict `:` type($res)";
225225
}
226226

227+
def TileLoadOp : ArmSME_Op<"tile_load"> {
228+
let summary = "Tile load operation";
229+
let description = [{
230+
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
231+
with the shape defined by the 2D scalable vector type of the result tile.
232+
The slice of memory must be contiguous. The memref must be either rank 1 or
233+
rank 2 with dynamic dimensions, since the operation is scalable, and the
234+
element type must be a scalar that matches the element type of the result.
235+
236+
Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
237+
```mlir
238+
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
239+
```
240+
241+
Example 2: Load a FP 32-bit element ZA tile from memory.
242+
```mlir
243+
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
244+
```
245+
246+
Example 3: Load a 128-bit element ZA tile from memory.
247+
```mlir
248+
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
249+
```
250+
}];
251+
let arguments = (ins
252+
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
253+
Variadic<Index>:$indices);
254+
let results = (outs SMETile:$result);
255+
256+
let extraClassDeclaration = [{
257+
MemRefType getMemRefType() {
258+
return ::llvm::cast<MemRefType>(getBase().getType());
259+
}
260+
VectorType getVectorType() {
261+
return ::llvm::cast<VectorType>(getResult().getType());
262+
}
263+
}];
264+
265+
let assemblyFormat =
266+
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
267+
}
268+
227269
def TileStoreOp : ArmSME_Op<"tile_store"> {
228270
let summary = "Tile store operation";
229271
let description = [{
230-
Store a 2D SME "virtual tile" to memory.
231-
232-
NOTE: At the moment it is assumed that the element type is `i8` and that
233-
there's only one "virtual tile".
272+
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
273+
with the shape defined by the 2D scalable vector type of the tile being
274+
stored. The slice of memory must be contiguous. The memref must be either
275+
rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
276+
and the element type must be a scalar that matches the element type of the
277+
result.
278+
279+
Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
280+
```mlir
281+
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
282+
```
234283

235-
Example:
284+
Example 2: Store a FP 32-bit element ZA tile to memory.
285+
```mlir
286+
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
287+
```
236288

289+
Example 3: Store a 128-bit element ZA tile to memory.
237290
```mlir
238-
arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
291+
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
239292
```
240293
}];
241-
let arguments = (ins nxnxv16i8:$valueToStore,
294+
let arguments = (ins SMETile:$valueToStore,
242295
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
243296
Variadic<Index>:$indices);
244297
let extraClassDeclaration = [{
@@ -304,7 +357,7 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
304357
class ArmSME_IntrLoadOp<string mnemonic>
305358
: ArmSME_IntrOp<mnemonic>,
306359
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
307-
Arg<LLVM_AnyPointer, "Load address", [MemRead]>,
360+
Arg<LLVM_AnyPointer, "Load address">,
308361
Arg<I32, "Virtual tile ID">,
309362
Arg<I32, "Tile slice">)>;
310363

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- Utils.h - General ArmSME transformation utilities --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines prototypes for various utilities for the ArmSME
10+
// dialect. These are not passes by themselves but are used either by passes,
11+
// optimization sequences, or in turn by other transformation utilities.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
16+
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
17+
18+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
19+
20+
namespace mlir {
21+
namespace arm_sme {
22+
23+
/// Return minimum number of elements for the given element `type` in
24+
/// a vector of SVL bits.
25+
unsigned getSMETileSliceMinNumElts(Type type);
26+
27+
/// Returns true if `type` is a valid element type for an SME tile or false
28+
/// otherwise.
29+
bool isValidSMETileElementType(Type type);
30+
31+
/// Returns true if `vType` is a valid vector type for an SME tile or false
32+
/// otherwise.
33+
bool isValidSMETileVectorType(VectorType vType);
34+
35+
} // namespace arm_sme
36+
} // namespace mlir
37+
38+
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_

mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
1010

1111
LINK_LIBS PUBLIC
1212
MLIRArmSMEDialect
13+
MLIRArmSMEUtils
1314
MLIRLLVMCommonConversion
1415
)

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
1010

1111
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
12+
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
1213
#include "mlir/IR/BuiltinTypes.h"
1314
#include "llvm/Support/Casting.h"
1415

@@ -76,9 +77,42 @@ struct TransferWriteToArmSMELowering
7677
}
7778
};
7879

80+
/// Conversion pattern for vector.load.
81+
struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
82+
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
83+
84+
LogicalResult matchAndRewrite(vector::LoadOp load,
85+
PatternRewriter &rewriter) const override {
86+
if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
87+
return failure();
88+
89+
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
90+
load, load.getVectorType(), load.getBase(), load.getIndices());
91+
92+
return success();
93+
}
94+
};
95+
96+
/// Conversion pattern for vector.store.
97+
struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
98+
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
99+
100+
LogicalResult matchAndRewrite(vector::StoreOp store,
101+
PatternRewriter &rewriter) const override {
102+
if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
103+
return failure();
104+
105+
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
106+
store, store.getValueToStore(), store.getBase(), store.getIndices());
107+
108+
return success();
109+
}
110+
};
111+
79112
} // namespace
80113

81114
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
82115
MLIRContext &ctx) {
83-
patterns.add<TransferWriteToArmSMELowering>(&ctx);
116+
patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
117+
VectorStoreToArmSMELowering>(&ctx);
84118
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
22
add_subdirectory(Transforms)
3+
add_subdirectory(Utils)

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
1212

1313
LINK_LIBS PUBLIC
1414
MLIRArmSMEDialect
15+
MLIRArmSMEUtils
1516
MLIRFuncDialect
1617
MLIRLLVMCommonConversion
1718
MLIRVectorDialect

0 commit comments

Comments
 (0)