diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def index f974cfc78c8dd..1d5ae099c76b4 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -1418,3 +1418,26 @@ __OMP_ASSUME_CLAUSE(llvm::StringLiteral("no_parallelism"), false, false, false) #undef __OMP_ASSUME_CLAUSE #undef OMP_ASSUME_CLAUSE ///} + + +/// Callback specification +/// +///{ + +#ifndef OMP_CALLBACK +#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...) +#endif + +#define __OMP_CALLBACK(Name, VarArgsArePassed, CallbackArgNo, ...) \ + OMP_CALLBACK(OMPRTL_##Name, VarArgsArePassed, CallbackArgNo, __VA_ARGS__) + +__OMP_CALLBACK(__kmpc_fork_call, true, 2, -1, -1) +__OMP_CALLBACK(__kmpc_fork_call_if, false, 2, -1, -1) +__OMP_CALLBACK(__kmpc_fork_teams, true, 2, -1, -1) +__OMP_CALLBACK(__kmpc_omp_task_alloc, false, 5, -1, -1) + +#undef __OMP_PTR_TYPE + +#undef __OMP_TYPE +#undef OMP_CALLBACK +///} diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index ca3d8438654dc..ac42cd9ab3297 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -614,21 +614,27 @@ OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) { #include "llvm/Frontend/OpenMP/OMPKinds.def" } - // Add information if the runtime function takes a callback function - if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) { - if (!Fn->hasMetadata(LLVMContext::MD_callback)) { - LLVMContext &Ctx = Fn->getContext(); - MDBuilder MDB(Ctx); - // Annotate the callback behavior of the runtime function: - // - The callback callee is argument number 2 (microtask). - // - The first two arguments of the callback callee are unknown (-1). - // - All variadic arguments to the runtime function are passed to the - // callback callee. - Fn->addMetadata( - LLVMContext::MD_callback, - *MDNode::get(Ctx, {MDB.createCallbackEncoding( - 2, {-1, -1}, /* VarArgsArePassed */ true)})); - } + // Annotate the callback behavior of the runtime function: + // - First the callback callee argument number + // - Then the arguments passed on to the callback (-1 for unknown), + // variadic + // - Finally, whether variadic args are passed on to the callback. + LLVMContext &Ctx = Fn->getContext(); + MDBuilder MDB(Ctx); + switch (FnID) { +#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...) \ + case Enum: { \ + if (!Fn->hasMetadata(LLVMContext::MD_callback)) { \ + Fn->addMetadata(LLVMContext::MD_callback, \ + *MDNode::get(Ctx, {MDB.createCallbackEncoding( \ + CallbackArgNo, {__VA_ARGS__}, \ + VarArgsArePassed)})); \ + } \ + break; \ + } +#include "llvm/Frontend/OpenMP/OMPKinds.def" + default: + break; } LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName() diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index be98be260c9dc..3281c4cd32e16 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -7702,4 +7702,57 @@ TEST_F(OpenMPIRBuilderTest, splitBB) { EXPECT_TRUE(DL == AllocaBB->getTerminator()->getStableDebugLoc()); } +TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) { + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + + FunctionCallee ForkCall = OMPBuilder.getOrCreateRuntimeFunction( + *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call); + FunctionCallee ForkCallIf = OMPBuilder.getOrCreateRuntimeFunction( + *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call_if); + FunctionCallee ForkTeam = OMPBuilder.getOrCreateRuntimeFunction( + *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_teams); + FunctionCallee TaskAlloc = OMPBuilder.getOrCreateRuntimeFunction( + *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_omp_task_alloc); + + M->dump(); + for (auto [FC, VarArg, ArgNo] : + zip(SmallVector( + {ForkCall, ForkCallIf, ForkTeam, TaskAlloc}), + SmallVector({true, false, true, false}), + SmallVector({2, 2, 2, 5}))) { + MDNode *CallbackMD = + cast(FC.getCallee())->getMetadata(LLVMContext::MD_callback); + EXPECT_NE(CallbackMD, nullptr); + unsigned Num = 0; + CallbackMD->dump(); + M->dump(); + for (const MDOperand &Op : CallbackMD->operands()) { + Num++; + MDNode *OpMD = cast(Op.get()); + auto *CBCalleeIdxAsCM = cast(OpMD->getOperand(0)); + uint64_t CBCalleeIdx = + cast(CBCalleeIdxAsCM->getValue())->getZExtValue(); + EXPECT_EQ(CBCalleeIdx, ArgNo); + + uint64_t Arg0 = + cast( + cast(OpMD->getOperand(1))->getValue()) + ->getZExtValue(); + uint64_t Arg1 = + cast( + cast(OpMD->getOperand(2))->getValue()) + ->getZExtValue(); + uint64_t _VarArg = + cast( + cast(OpMD->getOperand(3))->getValue()) + ->getZExtValue(); + EXPECT_EQ(Arg0, -1); + EXPECT_EQ(Arg1, -1); + EXPECT_EQ(_VarArg, VarArg); + } + EXPECT_EQ(Num, 1); + } +} + } // namespace