@@ -802,28 +802,61 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
802
802
assert (ivs.size () == iteratorTypes.size () && " did not generate enough loops" );
803
803
}
804
804
805
+ static Value materializeTiledShape (OpBuilder &builder, Location loc,
806
+ Value valueToTile,
807
+ const SliceParameters &sliceParams) {
808
+ auto shapedType = valueToTile.getType ().dyn_cast <ShapedType>();
809
+ auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
810
+ .Case ([&](MemRefType) {
811
+ return builder.create <memref::SubViewOp>(
812
+ loc, valueToTile, sliceParams.offsets ,
813
+ sliceParams.sizes , sliceParams.strides );
814
+ })
815
+ .Case ([&](RankedTensorType) {
816
+ return makeComposedExtractSliceOp (
817
+ builder, loc, valueToTile, sliceParams.offsets ,
818
+ sliceParams.sizes , sliceParams.strides );
819
+ })
820
+ .Default ([](ShapedType) -> Operation * {
821
+ llvm_unreachable (" Unexpected shaped type" );
822
+ });
823
+ return sliceOp->getResult (0 );
824
+ }
825
+
805
826
Value makeTiledShape (OpBuilder &builder, Location loc, Value valueToTile,
806
827
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
807
828
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
808
829
ArrayRef<OpFoldResult> subShapeSizes,
809
830
bool omitPartialTileCheck) {
831
+ SliceParameters sliceParams =
832
+ computeSliceParameters (builder, loc, valueToTile, tileSizes, map, lbs,
833
+ ubs, subShapeSizes, omitPartialTileCheck);
834
+ return materializeTiledShape (builder, loc, valueToTile, sliceParams);
835
+ }
836
+
837
+ SliceParameters
838
+ computeSliceParameters (OpBuilder &builder, Location loc, Value valueToTile,
839
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
840
+ ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
841
+ ArrayRef<OpFoldResult> subShapeSizes,
842
+ bool omitPartialTileCheck) {
810
843
auto shapedType = valueToTile.getType ().dyn_cast <ShapedType>();
811
844
assert (shapedType && " only shaped types can be tiled" );
812
845
ArrayRef<int64_t > shape = shapedType.getShape ();
813
846
int64_t rank = shapedType.getRank ();
814
847
815
848
// Construct a new subview / extract_slice for the tile.
816
- SmallVector<OpFoldResult, 4 > offsets, sizes, strides ;
817
- offsets.reserve (rank);
818
- sizes.reserve (rank);
819
- strides.reserve (rank);
849
+ SliceParameters sliceParams ;
850
+ sliceParams. offsets .reserve (rank);
851
+ sliceParams. sizes .reserve (rank);
852
+ sliceParams. strides .reserve (rank);
820
853
for (unsigned r = 0 ; r < rank; ++r) {
821
- LLVM_DEBUG (llvm::dbgs () << " makeTiledShape : for dim#" << r);
854
+ LLVM_DEBUG (llvm::dbgs () << " computeSliceParameters : for dim#" << r);
822
855
if (!isTiled (map.getSubMap ({r}), tileSizes)) {
823
- offsets.push_back (builder.getIndexAttr (0 ));
856
+ sliceParams. offsets .push_back (builder.getIndexAttr (0 ));
824
857
OpFoldResult dim = createFoldedDimOp (builder, loc, valueToTile, r);
825
- sizes.push_back (dim);
826
- strides.push_back (builder.getIndexAttr (1 ));
858
+ sliceParams. sizes .push_back (dim);
859
+ sliceParams. strides .push_back (builder.getIndexAttr (1 ));
827
860
LLVM_DEBUG (llvm::dbgs () << " : not tiled: use size: " << dim << " \n " );
828
861
continue ;
829
862
}
@@ -832,26 +865,27 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
832
865
// Tiling creates a new slice at the proper index, the slice step is 1
833
866
// (i.e. the op does not subsample, stepping occurs in the loop).
834
867
auto m = map.getSubMap ({r});
835
- LLVM_DEBUG (llvm::dbgs () << " makeTiledShape : submap: " << m << " \n " );
868
+ LLVM_DEBUG (llvm::dbgs () << " computeSliceParameters : submap: " << m << " \n " );
836
869
IRRewriter rewriter (builder);
837
870
OpFoldResult offset = makeComposedFoldedAffineApply (rewriter, loc, m, lbs);
838
- offsets.push_back (offset);
871
+ sliceParams. offsets .push_back (offset);
839
872
OpFoldResult closedIntSize =
840
873
makeComposedFoldedAffineApply (rewriter, loc, m, subShapeSizes);
841
874
// Resulting size needs to be made half open interval again.
842
875
AffineExpr s0 = getAffineSymbolExpr (0 , builder.getContext ());
843
876
OpFoldResult size =
844
877
makeComposedFoldedAffineApply (rewriter, loc, s0 + 1 , closedIntSize);
845
- LLVM_DEBUG (llvm::dbgs () << " makeTiledShape: raw size: " << size << " \n " );
846
878
LLVM_DEBUG (llvm::dbgs ()
847
- << " makeTiledShape: new offset: " << offset << " \n " );
848
- strides.push_back (builder.getIndexAttr (1 ));
879
+ << " computeSliceParameters: raw size: " << size << " \n " );
880
+ LLVM_DEBUG (llvm::dbgs ()
881
+ << " computeSliceParameters: new offset: " << offset << " \n " );
882
+ sliceParams.strides .push_back (builder.getIndexAttr (1 ));
849
883
850
884
if (omitPartialTileCheck) {
851
885
// We statically know that the partial/boundary tile condition is
852
886
// unnecessary.
853
887
LLVM_DEBUG (llvm::dbgs () << " makeTiledShape: new size: " << size << " \n " );
854
- sizes.push_back (size);
888
+ sliceParams. sizes .push_back (size);
855
889
continue ;
856
890
}
857
891
@@ -903,22 +937,9 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
903
937
makeComposedFoldedAffineMin (rewriter, loc, minMap, {size, d, offset});
904
938
}
905
939
LLVM_DEBUG (llvm::dbgs () << " makeTiledShape: new size: " << size << " \n " );
906
- sizes.push_back (size);
940
+ sliceParams. sizes .push_back (size);
907
941
}
908
-
909
- auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
910
- .Case ([&](MemRefType) {
911
- return builder.create <memref::SubViewOp>(
912
- loc, valueToTile, offsets, sizes, strides);
913
- })
914
- .Case ([&](RankedTensorType) {
915
- return makeComposedExtractSliceOp (
916
- builder, loc, valueToTile, offsets, sizes, strides);
917
- })
918
- .Default ([](ShapedType) -> Operation * {
919
- llvm_unreachable (" Unexpected shaped type" );
920
- });
921
- return sliceOp->getResult (0 );
942
+ return sliceParams;
922
943
}
923
944
924
945
SmallVector<OpFoldResult> computeTileOffsets (OpBuilder &b, Location loc,
@@ -1003,28 +1024,29 @@ Value materializeOpFoldResult(OpBuilder &builder, Location loc,
1003
1024
return materializeOpFoldResult (b, opFoldResult);
1004
1025
}
1005
1026
1006
- SmallVector<Value> makeTiledShapes (OpBuilder &b, Location loc,
1007
- LinalgOp linalgOp, ValueRange valuesToTile ,
1008
- ArrayRef<OpFoldResult> ivs,
1009
- ArrayRef<OpFoldResult> tileSizes,
1010
- ArrayRef<OpFoldResult> sizeBounds,
1011
- bool omitPartialTileCheck) {
1027
+ SmallVector<Optional<SliceParameters>>
1028
+ computeAllSliceParameters (OpBuilder &builder, Location loc, LinalgOp linalgOp,
1029
+ ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
1030
+ ArrayRef<OpFoldResult> tileSizes,
1031
+ ArrayRef<OpFoldResult> sizeBounds,
1032
+ bool omitPartialTileCheck) {
1012
1033
assert (ivs.size () == static_cast <size_t >(llvm::count_if (
1013
1034
llvm::make_range (tileSizes.begin (), tileSizes.end ()),
1014
1035
[](OpFoldResult v) { return !isZero (v); })) &&
1015
1036
" expected as many ivs as non-zero sizes" );
1016
1037
1017
1038
// Construct (potentially temporary) mins and maxes on which to apply maps
1018
1039
// that define tile subshapes.
1019
- SmallVector<OpFoldResult> lbs = computeTileOffsets (b, loc, ivs, tileSizes);
1040
+ SmallVector<OpFoldResult> lbs =
1041
+ computeTileOffsets (builder, loc, ivs, tileSizes);
1020
1042
SmallVector<OpFoldResult> subShapeSizes =
1021
- computeTileSizes (b , loc, tileSizes, sizeBounds);
1043
+ computeTileSizes (builder , loc, tileSizes, sizeBounds);
1022
1044
1023
1045
assert (static_cast <int64_t >(valuesToTile.size ()) ==
1024
1046
linalgOp.getNumInputsAndOutputs () &&
1025
1047
" expected one value to tile for every operand" );
1026
- SmallVector<Value> tiledShapes ;
1027
- tiledShapes .reserve (valuesToTile.size ());
1048
+ SmallVector<Optional<SliceParameters>> allSliceParams ;
1049
+ allSliceParams .reserve (valuesToTile.size ());
1028
1050
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands ()) {
1029
1051
Value shapedOp = valuesToTile[opOperand->getOperandNumber ()];
1030
1052
LLVM_DEBUG (llvm::dbgs () << " makeTiledShapes: for operand " << shapedOp);
@@ -1035,18 +1057,39 @@ SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc,
1035
1057
// extract/insert slice pairs make the accessed iteration argument
1036
1058
// subdomains explicit.
1037
1059
if (!isTiled (map, tileSizes) && !linalgOp.isOutputTensor (opOperand)) {
1038
- tiledShapes .push_back (shapedOp );
1060
+ allSliceParams .push_back (llvm::None );
1039
1061
LLVM_DEBUG (llvm::dbgs () << " : not tiled: use shape: "
1040
1062
<< opOperand->get ().getType () << " \n " );
1041
1063
continue ;
1042
1064
}
1043
1065
LLVM_DEBUG (llvm::dbgs () << " : tiled: figure out subshape...\n " );
1044
1066
1045
- tiledShapes .push_back (makeTiledShape (b, loc, shapedOp, tileSizes, map, lbs,
1046
- sizeBounds, subShapeSizes,
1047
- omitPartialTileCheck));
1067
+ allSliceParams .push_back (computeSliceParameters (
1068
+ builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
1069
+ omitPartialTileCheck));
1048
1070
}
1049
1071
1072
+ return allSliceParams;
1073
+ }
1074
+
1075
+ SmallVector<Value> makeTiledShapes (OpBuilder &builder, Location loc,
1076
+ LinalgOp linalgOp, ValueRange valuesToTile,
1077
+ ArrayRef<OpFoldResult> ivs,
1078
+ ArrayRef<OpFoldResult> tileSizes,
1079
+ ArrayRef<OpFoldResult> sizeBounds,
1080
+ bool omitPartialTileCheck) {
1081
+ SmallVector<Optional<SliceParameters>> allSliceParameter =
1082
+ computeAllSliceParameters (builder, loc, linalgOp, valuesToTile, ivs,
1083
+ tileSizes, sizeBounds, omitPartialTileCheck);
1084
+ SmallVector<Value> tiledShapes;
1085
+ for (auto item : llvm::zip (valuesToTile, allSliceParameter)) {
1086
+ Value valueToTile = std::get<0 >(item);
1087
+ Optional<SliceParameters> sliceParams = std::get<1 >(item);
1088
+ tiledShapes.push_back (
1089
+ sliceParams.hasValue ()
1090
+ ? materializeTiledShape (builder, loc, valueToTile, *sliceParams)
1091
+ : valueToTile);
1092
+ }
1050
1093
return tiledShapes;
1051
1094
}
1052
1095
0 commit comments