Skip to content

Commit 83835e2

Browse files
committed
[RISCV] Implement KCFI operand bundle lowering
With `-fsanitize=kcfi` (Kernel Control-Flow Integrity), Clang emits "kcfi" operand bundles to indirect call instructions. Similarly to the target-specific lowering added in D119296, implement KCFI operand bundle lowering for RISC-V. This patch disables the generic KCFI pass for RISC-V in Clang, and adds the KCFI machine function pass in `RISCVPassConfig::addPreSched` to emit target-specific `KCFI_CHECK` pseudo instructions before calls that have KCFI operand bundles. The machine function pass also bundles the instructions to ensure we emit the checks immediately before the calls, which is not possible with the generic pass. `KCFI_CHECK` instructions are lowered in `RISCVAsmPrinter` to a contiguous code sequence that traps if the expected hash in the operand bundle doesn't match the hash before the target function address. This patch emits an `ebreak` instruction for error handling to match the Linux kernel's `BUG()` implementation. Just like for X86, we also emit trap locations to a `.kcfi_traps` section to support error handling, as we cannot embed additional information to the trap instruction itself. Relands commit 62fa708 with fixed tests. Reviewed By: MaskRay Differential Revision: https://reviews.llvm.org/D148385
1 parent 4d60c65 commit 83835e2

14 files changed

+356
-2
lines changed

clang/lib/CodeGen/BackendUtil.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ static void addKCFIPass(const Triple &TargetTriple, const LangOptions &LangOpts,
631631
PassBuilder &PB) {
632632
// If the back-end supports KCFI operand bundle lowering, skip KCFIPass.
633633
if (TargetTriple.getArch() == llvm::Triple::x86_64 ||
634-
TargetTriple.isAArch64(64))
634+
TargetTriple.isAArch64(64) || TargetTriple.isRISCV())
635635
return;
636636

637637
// Ensure we lower KCFI operand bundles with -O0.

llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "RISCVMachineFunctionInfo.h"
2020
#include "RISCVTargetMachine.h"
2121
#include "TargetInfo/RISCVTargetInfo.h"
22+
#include "llvm/ADT/APInt.h"
2223
#include "llvm/ADT/Statistic.h"
2324
#include "llvm/BinaryFormat/ELF.h"
2425
#include "llvm/CodeGen/AsmPrinter.h"
@@ -72,6 +73,7 @@ class RISCVAsmPrinter : public AsmPrinter {
7273
typedef std::tuple<unsigned, uint32_t> HwasanMemaccessTuple;
7374
std::map<HwasanMemaccessTuple, MCSymbol *> HwasanMemaccessSymbols;
7475
void LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI);
76+
void LowerKCFI_CHECK(const MachineInstr &MI);
7577
void EmitHwasanMemaccessSymbols(Module &M);
7678

7779
// Wrapper needed for tblgenned pseudo lowering.
@@ -150,6 +152,9 @@ void RISCVAsmPrinter::emitInstruction(const MachineInstr *MI) {
150152
case RISCV::HWASAN_CHECK_MEMACCESS_SHORTGRANULES:
151153
LowerHWASAN_CHECK_MEMACCESS(*MI);
152154
return;
155+
case RISCV::KCFI_CHECK:
156+
LowerKCFI_CHECK(*MI);
157+
return;
153158
case RISCV::PseudoRVVInitUndefM1:
154159
case RISCV::PseudoRVVInitUndefM2:
155160
case RISCV::PseudoRVVInitUndefM4:
@@ -305,6 +310,92 @@ void RISCVAsmPrinter::LowerHWASAN_CHECK_MEMACCESS(const MachineInstr &MI) {
305310
EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::PseudoCALL).addExpr(Expr));
306311
}
307312

313+
void RISCVAsmPrinter::LowerKCFI_CHECK(const MachineInstr &MI) {
314+
Register AddrReg = MI.getOperand(0).getReg();
315+
assert(std::next(MI.getIterator())->isCall() &&
316+
"KCFI_CHECK not followed by a call instruction");
317+
assert(std::next(MI.getIterator())->getOperand(0).getReg() == AddrReg &&
318+
"KCFI_CHECK call target doesn't match call operand");
319+
320+
// Temporary registers for comparing the hashes. If a register is used
321+
// for the call target, or reserved by the user, we can clobber another
322+
// temporary register as the check is immediately followed by the
323+
// call. The check defaults to X6/X7, but can fall back to X28-X31 if
324+
// needed.
325+
unsigned ScratchRegs[] = {RISCV::X6, RISCV::X7};
326+
unsigned NextReg = RISCV::X28;
327+
auto isRegAvailable = [&](unsigned Reg) {
328+
return Reg != AddrReg && !STI->isRegisterReservedByUser(Reg);
329+
};
330+
for (auto &Reg : ScratchRegs) {
331+
if (isRegAvailable(Reg))
332+
continue;
333+
while (!isRegAvailable(NextReg))
334+
++NextReg;
335+
Reg = NextReg++;
336+
if (Reg > RISCV::X31)
337+
report_fatal_error("Unable to find scratch registers for KCFI_CHECK");
338+
}
339+
340+
if (AddrReg == RISCV::X0) {
341+
// Checking X0 makes no sense. Instead of emitting a load, zero
342+
// ScratchRegs[0].
343+
EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::ADDI)
344+
.addReg(ScratchRegs[0])
345+
.addReg(RISCV::X0)
346+
.addImm(0));
347+
} else {
348+
// Adjust the offset for patchable-function-prefix. This assumes that
349+
// patchable-function-prefix is the same for all functions.
350+
int NopSize = STI->hasStdExtCOrZca() ? 2 : 4;
351+
int64_t PrefixNops = 0;
352+
(void)MI.getMF()
353+
->getFunction()
354+
.getFnAttribute("patchable-function-prefix")
355+
.getValueAsString()
356+
.getAsInteger(10, PrefixNops);
357+
358+
// Load the target function type hash.
359+
EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::LW)
360+
.addReg(ScratchRegs[0])
361+
.addReg(AddrReg)
362+
.addImm(-(PrefixNops * NopSize + 4)));
363+
}
364+
365+
// Load the expected 32-bit type hash.
366+
const int64_t Type = MI.getOperand(1).getImm();
367+
const int64_t Hi20 = ((Type + 0x800) >> 12) & 0xFFFFF;
368+
const int64_t Lo12 = SignExtend64<12>(Type);
369+
if (Hi20) {
370+
EmitToStreamer(
371+
*OutStreamer,
372+
MCInstBuilder(RISCV::LUI).addReg(ScratchRegs[1]).addImm(Hi20));
373+
}
374+
if (Lo12 || Hi20 == 0) {
375+
EmitToStreamer(*OutStreamer,
376+
MCInstBuilder((STI->hasFeature(RISCV::Feature64Bit) && Hi20)
377+
? RISCV::ADDIW
378+
: RISCV::ADDI)
379+
.addReg(ScratchRegs[1])
380+
.addReg(ScratchRegs[1])
381+
.addImm(Lo12));
382+
}
383+
384+
// Compare the hashes and trap if there's a mismatch.
385+
MCSymbol *Pass = OutContext.createTempSymbol();
386+
EmitToStreamer(*OutStreamer,
387+
MCInstBuilder(RISCV::BEQ)
388+
.addReg(ScratchRegs[0])
389+
.addReg(ScratchRegs[1])
390+
.addExpr(MCSymbolRefExpr::create(Pass, OutContext)));
391+
392+
MCSymbol *Trap = OutContext.createTempSymbol();
393+
OutStreamer->emitLabel(Trap);
394+
EmitToStreamer(*OutStreamer, MCInstBuilder(RISCV::EBREAK));
395+
emitKCFITrapEntry(*MI.getMF(), Trap);
396+
OutStreamer->emitLabel(Pass);
397+
}
398+
308399
void RISCVAsmPrinter::EmitHwasanMemaccessSymbols(Module &M) {
309400
if (HwasanMemaccessSymbols.empty())
310401
return;

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15395,17 +15395,24 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
1539515395
if (Glue.getNode())
1539615396
Ops.push_back(Glue);
1539715397

15398+
assert((!CLI.CFIType || CLI.CB->isIndirectCall()) &&
15399+
"Unexpected CFI type for a direct call");
15400+
1539815401
// Emit the call.
1539915402
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
1540015403

1540115404
if (IsTailCall) {
1540215405
MF.getFrameInfo().setHasTailCall();
1540315406
SDValue Ret = DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
15407+
if (CLI.CFIType)
15408+
Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue());
1540415409
DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge);
1540515410
return Ret;
1540615411
}
1540715412

1540815413
Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
15414+
if (CLI.CFIType)
15415+
Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue());
1540915416
DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
1541015417
Glue = Chain.getValue(1);
1541115418

@@ -16864,6 +16871,24 @@ bool RISCVTargetLowering::lowerInterleavedStore(StoreInst *SI,
1686416871
return true;
1686516872
}
1686616873

16874+
MachineInstr *
16875+
RISCVTargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
16876+
MachineBasicBlock::instr_iterator &MBBI,
16877+
const TargetInstrInfo *TII) const {
16878+
assert(MBBI->isCall() && MBBI->getCFIType() &&
16879+
"Invalid call instruction for a KCFI check");
16880+
assert(is_contained({RISCV::PseudoCALLIndirect, RISCV::PseudoTAILIndirect},
16881+
MBBI->getOpcode()));
16882+
16883+
MachineOperand &Target = MBBI->getOperand(0);
16884+
Target.setIsRenamable(false);
16885+
16886+
return BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(RISCV::KCFI_CHECK))
16887+
.addReg(Target.getReg())
16888+
.addImm(MBBI->getCFIType())
16889+
.getInstr();
16890+
}
16891+
1686716892
#define GET_REGISTER_MATCHER
1686816893
#include "RISCVGenAsmMatcher.inc"
1686916894

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,12 @@ class RISCVTargetLowering : public TargetLowering {
759759
bool lowerInterleavedStore(StoreInst *SI, ShuffleVectorInst *SVI,
760760
unsigned Factor) const override;
761761

762+
bool supportKCFIBundles() const override { return true; }
763+
764+
MachineInstr *EmitKCFICheck(MachineBasicBlock &MBB,
765+
MachineBasicBlock::instr_iterator &MBBI,
766+
const TargetInstrInfo *TII) const override;
767+
762768
/// RISCVCCAssignFn - This target-specific function extends the default
763769
/// CCValAssign with additional information used to lower RISC-V calling
764770
/// conventions.

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,13 +1265,27 @@ unsigned RISCVInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
12651265
}
12661266
}
12671267

1268+
if (Opcode == TargetOpcode::BUNDLE)
1269+
return getInstBundleLength(MI);
1270+
12681271
if (MI.getParent() && MI.getParent()->getParent()) {
12691272
if (isCompressibleInst(MI, STI))
12701273
return 2;
12711274
}
12721275
return get(Opcode).getSize();
12731276
}
12741277

1278+
unsigned RISCVInstrInfo::getInstBundleLength(const MachineInstr &MI) const {
1279+
unsigned Size = 0;
1280+
MachineBasicBlock::const_instr_iterator I = MI.getIterator();
1281+
MachineBasicBlock::const_instr_iterator E = MI.getParent()->instr_end();
1282+
while (++I != E && I->isInsideBundle()) {
1283+
assert(!I->isBundle() && "No nested bundle!");
1284+
Size += getInstSizeInBytes(*I);
1285+
}
1286+
return Size;
1287+
}
1288+
12751289
bool RISCVInstrInfo::isAsCheapAsAMove(const MachineInstr &MI) const {
12761290
const unsigned Opcode = MI.getOpcode();
12771291
switch (Opcode) {

llvm/lib/Target/RISCV/RISCVInstrInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
237237

238238
protected:
239239
const RISCVSubtarget &STI;
240+
241+
private:
242+
unsigned getInstBundleLength(const MachineInstr &MI) const;
240243
};
241244

242245
namespace RISCV {

llvm/lib/Target/RISCV/RISCVInstrInfo.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,13 @@ def HWASAN_CHECK_MEMACCESS_SHORTGRANULES
18871887
[(int_hwasan_check_memaccess_shortgranules X5, GPRJALR:$ptr,
18881888
(i32 timm:$accessinfo))]>;
18891889

1890+
// This gets lowered into a 20-byte instruction sequence (at most)
1891+
let hasSideEffects = 0, mayLoad = 1, mayStore = 0,
1892+
Defs = [ X6, X7, X28, X29, X30, X31 ], Size = 20 in {
1893+
def KCFI_CHECK
1894+
: Pseudo<(outs), (ins GPRJALR:$ptr, i32imm:$type), []>, Sched<[]>;
1895+
}
1896+
18901897
/// Simple optimization
18911898
def : Pat<(XLenVT (add GPR:$rs1, (AddiPair:$rs2))),
18921899
(ADDI (ADDI GPR:$rs1, (AddiPairImmLarge AddiPair:$rs2)),

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
7676
RegisterTargetMachine<RISCVTargetMachine> Y(getTheRISCV64Target());
7777
auto *PR = PassRegistry::getPassRegistry();
7878
initializeGlobalISel(*PR);
79+
initializeKCFIPass(*PR);
7980
initializeRISCVMakeCompressibleOptPass(*PR);
8081
initializeRISCVGatherScatterLoweringPass(*PR);
8182
initializeRISCVCodeGenPreparePass(*PR);
@@ -333,7 +334,10 @@ bool RISCVPassConfig::addGlobalInstructionSelect() {
333334
return false;
334335
}
335336

336-
void RISCVPassConfig::addPreSched2() {}
337+
void RISCVPassConfig::addPreSched2() {
338+
// Emit KCFI checks for indirect calls.
339+
addPass(createKCFIPass());
340+
}
337341

338342
void RISCVPassConfig::addPreEmitPass() {
339343
addPass(&BranchRelaxationPassID);
@@ -357,6 +361,11 @@ void RISCVPassConfig::addPreEmitPass2() {
357361
// possibility for other passes to break the requirements for forward
358362
// progress in the LR/SC block.
359363
addPass(createRISCVExpandAtomicPseudoPass());
364+
365+
// KCFI indirect call checks are lowered to a bundle.
366+
addPass(createUnpackMachineBundles([&](const MachineFunction &MF) {
367+
return MF.getFunction().getParent()->getModuleFlag("kcfi");
368+
}));
360369
}
361370

362371
void RISCVPassConfig::addMachineSSAOptimization() {

llvm/test/CodeGen/RISCV/O0-pipeline.ll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
; CHECK-NEXT: Machine Optimization Remark Emitter
5252
; CHECK-NEXT: Prologue/Epilogue Insertion & Frame Finalization
5353
; CHECK-NEXT: Post-RA pseudo instruction expansion pass
54+
; CHECK-NEXT: Insert KCFI indirect call checks
5455
; CHECK-NEXT: Analyze Machine Code For Garbage Collection
5556
; CHECK-NEXT: Insert fentry calls
5657
; CHECK-NEXT: Insert XRay ops
@@ -66,6 +67,7 @@
6667
; CHECK-NEXT: Stack Frame Layout Analysis
6768
; CHECK-NEXT: RISC-V pseudo instruction expansion pass
6869
; CHECK-NEXT: RISC-V atomic pseudo instruction expansion pass
70+
; CHECK-NEXT: Unpack machine instruction bundles
6971
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
7072
; CHECK-NEXT: Machine Optimization Remark Emitter
7173
; CHECK-NEXT: RISC-V Assembly Printer

llvm/test/CodeGen/RISCV/O3-pipeline.ll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
; CHECK-NEXT: Tail Duplication
156156
; CHECK-NEXT: Machine Copy Propagation Pass
157157
; CHECK-NEXT: Post-RA pseudo instruction expansion pass
158+
; CHECK-NEXT: Insert KCFI indirect call checks
158159
; CHECK-NEXT: MachineDominator Tree Construction
159160
; CHECK-NEXT: Machine Natural Loop Construction
160161
; CHECK-NEXT: Post RA top-down list latency scheduler
@@ -180,6 +181,7 @@
180181
; CHECK-NEXT: RISC-V Zcmp move merging pass
181182
; CHECK-NEXT: RISC-V pseudo instruction expansion pass
182183
; CHECK-NEXT: RISC-V atomic pseudo instruction expansion pass
184+
; CHECK-NEXT: Unpack machine instruction bundles
183185
; CHECK-NEXT: Lazy Machine Block Frequency Analysis
184186
; CHECK-NEXT: Machine Optimization Remark Emitter
185187
; CHECK-NEXT: RISC-V Assembly Printer
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 2
2+
; RUN: llc -mtriple=riscv64 -stop-after=finalize-isel -verify-machineinstrs -o - %s | FileCheck %s
3+
define void @f1(ptr noundef %x) !kcfi_type !1 {
4+
; CHECK-LABEL: name: f1
5+
; CHECK: bb.0 (%ir-block.0):
6+
; CHECK-NEXT: liveins: $x10
7+
; CHECK-NEXT: {{ $}}
8+
; CHECK-NEXT: [[COPY:%[0-9]+]]:gprjalr = COPY $x10
9+
; CHECK-NEXT: ADJCALLSTACKDOWN 0, 0, implicit-def dead $x2, implicit $x2
10+
; CHECK-NEXT: PseudoCALLIndirect [[COPY]], csr_ilp32_lp64, implicit-def dead $x1, implicit-def $x2, cfi-type 12345678
11+
; CHECK-NEXT: ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
12+
; CHECK-NEXT: PseudoRET
13+
call void %x() [ "kcfi"(i32 12345678) ]
14+
ret void
15+
}
16+
17+
define void @f2(ptr noundef %x) #0 {
18+
; CHECK-LABEL: name: f2
19+
; CHECK: bb.0 (%ir-block.0):
20+
; CHECK-NEXT: liveins: $x10
21+
; CHECK-NEXT: {{ $}}
22+
; CHECK-NEXT: [[COPY:%[0-9]+]]:gprtc = COPY $x10
23+
; CHECK-NEXT: PseudoTAILIndirect [[COPY]], implicit $x2, cfi-type 12345678
24+
tail call void %x() [ "kcfi"(i32 12345678) ]
25+
ret void
26+
}
27+
28+
attributes #0 = { "patchable-function-entry"="2" }
29+
30+
!llvm.module.flags = !{!0}
31+
32+
!0 = !{i32 4, !"kcfi", i32 1}
33+
!1 = !{i32 12345678}

llvm/test/CodeGen/RISCV/kcfi-mir.ll

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 2
2+
; RUN: llc -mtriple=riscv64 -stop-after=kcfi -verify-machineinstrs -o - %s | FileCheck %s
3+
4+
define void @f1(ptr noundef %x) !kcfi_type !1 {
5+
; CHECK-LABEL: name: f1
6+
; CHECK: bb.0 (%ir-block.0):
7+
; CHECK-NEXT: liveins: $x10, $x1
8+
; CHECK-NEXT: {{ $}}
9+
; CHECK-NEXT: $x2 = frame-setup ADDI $x2, -16
10+
; CHECK-NEXT: frame-setup CFI_INSTRUCTION def_cfa_offset 16
11+
; CHECK-NEXT: SD killed $x1, $x2, 8 :: (store (s64) into %stack.0)
12+
; CHECK-NEXT: frame-setup CFI_INSTRUCTION offset $x1, -8
13+
; CHECK-NEXT: BUNDLE implicit-def $x6, implicit-def $x7, implicit-def $x28, implicit-def $x29, implicit-def $x30, implicit-def $x31, implicit-def dead $x1, implicit-def $x2, implicit killed $x10 {
14+
; CHECK-NEXT: KCFI_CHECK $x10, 12345678, implicit-def $x6, implicit-def $x7, implicit-def $x28, implicit-def $x29, implicit-def $x30, implicit-def $x31
15+
; CHECK-NEXT: PseudoCALLIndirect killed $x10, csr_ilp32_lp64, implicit-def dead $x1, implicit-def $x2
16+
; CHECK-NEXT: }
17+
; CHECK-NEXT: $x1 = LD $x2, 8 :: (load (s64) from %stack.0)
18+
; CHECK-NEXT: $x2 = frame-destroy ADDI $x2, 16
19+
; CHECK-NEXT: PseudoRET
20+
call void %x() [ "kcfi"(i32 12345678) ]
21+
ret void
22+
}
23+
24+
define void @f2(ptr noundef %x) #0 {
25+
; CHECK-LABEL: name: f2
26+
; CHECK: bb.0 (%ir-block.0):
27+
; CHECK-NEXT: liveins: $x10
28+
; CHECK-NEXT: {{ $}}
29+
; CHECK-NEXT: BUNDLE implicit-def $x6, implicit-def $x7, implicit-def $x28, implicit-def $x29, implicit-def $x30, implicit-def $x31, implicit killed $x10, implicit $x2 {
30+
; CHECK-NEXT: KCFI_CHECK $x10, 12345678, implicit-def $x6, implicit-def $x7, implicit-def $x28, implicit-def $x29, implicit-def $x30, implicit-def $x31
31+
; CHECK-NEXT: PseudoTAILIndirect killed $x10, implicit $x2
32+
; CHECK-NEXT: }
33+
tail call void %x() [ "kcfi"(i32 12345678) ]
34+
ret void
35+
}
36+
37+
attributes #0 = { "patchable-function-entry"="2" }
38+
39+
!llvm.module.flags = !{!0}
40+
41+
!0 = !{i32 4, !"kcfi", i32 1}
42+
!1 = !{i32 12345678}

0 commit comments

Comments
 (0)