Skip to content

[SPIRV] support for extension SPV_INTEL_maximum_registers #137229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
* - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
- Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
* - ``SPV_INTEL_maximum_registers``
- This extension adds an execution mode to specify the maximum number of registers a SPIR-V consumer should use when compiling an entry point.

To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:

Expand Down Expand Up @@ -489,6 +491,11 @@ SPIR-V backend, along with their descriptions and argument details.
- `[spirv.Image Image, 32-bit Integer coordinate, vec4 data]`
- Stores the data to the image buffer at the given coordinate. The \
data must be a 4-element vector.
* - `int_spv_max_reg_constant`
- void
- `[32-bit Integer Maximum registers allowed]`
- 32-bit integer indicating the maximum number of registers allowed. \
Used as an operand to the OpExecutionModeId MaximumRegistersIdINTEL.

.. _spirv-builtin-functions:

Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ let TargetPrefix = "spv" in {

// FPMaxErrorDecorationINTEL
def int_spv_assign_fpmaxerror_decoration: Intrinsic<[], [llvm_any_ty, llvm_metadata_ty]>;

// ExecutionModeMaxRedId
def int_spv_max_reg_constant : Intrinsic<[], [llvm_i32_ty], [ImmArg<ArgIndex<0>>]>;

// Convert between the generic storage class and a concrete one.
def int_spv_generic_cast_to_ptr_explicit
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ namespace CooperativeMatrixOperands {
#include "SPIRVGenTables.inc"
} // namespace CooperativeMatrixOperands

namespace NamedMaximumNumberOfRegisters {
#define GET_NamedMaximumNumberOfRegisters_DECL
#include "SPIRVGenTables.inc"
} // namespace NamedMaximumNumberOfRegisters

struct ExtendedBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
Expand Down
18 changes: 17 additions & 1 deletion llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Please fix the formatting issues.

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormattedStream.h"
#include "llvm/Support/raw_ostream.h"

using namespace llvm;
using namespace llvm::SPIRV;
Expand Down Expand Up @@ -147,7 +148,22 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
case SPIRV::OpMemberDecorate:
printRemainingVariableOps(MI, NumFixedOps, OS);
break;
case SPIRV::OpExecutionMode:
case SPIRV::OpExecutionMode: {
unsigned NumOperands = MI->getNumOperands();
if (NumOperands != NumFixedOps) {
const unsigned MaxRegVal =
MI->getOperand(FirstVariableIndex).getImm();
if (MaxRegVal == 0) {

OS << ' ';
printSymbolicOperand<
OperandCategory::NamedMaximumNumberOfRegistersOperand>(
MI, FirstVariableIndex, OS);
break;
}
}
printRemainingVariableOps(MI, NumFixedOps, OS);
} break;
case SPIRV::OpExecutionModeId:
case SPIRV::OpLoopMerge: {
// Print any literals after the OPERAND_UNKNOWN argument normally.
Expand Down
46 changes: 46 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class SPIRVAsmPrinter : public AsmPrinter {
void outputExecutionModeFromNumthreadsAttribute(
const MCRegister &Reg, const Attribute &Attr,
SPIRV::ExecutionMode::ExecutionMode EM);
void outputExecutionModeFromRegisterAllocMode(const MCRegister &Reg,
const MDNode *Node,
MachineFunction *MF);
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();
Expand Down Expand Up @@ -492,6 +495,45 @@ void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute(
outputMCInst(Inst);
}

// outputs the execution mode for the extension SPV_INTEL_maximum_registers
void SPIRVAsmPrinter::outputExecutionModeFromRegisterAllocMode(
const MCRegister &Reg, const MDNode *Node, MachineFunction *MF) {
MCInst Inst;
auto *RegisterAllocMode = Node->getOperand(0).get();
Inst.setOpcode(SPIRV::OpExecutionMode);
Inst.addOperand(MCOperand::createReg(Reg));

if (auto *MDS = dyn_cast<MDString>(RegisterAllocMode)) {
StringRef Str = MDS->getString();
if (Str.equals_insensitive("AutoINTEL")) {
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(
SPIRV::ExecutionMode::NamedMaximumRegistersINTEL)));
Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(
SPIRV::NamedMaximumNumberOfRegisters::AutoINTEL)));
}
} else if (MDNode *NestedNode = dyn_cast<MDNode>(RegisterAllocMode)) {
if (auto *CMD = dyn_cast<ConstantAsMetadata>(NestedNode->getOperand(0))) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(CMD->getValue())) {
Inst.setOpcode(SPIRV::OpExecutionModeId);
Inst.addOperand(MCOperand::createImm(
SPIRV::ExecutionMode::MaximumRegistersIdINTEL));
auto *GR = ST->getSPIRVGlobalRegistry();
Register MaxOpConstantReg = GR->getMaxRegConstantExtMap(MF);
MCRegister MaxRegister = MAI->getRegisterAlias(MF, MaxOpConstantReg);
Inst.addOperand(MCOperand::createReg(MaxRegister));
}
}
} else {

int64_t RegisterAllocVal =
mdconst::dyn_extract<ConstantInt>(RegisterAllocMode)->getZExtValue();
Inst.addOperand(
MCOperand::createImm(SPIRV::ExecutionMode::MaximumRegistersINTEL));
Inst.addOperand(MCOperand::createImm(RegisterAllocVal));
}
outputMCInst(Inst);
}

void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
Expand Down Expand Up @@ -532,6 +574,10 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
Inst.addOperand(MCOperand::createImm(TypeCode));
outputMCInst(Inst);
}
if (MDNode *Node = F.getMetadata("RegisterAllocMode")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest not to rely on internal to internal https://github.com/intel/llvm metadata. Instead lets use !spirv.ExecutionMode metadata, see https://github.com/KhronosGroup/SPIRV-LLVM-Translator/blob/main/docs/SPIRVRepresentationInLLVM.rst as 'SPIR-V friendly LLVM IR' is also used in SPIR-V backend.

MachineFunction *MF = MMI->getMachineFunction(F);
outputExecutionModeFromRegisterAllocMode(FReg, Node, MF);
}
if (ST->isOpenCLEnv() && !M.getNamedMetadata("spirv.ExecutionMode") &&
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
MCInst Inst;
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::
SPV_INTEL_subgroup_matrix_multiply_accumulate},
{"SPV_INTEL_ternary_bitwise_function",
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function}};
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
{"SPV_INTEL_maximum_registers",
SPIRV::Extension::SPV_INTEL_maximum_registers}};

bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
Expand Down
27 changes: 26 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class SPIRVEmitIntrinsics
unsigned OperandToReplace,
IRBuilder<> &B);
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void insertMaxRegIdExecModeIntrs(Function *F, IRBuilder<> &B);
bool shouldTryToAddMemAliasingDecoration(Instruction *Inst);
void insertSpirvDecorations(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
Expand Down Expand Up @@ -2210,6 +2211,30 @@ void SPIRVEmitIntrinsics::processParamTypesByFunHeader(Function *F,
}
}

void SPIRVEmitIntrinsics::insertMaxRegIdExecModeIntrs(Function *F,
IRBuilder<> &B) {
MDNode *Node = F->getMetadata("RegisterAllocMode");

if (Node) {
Metadata *RegisterAllocMode = Node->getOperand(0).get();
// spv_max_reg_constant is added to add the OpConstant instruction which
// will be then used as operand for OpExecutionMode MaximumRegistersIdINTEL
if (MDNode *NestedNode = dyn_cast<MDNode>(RegisterAllocMode)) {
if (auto *CMD = dyn_cast<ConstantAsMetadata>(NestedNode->getOperand(0))) {
auto *CI = dyn_cast<ConstantInt>(CMD->getValue());
if (!CI)
return;
int32_t MaxRegNumExt = CI->getSExtValue();
B.SetInsertPointPastAllocas(F);
Value *MaxRegNumExtVal =
ConstantInt::get(Type::getInt32Ty(B.getContext()), MaxRegNumExt);
B.CreateIntrinsic(Intrinsic::spv_max_reg_constant, {},
{MaxRegNumExtVal});
}
}
}
}

void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
B.SetInsertPointPastAllocas(F);
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
Expand Down Expand Up @@ -2385,7 +2410,6 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
}

processParamTypesByFunHeader(CurrF, B);

// StoreInst's operand type can be changed during the next transformations,
// so we need to store it in the set. Also store already transformed types.
for (auto &I : instructions(Func)) {
Expand Down Expand Up @@ -2419,6 +2443,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
insertAssignTypeIntrs(I, B);
insertPtrCastOrAssignTypeInstr(I, B);
insertSpirvDecorations(I, B);
insertMaxRegIdExecModeIntrs(CurrF, B);
// if instruction requires a pointee type set, let's check if we know it
// already, and force it to be i8 if not
if (Postpone && !GR->findAssignPtrTypeInstr(I))
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// Number of bits pointers and size_t integers require.
const unsigned PointerSize;

// Maps each MachineFunction to its associated OpConstant register for the
// SPV_INTEL_maximum_registers extension with ExecutionModeId
// MaximumRegistersIdINTEL.
DenseMap<const MachineFunction *, Register> MaxRegConstantExtMap;

// Holds the maximum ID we have in the module.
unsigned Bound;

Expand Down Expand Up @@ -114,6 +119,15 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
void setBound(unsigned V) { Bound = V; }
unsigned getBound() { return Bound; }

void addMaxRegConstantRegisterExt(const MachineFunction *MF, Register Reg) {
MaxRegConstantExtMap[MF] = Reg;
}

Register getMaxRegConstantExtMap(const MachineFunction *MF) {
assert(MaxRegConstantExtMap.count(MF) &&
"MachineFunction not found in MaxRegConstantExtMap");
return MaxRegConstantExtMap[MF];
}
void addGlobalObject(const Value *V, const MachineFunction *MF, Register R) {
Reg2GO[std::make_pair(MF, R)] = V;
}
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3191,6 +3191,13 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
case Intrinsic::spv_discard: {
return selectDiscard(ResVReg, ResType, I);
}
case Intrinsic::spv_max_reg_constant: {
int32_t MaxRegNum = I.getOperand(1).getImm();
auto ConstMaxRegId = buildI32Constant(MaxRegNum, I);
Register MaxIdRegister = ConstMaxRegId.first;
GR.addMaxRegConstantRegisterExt(MF, MaxIdRegister);
return ConstMaxRegId.second;
} break;
default: {
std::string DiagMsg;
raw_string_ostream OS(DiagMsg);
Expand Down
50 changes: 48 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,50 @@ static bool isImageTypeWithUnknownFormat(SPIRVType *TypeInst) {
return TypeInst->getOperand(7).getImm() == 0;
}

static void setExtMaxRegId(const MachineFunction *MF,
SPIRV::ModuleAnalysisInfo &MAI,
SPIRVGlobalRegistry *GR) {
Register VirtualReg = GR->getMaxRegConstantExtMap(MF);
MCRegister MaxOpConstantReg = MAI.getNextIDRegister();
MAI.setRegisterAlias(MF, VirtualReg, MaxOpConstantReg);
}

static void transFunctionMetadataAsExecutionMode(const Function &F,
const SPIRVSubtarget &ST,
SPIRV::ModuleAnalysisInfo &MAI,
MachineFunction *MF) {
SmallVector<MDNode *, 1> RegisterAllocModeMDs;
F.getMetadata("RegisterAllocMode", RegisterAllocModeMDs);
if (!RegisterAllocModeMDs.empty()) {
MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_maximum_registers);
MAI.Reqs.addCapability(SPIRV::Capability::RegisterLimitsINTEL);
}
for (unsigned I = 0; I < RegisterAllocModeMDs.size(); I++) {

auto *RegisterAllocMode = RegisterAllocModeMDs[I]->getOperand(0).get();
if (auto *MDS = dyn_cast<MDString>(RegisterAllocMode)) {
MAI.Reqs.getAndAddRequirements(
SPIRV::OperandCategory::ExecutionModeOperand,
SPIRV::ExecutionMode::NamedMaximumRegistersINTEL, ST);
} else if (isa<MDNode>(RegisterAllocMode)) {
MDNode *NestedNode = dyn_cast<MDNode>(RegisterAllocMode);
if (auto *CMD = dyn_cast<ConstantAsMetadata>(NestedNode->getOperand(0))) {
auto *CI = dyn_cast<ConstantInt>(CMD->getValue());
if (!CI)
break;
MAI.Reqs.getAndAddRequirements(
SPIRV::OperandCategory::ExecutionModeOperand,
SPIRV::ExecutionMode::MaximumRegistersIdINTEL, ST);
setExtMaxRegId(MF, MAI, ST.getSPIRVGlobalRegistry());
}
} else {
MAI.Reqs.getAndAddRequirements(
SPIRV::OperandCategory::ExecutionModeOperand,
SPIRV::ExecutionMode::MaximumRegistersINTEL, ST);
}
}
}

static void AddDotProductRequirements(const MachineInstr &MI,
SPIRV::RequirementHandler &Reqs,
const SPIRVSubtarget &ST) {
Expand Down Expand Up @@ -1921,7 +1965,10 @@ static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
MAI.Reqs.getAndAddRequirements(
SPIRV::OperandCategory::ExecutionModeOperand,
SPIRV::ExecutionMode::VecTypeHint, ST);

if (F.getMetadata("RegisterAllocMode")) {
MachineFunction *MF = MMI->getMachineFunction(F);
transFunctionMetadataAsExecutionMode(F, ST, MAI, MF);
}
if (F.hasOptNone()) {
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_optnone)) {
MAI.Reqs.addExtension(SPIRV::Extension::SPV_INTEL_optnone);
Expand Down Expand Up @@ -2059,7 +2106,6 @@ bool SPIRVModuleAnalysis::runOnModule(Module &M) {

// Process type/const/global var/func decl instructions, number their
// destination registers from 0 to N, collect Extensions and Capabilities.
collectReqs(M, MAI, MMI, *ST);
collectDeclarations(M);

// Number rest of registers from N+1 onwards.
Expand Down
32 changes: 32 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def KernelProfilingInfoOperand : OperandCategory;
def OpcodeOperand : OperandCategory;
def CooperativeMatrixLayoutOperand : OperandCategory;
def CooperativeMatrixOperandsOperand : OperandCategory;
def NamedMaximumNumberOfRegistersOperand: OperandCategory;


//===----------------------------------------------------------------------===//
// Multiclass used to define Extesions enum values and at the same time
Expand Down Expand Up @@ -315,6 +317,7 @@ defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
defm SPV_INTEL_maximum_registers : ExtensionOperand<122>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -517,6 +520,7 @@ defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory
defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
defm TernaryBitwiseFunctionINTEL : CapabilityOperand<6241, 0, 0, [SPV_INTEL_ternary_bitwise_function], []>;
defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_INTEL_subgroup_matrix_multiply_accumulate], []>;
defm RegisterLimitsINTEL : CapabilityOperand<6460 , 0, 0, [SPV_INTEL_maximum_registers], []>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down Expand Up @@ -714,6 +718,9 @@ defm RoundingModeRTPINTEL : ExecutionModeOperand<5620, [RoundToInfinityINTEL]>;
defm RoundingModeRTNINTEL : ExecutionModeOperand<5621, [RoundToInfinityINTEL]>;
defm FloatingPointModeALTINTEL : ExecutionModeOperand<5622, [FloatingPointModeINTEL]>;
defm FloatingPointModeIEEEINTEL : ExecutionModeOperand<5623, [FloatingPointModeINTEL]>;
defm MaximumRegistersINTEL : ExecutionModeOperand<6461, [RegisterLimitsINTEL]>;
defm MaximumRegistersIdINTEL : ExecutionModeOperand<6462, [RegisterLimitsINTEL]>;
defm NamedMaximumRegistersINTEL : ExecutionModeOperand<6463, [RegisterLimitsINTEL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define StorageClass enum values and at the same time
Expand Down Expand Up @@ -1746,3 +1753,28 @@ defm MatrixAAndBTF32ComponentsINTEL : CooperativeMatrixOperandsOperand<0x20, [SP
defm MatrixAAndBBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x40, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
defm MatrixCBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x80, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;
defm MatrixResultBFloat16ComponentsINTEL : CooperativeMatrixOperandsOperand<0x100, [SPV_INTEL_joint_matrix], [CooperativeMatrixBFloat16ComponentTypeINTEL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Named Max Number of Reg Operands enum values and at the
// same time SymbolicOperand entries with string mnemonics, extensions and
// capabilities.
//===----------------------------------------------------------------------===//

def NamedMaximumNumberOfRegisters : GenericEnum, Operand<i32> {
let FilterClass = "NamedMaximumNumberOfRegisters";
let NameField = "Name";
let ValueField = "Value";
let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
}

class NamedMaximumNumberOfRegisters<string name, bits<32> value> {
string Name = name;
bits<32> Value = value;
}

multiclass NamedMaximumNumberOfRegistersOperand<bits<32> value, list<Extension> reqExtensions, list<Capability> reqCapabilities> {
def : NamedMaximumNumberOfRegisters<NAME, value>;
defm : SymbolicOperandWithRequirements<NamedMaximumNumberOfRegistersOperand, value, NAME, 0, 0, reqExtensions, reqCapabilities>;
}

defm AutoINTEL : NamedMaximumNumberOfRegistersOperand<0x0, [SPV_INTEL_maximum_registers], [RegisterLimitsINTEL]>;
Loading