Skip to content

Commit a615975

Browse files
authored
[MLIR][NVVM] Add Op to create tcgen05-mma smem descriptor (#141651)
This patch adds an Op to create the shared-memory descriptor for Tcgen05 MMA. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent f8ca9e5 commit a615975

File tree

3 files changed

+146
-0
lines changed

3 files changed

+146
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3373,6 +3373,70 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
33733373
}];
33743374
}
33753375

3376+
def NVVM_Tcgen05MmaSmemDescOp : NVVM_Op<"tcgen05.mma_smem_desc", []> {
3377+
let summary = "Constructs a Shared Memory descriptor for MMA Operands A or B";
3378+
let description = [{
3379+
The `nvvm.tcgen05_mma_smem_desc` constructs a Shared Memory descriptor
3380+
for tcgen05.mma. This descriptor is a 64-bit value which describes the
3381+
properties of multiplicand matrix in shared memory including its location
3382+
in the shared memory of the current CTA.
3383+
3384+
+-----------+------+------------------------------------------------------+
3385+
| Bit-field | Size | Description |
3386+
+-----------+------+------------------------------------------------------+
3387+
| 0-13 | 14 | Matrix start address |
3388+
| 14-15 | 2 | Reserved |
3389+
| 16-29 | 14 | Leading dim relative-offset (or) absolute-address |
3390+
| 30-31 | 2 | Reserved |
3391+
| 32-45 | 14 | Stride dimension byte offset |
3392+
| 46-48 | 3 | Fixed constant value of 0b001 |
3393+
| 49-51 | 3 | Matrix base offset |
3394+
| 52 | 1 | Leading dimension stride mode: |
3395+
| | | 0: byte offset relative |
3396+
| | | 1: byte address absolute |
3397+
| 53-60 | 8 | Fixed constant value of 0xb00000000 |
3398+
| 61-63 | 3 | Swizzling mode: |
3399+
| | | 0: No swizzling |
3400+
| | | 1: 128-Byte with 32B atomic swizzling |
3401+
| | | 2: 128-Byte swizzling |
3402+
| | | 4: 64-Byte swizzling |
3403+
| | | 6: 32-Byte swizzling |
3404+
| | | (Values 3, 5 and 7 are invalid) |
3405+
+-----------+------+------------------------------------------------------+
3406+
3407+
Example:
3408+
```mlir
3409+
%desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset,
3410+
%baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
3411+
```
3412+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
3413+
}];
3414+
3415+
let arguments = (ins
3416+
I32:$startAddr, // Matrix A or B start address (bits 13-0)
3417+
I32:$leadingDimOffset, // Matrix A or B leading dim byte offset (bits 29-16)
3418+
I32:$strideDimOffset, // Matrix A or B stride dim byte offset (bits 45-32)
3419+
I8:$baseOffset, // Matrix A or B base offset (bits 51-49)
3420+
I1:$leadingDimMode, // Matrix A or B leading dim mode (bit 52)
3421+
I8:$swizzleMode // Swizzle mode (bits 63-61)
3422+
);
3423+
3424+
let results = (outs I64:$res);
3425+
3426+
let assemblyFormat = [{
3427+
`(` operands `)` attr-dict `:` `(` type(operands) `)` `->` type($res)
3428+
}];
3429+
3430+
let extraClassDeclaration = [{
3431+
static void createSmemDescriptor(Operation &op, LLVM::ModuleTranslation &mt,
3432+
llvm::IRBuilderBase& builder);
3433+
}];
3434+
3435+
string llvmBuilder = [{
3436+
NVVM::Tcgen05MmaSmemDescOp::createSmemDescriptor(*op, moduleTranslation, builder);
3437+
}];
3438+
}
3439+
33763440
//===----------------------------------------------------------------------===//
33773441
// NVVM tcgen05 LdSt Shape Attr
33783442
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,50 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
12121212
llvm::Type::getInt32Ty(builder.getContext()));
12131213
}
12141214

1215+
/// Packs the given `field` into the `result`.
1216+
/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
1217+
static llvm::Value *
1218+
packValInto64Bits(llvm::IRBuilderBase &builder,
1219+
llvm::Value *result, // the `result` (unset bits are zero)
1220+
llvm::Value *field, // `field` to pack into `result`
1221+
unsigned sizeInBits, // Size of `field` in bits
1222+
unsigned start) { // Starting bit within `result`
1223+
field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
1224+
1225+
unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
1226+
if (mask != 0xffffffffu)
1227+
field = builder.CreateAnd(field, builder.getInt32(mask));
1228+
1229+
field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
1230+
field = builder.CreateShl(field, start);
1231+
1232+
return builder.CreateOr(result, field);
1233+
}
1234+
1235+
void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
1236+
LLVM::ModuleTranslation &mt,
1237+
llvm::IRBuilderBase &builder) {
1238+
auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
1239+
llvm::Value *smemDesc = builder.getInt64(0);
1240+
1241+
smemDesc = packValInto64Bits(builder, smemDesc,
1242+
mt.lookupValue(thisOp.getStartAddr()), 14, 0);
1243+
smemDesc = packValInto64Bits(
1244+
builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
1245+
smemDesc = packValInto64Bits(
1246+
builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
1247+
1248+
smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
1249+
smemDesc = packValInto64Bits(builder, smemDesc,
1250+
mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
1251+
smemDesc = packValInto64Bits(
1252+
builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
1253+
smemDesc = packValInto64Bits(builder, smemDesc,
1254+
mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
1255+
1256+
mt.mapValue(thisOp.getRes()) = smemDesc;
1257+
}
1258+
12151259
//===----------------------------------------------------------------------===//
12161260
// getIntrinsicID/getIntrinsicIDAndArgs methods
12171261
//===----------------------------------------------------------------------===//
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: define i64 @tcgen05_mma_smem_desc_test(i32 %0, i32 %1, i32 %2, i8 %3, i1 %4, i8 %5) {
4+
llvm.func @tcgen05_mma_smem_desc_test(%startAddr: i32, %leadingDimOffset: i32, %strideDimOffset: i32,
5+
%baseOffset: i8, %leadingDimMode: i1, %swizzleMode: i8) -> i64 {
6+
// CHECK-NEXT: %7 = and i32 %0, 16383
7+
// CHECK-NEXT: %8 = zext i32 %7 to i64
8+
// CHECK-NEXT: %9 = shl i64 %8, 0
9+
// CHECK-NEXT: %10 = or i64 0, %9
10+
// CHECK-NEXT: %11 = and i32 %1, 16383
11+
// CHECK-NEXT: %12 = zext i32 %11 to i64
12+
// CHECK-NEXT: %13 = shl i64 %12, 16
13+
// CHECK-NEXT: %14 = or i64 %10, %13
14+
// CHECK-NEXT: %15 = and i32 %2, 16383
15+
// CHECK-NEXT: %16 = zext i32 %15 to i64
16+
// CHECK-NEXT: %17 = shl i64 %16, 32
17+
// CHECK-NEXT: %18 = or i64 %14, %17
18+
// CHECK-NEXT: %19 = or i64 %18, 70368744177664
19+
// CHECK-NEXT: %20 = zext i8 %3 to i32
20+
// CHECK-NEXT: %21 = and i32 %20, 7
21+
// CHECK-NEXT: %22 = zext i32 %21 to i64
22+
// CHECK-NEXT: %23 = shl i64 %22, 49
23+
// CHECK-NEXT: %24 = or i64 %19, %23
24+
// CHECK-NEXT: %25 = zext i1 %4 to i32
25+
// CHECK-NEXT: %26 = and i32 %25, 1
26+
// CHECK-NEXT: %27 = zext i32 %26 to i64
27+
// CHECK-NEXT: %28 = shl i64 %27, 52
28+
// CHECK-NEXT: %29 = or i64 %24, %28
29+
// CHECK-NEXT: %30 = zext i8 %5 to i32
30+
// CHECK-NEXT: %31 = and i32 %30, 7
31+
// CHECK-NEXT: %32 = zext i32 %31 to i64
32+
// CHECK-NEXT: %33 = shl i64 %32, 61
33+
// CHECK-NEXT: %34 = or i64 %29, %33
34+
// CHECK-NEXT: ret i64 %34
35+
// CHECK-NEXT: }
36+
%desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset, %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
37+
llvm.return %desc : i64
38+
}

0 commit comments

Comments
 (0)