Skip to content

Commit 492ad84

Browse files
authored
[SPIRV] Add explicit layout (#135789)
Adds code to add offset decorations when needed. This could cause a type mismatch for memory instructions. We add code to fix up OpLoad instructions, so that we could get some tests. Other memory operations will be handled in another PR. Part of #134119.
1 parent b86b529 commit 492ad84

File tree

7 files changed

+523
-124
lines changed

7 files changed

+523
-124
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 188 additions & 105 deletions
Large diffs are not rendered by default.

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,14 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
9090
// Add a new OpTypeXXX instruction without checking for duplicates.
9191
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
9292
SPIRV::AccessQualifier::AccessQualifier AQ,
93-
bool EmitIR);
93+
bool ExplicitLayoutRequired, bool EmitIR);
9494
SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
9595
SPIRV::AccessQualifier::AccessQualifier accessQual,
96-
bool EmitIR);
96+
bool ExplicitLayoutRequired, bool EmitIR);
9797
SPIRVType *
9898
restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
9999
SPIRV::AccessQualifier::AccessQualifier AccessQual,
100-
bool EmitIR);
100+
bool ExplicitLayoutRequired, bool EmitIR);
101101

102102
// Internal function creating the an OpType at the correct position in the
103103
// function by tweaking the passed "MIRBuilder" insertion point and restoring
@@ -298,10 +298,19 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
298298
// EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes)
299299
// because this method may be called from InstructionSelector and we don't
300300
// want to emit extra IR instructions there.
301+
SPIRVType *getOrCreateSPIRVType(const Type *Type, MachineInstr &I,
302+
SPIRV::AccessQualifier::AccessQualifier AQ,
303+
bool EmitIR) {
304+
MachineIRBuilder MIRBuilder(I);
305+
return getOrCreateSPIRVType(Type, MIRBuilder, AQ, EmitIR);
306+
}
307+
301308
SPIRVType *getOrCreateSPIRVType(const Type *Type,
302309
MachineIRBuilder &MIRBuilder,
303310
SPIRV::AccessQualifier::AccessQualifier AQ,
304-
bool EmitIR);
311+
bool EmitIR) {
312+
return getOrCreateSPIRVType(Type, MIRBuilder, AQ, false, EmitIR);
313+
}
305314

306315
const Type *getTypeForSPIRVType(const SPIRVType *Ty) const {
307316
auto Res = SPIRVToLLVMType.find(Ty);
@@ -364,6 +373,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
364373
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
365374
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
366375

376+
// Returns true if `Type` is a resource type. This could be an image type
377+
// or a struct for a buffer decorated with the block decoration.
378+
bool isResourceType(SPIRVType *Type) const;
379+
367380
// Return number of elements in a vector if the argument is associated with
368381
// a vector type. Return 1 for a scalar type, and 0 for a missing type.
369382
unsigned getScalarOrVectorComponentCount(Register VReg) const;
@@ -414,6 +427,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
414427
const Type *adjustIntTypeByWidth(const Type *Ty) const;
415428
unsigned adjustOpTypeIntWidth(unsigned Width) const;
416429

430+
SPIRVType *getOrCreateSPIRVType(const Type *Type,
431+
MachineIRBuilder &MIRBuilder,
432+
SPIRV::AccessQualifier::AccessQualifier AQ,
433+
bool ExplicitLayoutRequired, bool EmitIR);
434+
417435
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
418436
bool IsSigned = false);
419437

@@ -425,14 +443,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
425443
MachineIRBuilder &MIRBuilder);
426444

427445
SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
428-
MachineIRBuilder &MIRBuilder, bool EmitIR);
446+
MachineIRBuilder &MIRBuilder,
447+
bool ExplicitLayoutRequired, bool EmitIR);
429448

430449
SPIRVType *getOpTypeOpaque(const StructType *Ty,
431450
MachineIRBuilder &MIRBuilder);
432451

433452
SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
434453
SPIRV::AccessQualifier::AccessQualifier AccQual,
435-
bool EmitIR);
454+
bool ExplicitLayoutRequired, bool EmitIR);
436455

437456
SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
438457
SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
@@ -475,6 +494,12 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
475494
MachineIRBuilder &MIRBuilder,
476495
SPIRV::StorageClass::StorageClass SC);
477496

497+
void addStructOffsetDecorations(Register Reg, StructType *Ty,
498+
MachineIRBuilder &MIRBuilder);
499+
void addArrayStrideDecorations(Register Reg, Type *ElementType,
500+
MachineIRBuilder &MIRBuilder);
501+
bool hasBlockDecoration(SPIRVType *Type) const;
502+
478503
public:
479504
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
480505
SPIRVType *SpvType, bool EmitIR,
@@ -545,9 +570,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
545570
SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
546571
unsigned NumElements, MachineInstr &I,
547572
const SPIRVInstrInfo &TII);
548-
SPIRVType *getOrCreateSPIRVArrayType(SPIRVType *BaseType,
549-
unsigned NumElements, MachineInstr &I,
550-
const SPIRVInstrInfo &TII);
551573

552574
// Returns a pointer to a SPIR-V pointer type with the given base type and
553575
// storage class. The base type will be translated to a SPIR-V type, and the

llvm/lib/Target/SPIRV/SPIRVIRMapping.h

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum SpecialTypeKind {
6666
STK_Value,
6767
STK_MachineInstr,
6868
STK_VkBuffer,
69+
STK_ExplictLayoutType,
6970
STK_Last = -1
7071
};
7172

@@ -150,6 +151,11 @@ inline IRHandle irhandle_vkbuffer(const Type *ElementType,
150151
SpecialTypeKind::STK_VkBuffer);
151152
}
152153

154+
inline IRHandle irhandle_explict_layout_type(const Type *Ty) {
155+
const Type *WrpTy = unifyPtrType(Ty);
156+
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_ExplictLayoutType);
157+
}
158+
153159
inline IRHandle handle(const Type *Ty) {
154160
const Type *WrpTy = unifyPtrType(Ty);
155161
return irhandle_ptr(WrpTy, Ty->getTypeID(), STK_Type);
@@ -163,6 +169,10 @@ inline IRHandle handle(const MachineInstr *KeyMI) {
163169
return irhandle_ptr(KeyMI, SPIRV::to_hash(KeyMI), STK_MachineInstr);
164170
}
165171

172+
inline bool type_has_layout_decoration(const Type *T) {
173+
return (isa<StructType>(T) || isa<ArrayType>(T));
174+
}
175+
166176
} // namespace SPIRV
167177

168178
// Bi-directional mappings between LLVM entities and (v-reg, machine function)
@@ -238,14 +248,49 @@ class SPIRVIRMapping {
238248
return findMI(SPIRV::irhandle_pointee(PointeeTy, AddressSpace), MF);
239249
}
240250

241-
template <typename T> bool add(const T *Obj, const MachineInstr *MI) {
251+
bool add(const Value *V, const MachineInstr *MI) {
252+
return add(SPIRV::handle(V), MI);
253+
}
254+
255+
bool add(const Type *T, bool RequiresExplicitLayout, const MachineInstr *MI) {
256+
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T)) {
257+
return add(SPIRV::irhandle_explict_layout_type(T), MI);
258+
}
259+
return add(SPIRV::handle(T), MI);
260+
}
261+
262+
bool add(const MachineInstr *Obj, const MachineInstr *MI) {
242263
return add(SPIRV::handle(Obj), MI);
243264
}
244-
template <typename T> Register find(const T *Obj, const MachineFunction *MF) {
245-
return find(SPIRV::handle(Obj), MF);
265+
266+
Register find(const Value *V, const MachineFunction *MF) {
267+
return find(SPIRV::handle(V), MF);
268+
}
269+
270+
Register find(const Type *T, bool RequiresExplicitLayout,
271+
const MachineFunction *MF) {
272+
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
273+
return find(SPIRV::irhandle_explict_layout_type(T), MF);
274+
return find(SPIRV::handle(T), MF);
275+
}
276+
277+
Register find(const MachineInstr *MI, const MachineFunction *MF) {
278+
return find(SPIRV::handle(MI), MF);
279+
}
280+
281+
const MachineInstr *findMI(const Value *Obj, const MachineFunction *MF) {
282+
return findMI(SPIRV::handle(Obj), MF);
283+
}
284+
285+
const MachineInstr *findMI(const Type *T, bool RequiresExplicitLayout,
286+
const MachineFunction *MF) {
287+
if (RequiresExplicitLayout && SPIRV::type_has_layout_decoration(T))
288+
return findMI(SPIRV::irhandle_explict_layout_type(T), MF);
289+
return findMI(SPIRV::handle(T), MF);
246290
}
247-
template <typename T>
248-
const MachineInstr *findMI(const T *Obj, const MachineFunction *MF) {
291+
292+
const MachineInstr *findMI(const MachineInstr *Obj,
293+
const MachineFunction *MF) {
249294
return findMI(SPIRV::handle(Obj), MF);
250295
}
251296
};

llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,42 @@
2525

2626
using namespace llvm;
2727

28+
// Returns true of the types logically match, as defined in
29+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCopyLogical.
30+
static bool typesLogicallyMatch(const SPIRVType *Ty1, const SPIRVType *Ty2,
31+
SPIRVGlobalRegistry &GR) {
32+
if (Ty1->getOpcode() != Ty2->getOpcode())
33+
return false;
34+
35+
if (Ty1->getNumOperands() != Ty2->getNumOperands())
36+
return false;
37+
38+
if (Ty1->getOpcode() == SPIRV::OpTypeArray) {
39+
// Array must have the same size.
40+
if (Ty1->getOperand(2).getReg() != Ty2->getOperand(2).getReg())
41+
return false;
42+
43+
SPIRVType *ElemType1 = GR.getSPIRVTypeForVReg(Ty1->getOperand(1).getReg());
44+
SPIRVType *ElemType2 = GR.getSPIRVTypeForVReg(Ty2->getOperand(1).getReg());
45+
return ElemType1 == ElemType2 ||
46+
typesLogicallyMatch(ElemType1, ElemType2, GR);
47+
}
48+
49+
if (Ty1->getOpcode() == SPIRV::OpTypeStruct) {
50+
for (unsigned I = 1; I < Ty1->getNumOperands(); I++) {
51+
SPIRVType *ElemType1 =
52+
GR.getSPIRVTypeForVReg(Ty1->getOperand(I).getReg());
53+
SPIRVType *ElemType2 =
54+
GR.getSPIRVTypeForVReg(Ty2->getOperand(I).getReg());
55+
if (ElemType1 != ElemType2 &&
56+
!typesLogicallyMatch(ElemType1, ElemType2, GR))
57+
return false;
58+
}
59+
return true;
60+
}
61+
return false;
62+
}
63+
2864
unsigned SPIRVTargetLowering::getNumRegistersForCallingConv(
2965
LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
3066
// This code avoids CallLowering fail inside getVectorTypeBreakdown
@@ -374,6 +410,9 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
374410
// implies that %Op is a pointer to <ResType>
375411
case SPIRV::OpLoad:
376412
// OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
413+
if (enforcePtrTypeCompatibility(MI, 2, 0))
414+
break;
415+
377416
validatePtrTypes(STI, MRI, GR, MI, 2,
378417
GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
379418
break;
@@ -531,3 +570,58 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
531570
ProcessedMF.insert(&MF);
532571
TargetLowering::finalizeLowering(MF);
533572
}
573+
574+
// Modifies either operand PtrOpIdx or OpIdx so that the pointee type of
575+
// PtrOpIdx matches the type for operand OpIdx. Returns true if they already
576+
// match or if the instruction was modified to make them match.
577+
bool SPIRVTargetLowering::enforcePtrTypeCompatibility(
578+
MachineInstr &I, unsigned int PtrOpIdx, unsigned int OpIdx) const {
579+
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
580+
SPIRVType *PtrType = GR.getResultType(I.getOperand(PtrOpIdx).getReg());
581+
SPIRVType *PointeeType = GR.getPointeeType(PtrType);
582+
SPIRVType *OpType = GR.getResultType(I.getOperand(OpIdx).getReg());
583+
584+
if (PointeeType == OpType)
585+
return true;
586+
587+
if (typesLogicallyMatch(PointeeType, OpType, GR)) {
588+
// Apply OpCopyLogical to OpIdx.
589+
if (I.getOperand(OpIdx).isDef() &&
590+
insertLogicalCopyOnResult(I, PointeeType)) {
591+
return true;
592+
}
593+
594+
llvm_unreachable("Unable to add OpCopyLogical yet.");
595+
return false;
596+
}
597+
598+
return false;
599+
}
600+
601+
bool SPIRVTargetLowering::insertLogicalCopyOnResult(
602+
MachineInstr &I, SPIRVType *NewResultType) const {
603+
MachineRegisterInfo *MRI = &I.getMF()->getRegInfo();
604+
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
605+
606+
Register NewResultReg =
607+
createVirtualRegister(NewResultType, &GR, MRI, *I.getMF());
608+
Register NewTypeReg = GR.getSPIRVTypeID(NewResultType);
609+
610+
assert(std::distance(I.defs().begin(), I.defs().end()) == 1 &&
611+
"Expected only one def");
612+
MachineOperand &OldResult = *I.defs().begin();
613+
Register OldResultReg = OldResult.getReg();
614+
MachineOperand &OldType = *I.uses().begin();
615+
Register OldTypeReg = OldType.getReg();
616+
617+
OldResult.setReg(NewResultReg);
618+
OldType.setReg(NewTypeReg);
619+
620+
MachineIRBuilder MIB(*I.getNextNode());
621+
return MIB.buildInstr(SPIRV::OpCopyLogical)
622+
.addDef(OldResultReg)
623+
.addUse(OldTypeReg)
624+
.addUse(NewResultReg)
625+
.constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
626+
*STI.getRegBankInfo());
627+
}

llvm/lib/Target/SPIRV/SPIRVISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ class SPIRVTargetLowering : public TargetLowering {
7171
EVT ConditionVT) const override {
7272
return ConditionVT.getSimpleVT();
7373
}
74+
75+
bool enforcePtrTypeCompatibility(MachineInstr &I, unsigned PtrOpIdx,
76+
unsigned OpIdx) const;
77+
bool insertLogicalCopyOnResult(MachineInstr &I,
78+
SPIRVType *NewResultType) const;
7479
};
7580
} // namespace llvm
7681

llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@ declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handle
1111

1212
; CHECK: OpDecorate [[BufferVar:%.+]] DescriptorSet 0
1313
; CHECK: OpDecorate [[BufferVar]] Binding 0
14-
; CHECK: OpDecorate [[BufferType:%.+]] Block
15-
; CHECK: OpMemberDecorate [[BufferType]] 0 Offset 0
14+
; CHECK: OpMemberDecorate [[BufferType:%.+]] 0 Offset 0
15+
; CHECK: OpDecorate [[BufferType]] Block
1616
; CHECK: OpMemberDecorate [[BufferType]] 0 NonWritable
1717
; CHECK: OpDecorate [[RWBufferVar:%.+]] DescriptorSet 0
1818
; CHECK: OpDecorate [[RWBufferVar]] Binding 1
19-
; CHECK: OpDecorate [[RWBufferType:%.+]] Block
20-
; CHECK: OpMemberDecorate [[RWBufferType]] 0 Offset 0
19+
; CHECK: OpDecorate [[ArrayType:%.+]] ArrayStride 4
20+
; CHECK: OpMemberDecorate [[RWBufferType:%.+]] 0 Offset 0
21+
; CHECK: OpDecorate [[RWBufferType]] Block
2122

2223

2324
; CHECK: [[int:%[0-9]+]] = OpTypeInt 32 0
24-
; CHECK: [[ArrayType:%.+]] = OpTypeRuntimeArray
25+
; CHECK: [[ArrayType]] = OpTypeRuntimeArray
2526
; CHECK: [[RWBufferType]] = OpTypeStruct [[ArrayType]]
2627
; CHECK: [[RWBufferPtrType:%.+]] = OpTypePointer StorageBuffer [[RWBufferType]]
2728
; CHECK: [[BufferType]] = OpTypeStruct [[ArrayType]]

0 commit comments

Comments
 (0)