diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 756a72e6d97bc..f964ce2d5e499 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -32,8 +32,10 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MatrixBuilder.h" @@ -232,6 +234,34 @@ static bool isUniformShape(Value *V) { if (I->isBinaryOp()) return true; + if (auto *Cast = dyn_cast(V)) { + switch (Cast->getOpcode()) { + case llvm::Instruction::Trunc: + case llvm::Instruction::ZExt: + case llvm::Instruction::SExt: + case llvm::Instruction::FPToUI: + case llvm::Instruction::FPToSI: + case llvm::Instruction::UIToFP: + case llvm::Instruction::SIToFP: + case llvm::Instruction::FPTrunc: + case llvm::Instruction::FPExt: + return true; + case llvm::Instruction::AddrSpaceCast: + case CastInst::PtrToInt: + case CastInst::IntToPtr: + return false; + case CastInst::BitCast: { + if (auto *SrcVTy = dyn_cast(Cast->getSrcTy())) + if (auto *DestVTy = dyn_cast(Cast->getDestTy())) + return SrcVTy->getNumElements() == DestVTy->getNumElements(); + return false; + } + case llvm::Instruction::CastOpsEnd: + llvm_unreachable("not an actual cast op"); + } + llvm_unreachable("unhandled cast opcode"); + } + switch (I->getOpcode()) { case Instruction::FNeg: return true; @@ -1066,9 +1096,11 @@ class LowerMatrixIntrinsics { Value *Op2; if (auto *BinOp = dyn_cast(Inst)) Changed |= VisitBinaryOperator(BinOp); - if (auto *UnOp = dyn_cast(Inst)) + else if (auto *UnOp = dyn_cast(Inst)) Changed |= VisitUnaryOperator(UnOp); - if (match(Inst, m_Load(m_Value(Op1)))) + else if (auto *Cast = dyn_cast(Inst)) + Changed |= VisitCastInstruction(Cast); + else if (match(Inst, m_Load(m_Value(Op1)))) Changed |= VisitLoad(cast(Inst), Op1, Builder); else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) Changed |= VisitStore(cast(Inst), Op1, Op2, Builder); @@ -2198,6 +2230,37 @@ class LowerMatrixIntrinsics { return true; } + /// Lower cast instructions, if shape information is available. + bool VisitCastInstruction(CastInst *Inst) { + auto I = ShapeMap.find(Inst); + if (I == ShapeMap.end()) + return false; + + Value *Op = Inst->getOperand(0); + + IRBuilder<> Builder(Inst); + ShapeInfo &Shape = I->second; + + MatrixTy Result; + MatrixTy M = getMatrix(Op, Shape, Builder); + + Builder.setFastMathFlags(getFastMathFlags(Inst)); + + auto *OrigVTy = cast(Inst->getType()); + auto *NewVTy = VectorType::get(OrigVTy->getElementType(), + ElementCount::getFixed(M.getStride())); + + for (unsigned I = 0; I < Shape.getNumVectors(); ++I) + Result.addVector( + Builder.CreateCast(Inst->getOpcode(), M.getVector(I), NewVTy)); + + finalizeLowering(Inst, + Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * + Result.getNumVectors()), + Builder); + return true; + } + /// Helper to linearize a matrix expression tree into a string. Currently /// matrix expressions are linarized by starting at an expression leaf and /// linearizing bottom up. diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll new file mode 100644 index 0000000000000..a4bd516868bcd --- /dev/null +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/unary.ll @@ -0,0 +1,250 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s + +define void @fneg_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @fneg_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x float> [[COL_LOAD]] +; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x float> [[COL_LOAD1]] +; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 4 +; CHECK-NEXT: ret void +; + %inv = load <4 x float>, ptr %in + %op = fneg <4 x float> %inv + call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @trunc_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @trunc_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i64> [[COL_LOAD]] to <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = trunc <2 x i64> [[COL_LOAD1]] to <2 x i32> +; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4 +; CHECK-NEXT: ret void +; + %inv = load <4 x i64>, ptr %in + %op = trunc <4 x i64> %inv to <4 x i32> + call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @zext_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @zext_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i16>, ptr [[IN:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i16, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i16>, ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = zext <2 x i16> [[COL_LOAD]] to <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = zext <2 x i16> [[COL_LOAD1]] to <2 x i32> +; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4 +; CHECK-NEXT: ret void +; + %inv = load <4 x i16>, ptr %in + %op = zext <4 x i16> %inv to <4 x i32> + call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @sext_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @sext_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i8>, ptr [[IN:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i8, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i8>, ptr [[VEC_GEP]], align 2 +; CHECK-NEXT: [[TMP1:%.*]] = sext <2 x i8> [[COL_LOAD]] to <2 x i16> +; CHECK-NEXT: [[TMP2:%.*]] = sext <2 x i8> [[COL_LOAD1]] to <2 x i16> +; CHECK-NEXT: store <2 x i16> [[TMP1]], ptr [[OUT:%.*]], align 2 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i16, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i16> [[TMP2]], ptr [[VEC_GEP2]], align 2 +; CHECK-NEXT: ret void +; + %inv = load <4 x i8>, ptr %in + %op = sext <4 x i8> %inv to <4 x i16> + call void @llvm.matrix.column.major.store(<4 x i16> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @fptoui_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @fptoui_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = fptoui <2 x float> [[COL_LOAD]] to <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = fptoui <2 x float> [[COL_LOAD1]] to <2 x i32> +; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4 +; CHECK-NEXT: ret void +; + %inv = load <4 x float>, ptr %in + %op = fptoui <4 x float> %inv to <4 x i32> + call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @fptosi_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @fptosi_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = fptosi <2 x float> [[COL_LOAD]] to <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = fptosi <2 x float> [[COL_LOAD1]] to <2 x i32> +; CHECK-NEXT: store <2 x i32> [[TMP1]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i32, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i32> [[TMP2]], ptr [[VEC_GEP2]], align 4 +; CHECK-NEXT: ret void +; + %inv = load <4 x float>, ptr %in + %op = fptosi <4 x float> %inv to <4 x i32> + call void @llvm.matrix.column.major.store(<4 x i32> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @uitofp_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @uitofp_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = uitofp <2 x i64> [[COL_LOAD]] to <2 x double> +; CHECK-NEXT: [[TMP2:%.*]] = uitofp <2 x i64> [[COL_LOAD1]] to <2 x double> +; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: ret void +; + %inv = load <4 x i64>, ptr %in + %op = uitofp <4 x i64> %inv to <4 x double> + call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @sitofp_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @sitofp_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x i64>, ptr [[IN:%.*]], align 32 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i64, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x i64>, ptr [[VEC_GEP]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = sitofp <2 x i64> [[COL_LOAD]] to <2 x double> +; CHECK-NEXT: [[TMP2:%.*]] = sitofp <2 x i64> [[COL_LOAD1]] to <2 x double> +; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: ret void +; + %inv = load <4 x i64>, ptr %in + %op = sitofp <4 x i64> %inv to <4 x double> + call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @fptrunc_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @fptrunc_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD]] to <2 x float> +; CHECK-NEXT: [[TMP2:%.*]] = fptrunc nnan <2 x double> [[COL_LOAD1]] to <2 x float> +; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 4 +; CHECK-NEXT: ret void +; + %inv = load <4 x double>, ptr %in + %op = fptrunc nnan <4 x double> %inv to <4 x float> + call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @fpext_2x2(ptr %in, ptr %out) { +; CHECK-LABEL: @fpext_2x2( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[IN:%.*]], align 16 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = fpext <2 x float> [[COL_LOAD]] to <2 x double> +; CHECK-NEXT: [[TMP2:%.*]] = fpext <2 x float> [[COL_LOAD1]] to <2 x double> +; CHECK-NEXT: store <2 x double> [[TMP1]], ptr [[OUT:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr double, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x double> [[TMP2]], ptr [[VEC_GEP2]], align 8 +; CHECK-NEXT: ret void +; + %inv = load <4 x float>, ptr %in + %op = fpext <4 x float> %inv to <4 x double> + call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @bitcast_2x2_v4f64_to_v4i64(ptr %in, ptr %out) { +; CHECK-LABEL: @bitcast_2x2_v4f64_to_v4i64( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 32 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 16 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x double> [[COL_LOAD]] to <2 x i64> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x double> [[COL_LOAD1]] to <2 x i64> +; CHECK-NEXT: store <2 x i64> [[TMP1]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr i64, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x i64> [[TMP2]], ptr [[VEC_GEP2]], align 4 +; CHECK-NEXT: ret void +; + %inv = load <4 x double>, ptr %in + %op = bitcast <4 x double> %inv to <4 x i64> + call void @llvm.matrix.column.major.store(<4 x i64> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @bitcast_2x2_v4f64_to_v8i32(ptr %in, ptr %out) { +; CHECK-LABEL: @bitcast_2x2_v4f64_to_v8i32( +; CHECK-NEXT: [[INV:%.*]] = load <4 x double>, ptr [[IN:%.*]], align 32 +; CHECK-NEXT: [[OP:%.*]] = bitcast <4 x double> [[INV]] to <8 x i32> +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x i32> [[OP]], <8 x i32> poison, <4 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x i32> [[OP]], <8 x i32> poison, <4 x i32> +; CHECK-NEXT: store <4 x i32> [[SPLIT]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr i32, ptr [[OUT]], i64 4 +; CHECK-NEXT: store <4 x i32> [[SPLIT1]], ptr [[VEC_GEP]], align 4 +; CHECK-NEXT: ret void +; + %inv = load <4 x double>, ptr %in + %op = bitcast <4 x double> %inv to <8 x i32> + call void @llvm.matrix.column.major.store(<8 x i32> %op, ptr %out, i64 4, i1 false, i32 4, i32 2) + ret void +} + +define void @bitcast_2x2_i256_to_v4i64(ptr %in, ptr %out) { +; CHECK-LABEL: @bitcast_2x2_i256_to_v4i64( +; CHECK-NEXT: [[INV:%.*]] = load i256, ptr [[IN:%.*]], align 4 +; CHECK-NEXT: [[OP:%.*]] = bitcast i256 [[INV]] to <4 x double> +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x double> [[OP]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x double> [[OP]], <4 x double> poison, <2 x i32> +; CHECK-NEXT: store <2 x double> [[SPLIT]], ptr [[OUT:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[OUT]], i64 2 +; CHECK-NEXT: store <2 x double> [[SPLIT1]], ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: ret void +; + %inv = load i256, ptr %in + %op = bitcast i256 %inv to <4 x double> + call void @llvm.matrix.column.major.store(<4 x double> %op, ptr %out, i64 2, i1 false, i32 2, i32 2) + ret void +} + +define void @bitcast_2x2_4i64_to_i256(ptr %in, ptr %out) { +; CHECK-LABEL: @bitcast_2x2_4i64_to_i256( +; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x double>, ptr [[IN:%.*]], align 8 +; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr double, ptr [[IN]], i64 2 +; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[COL_LOAD]], <2 x double> [[COL_LOAD1]], <4 x i32> +; CHECK-NEXT: [[OP:%.*]] = bitcast <4 x double> [[TMP1]] to i256 +; CHECK-NEXT: store i256 [[OP]], ptr [[OUT:%.*]], align 4 +; CHECK-NEXT: ret void +; + %inv = call <4 x double> @llvm.matrix.column.major.load(ptr %in, i64 2, i1 false, i32 2, i32 2) + %op = bitcast <4 x double> %inv to i256 + store i256 %op, ptr %out + ret void +}