diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 756a72e6d97bc..3d684ef82ca7c 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -34,6 +34,7 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MatrixBuilder.h" @@ -1337,6 +1338,32 @@ class LowerMatrixIntrinsics { return Builder.CreateAdd(Sum, Mul); } + bool tryLowerExtractElement(ExtractElementInst *Inst) { + uint64_t Index; + if (!match(Inst->getOperand(1), m_ConstantInt(Index))) + return false; + + Value *Op0 = Inst->getOperand(0); + auto *VTy = cast(Op0->getType()); + + if (VTy->getElementCount().getKnownMinValue() < Index) { + Inst->replaceAllUsesWith(PoisonValue::get(VTy->getElementType())); + Inst->eraseFromParent(); + return true; + } + + auto *I = Inst2ColumnMatrix.find(Op0); + if (I == Inst2ColumnMatrix.end()) + return false; + + const MatrixTy &M = I->second; + + IRBuilder<> Builder(Inst); + Inst->setOperand(0, M.getVector(Index / M.getStride())); + Inst->setOperand(1, Builder.getInt32(Index % M.getStride())); + return true; + } + /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For /// users with shape information, there's nothing to do: they will use the /// cached value when they are lowered. For other users, \p Matrix is @@ -1351,11 +1378,16 @@ class LowerMatrixIntrinsics { ToRemove.push_back(Inst); Value *Flattened = nullptr; for (Use &U : llvm::make_early_inc_range(Inst->uses())) { - if (!ShapeMap.contains(U.getUser())) { - if (!Flattened) - Flattened = Matrix.embedInVector(Builder); - U.set(Flattened); - } + if (ShapeMap.contains(U.getUser())) + continue; + + if (auto *Extract = dyn_cast(U.getUser())) + if (tryLowerExtractElement(Extract)) + continue; + + if (!Flattened) + Flattened = Matrix.embedInVector(Builder); + U.set(Flattened); } } diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll new file mode 100644 index 0000000000000..db5444ca036ae --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/extract.ll @@ -0,0 +1,41 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define float @extract_static(ptr %in, ptr %out) { +; CHECK-LABEL: @extract_static( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <2 x float> [[COL_LOAD1]], i32 1 +; CHECK-NEXT: ret float [[EXTRACT]] +; + %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2) + %extract = extractelement <4 x float> %inv, i32 3 + ret float %extract +} + +define float @extract_static_outofbounds(ptr %in, ptr %out) { +; CHECK-LABEL: @extract_static_outofbounds( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: ret float poison +; + %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2) + %extract = extractelement <4 x float> %inv, i32 5 + ret float %extract +} + +define float @extract_dynamic(ptr %in, i32 %idx, ptr %out) { +; CHECK-LABEL: @extract_dynamic( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load volatile <2 x float>, ptr [[IN:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load volatile <2 x float>, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD1]], <4 x i32> +; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <4 x float> [[TMP1]], i32 [[IDX:%.*]] +; CHECK-NEXT: ret float [[EXTRACT]] +; + %inv = call <4 x float> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 1, i32 2, i32 2) + %extract = extractelement <4 x float> %inv, i32 %idx + ret float %extract +}