@@ -1049,6 +1049,29 @@ struct TargetPPC64 : public GenericTarget<TargetPPC64> {
1049
1049
AT{});
1050
1050
return marshal;
1051
1051
}
1052
+
1053
+ CodeGenSpecifics::Marshalling
1054
+ structType (mlir::Location loc, fir::RecordType ty, bool isResult) const {
1055
+ CodeGenSpecifics::Marshalling marshal;
1056
+ auto sizeAndAlign{
1057
+ fir::getTypeSizeAndAlignmentOrCrash (loc, ty, getDataLayout (), kindMap)};
1058
+ unsigned short align{
1059
+ std::max (sizeAndAlign.second , static_cast <unsigned short >(8 ))};
1060
+ marshal.emplace_back (fir::ReferenceType::get (ty),
1061
+ AT{align, /* byval*/ !isResult, /* sret*/ isResult});
1062
+ return marshal;
1063
+ }
1064
+
1065
+ CodeGenSpecifics::Marshalling
1066
+ structArgumentType (mlir::Location loc, fir::RecordType ty,
1067
+ const Marshalling &previousArguments) const override {
1068
+ return structType (loc, ty, false );
1069
+ }
1070
+
1071
+ CodeGenSpecifics::Marshalling
1072
+ structReturnType (mlir::Location loc, fir::RecordType ty) const override {
1073
+ return structType (loc, ty, true );
1074
+ }
1052
1075
};
1053
1076
} // namespace
1054
1077
@@ -1060,7 +1083,7 @@ namespace {
1060
1083
struct TargetPPC64le : public GenericTarget <TargetPPC64le> {
1061
1084
using GenericTarget::GenericTarget;
1062
1085
1063
- static constexpr int defaultWidth = 64 ;
1086
+ static constexpr int defaultWidth{ 64 } ;
1064
1087
1065
1088
CodeGenSpecifics::Marshalling
1066
1089
complexArgumentType (mlir::Location, mlir::Type eleTy) const override {
@@ -1081,6 +1104,143 @@ struct TargetPPC64le : public GenericTarget<TargetPPC64le> {
1081
1104
AT{});
1082
1105
return marshal;
1083
1106
}
1107
+
1108
+ unsigned getElemWidth (mlir::Type ty) const {
1109
+ unsigned width{};
1110
+ llvm::TypeSwitch<mlir::Type>(ty)
1111
+ .template Case <mlir::ComplexType>([&](mlir::ComplexType cmplx) {
1112
+ auto elemType{
1113
+ mlir::dyn_cast<mlir::FloatType>(cmplx.getElementType ())};
1114
+ width = elemType.getWidth ();
1115
+ })
1116
+ .template Case <mlir::FloatType>(
1117
+ [&](mlir::FloatType real) { width = real.getWidth (); });
1118
+ return width;
1119
+ }
1120
+
1121
+ // Determine if all derived types components are of the same float type with
1122
+ // the same width. Complex(4) is considered 2 floats and complex(8) 2 doubles.
1123
+ bool hasSameFloatAndWidth (
1124
+ fir::RecordType recTy,
1125
+ std::pair<mlir::Type, unsigned > &firstTypeAndWidth) const {
1126
+ for (auto comp : recTy.getTypeList ()) {
1127
+ mlir::Type compType{comp.second };
1128
+ if (mlir::isa<fir::RecordType>(compType)) {
1129
+ auto rc{hasSameFloatAndWidth (mlir::cast<fir::RecordType>(compType),
1130
+ firstTypeAndWidth)};
1131
+ if (!rc)
1132
+ return false ;
1133
+ } else {
1134
+ mlir::Type ty;
1135
+ bool isFloatType{false };
1136
+ if (mlir::isa<mlir::FloatType, mlir::ComplexType>(compType)) {
1137
+ ty = compType;
1138
+ isFloatType = true ;
1139
+ } else if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(compType)) {
1140
+ ty = seqTy.getEleTy ();
1141
+ isFloatType = mlir::isa<mlir::FloatType, mlir::ComplexType>(ty);
1142
+ }
1143
+
1144
+ if (!isFloatType) {
1145
+ return false ;
1146
+ }
1147
+ auto width{getElemWidth (ty)};
1148
+ if (firstTypeAndWidth.first == nullptr ) {
1149
+ firstTypeAndWidth.first = ty;
1150
+ firstTypeAndWidth.second = width;
1151
+ } else if (width != firstTypeAndWidth.second ) {
1152
+ return false ;
1153
+ }
1154
+ }
1155
+ }
1156
+ return true ;
1157
+ }
1158
+
1159
+ CodeGenSpecifics::Marshalling
1160
+ passOnTheStack (mlir::Location loc, mlir::Type ty, bool isResult) const {
1161
+ CodeGenSpecifics::Marshalling marshal;
1162
+ auto sizeAndAlign{
1163
+ fir::getTypeSizeAndAlignmentOrCrash (loc, ty, getDataLayout (), kindMap)};
1164
+ unsigned short align{
1165
+ std::max (sizeAndAlign.second , static_cast <unsigned short >(8 ))};
1166
+ marshal.emplace_back (fir::ReferenceType::get (ty),
1167
+ AT{align, /* byval=*/ !isResult, /* sret=*/ isResult});
1168
+ return marshal;
1169
+ }
1170
+
1171
+ CodeGenSpecifics::Marshalling
1172
+ structType (mlir::Location loc, fir::RecordType recTy, bool isResult) const {
1173
+ CodeGenSpecifics::Marshalling marshal;
1174
+ auto sizeAndAlign{fir::getTypeSizeAndAlignmentOrCrash (
1175
+ loc, recTy, getDataLayout (), kindMap)};
1176
+ auto recordTypeSize{sizeAndAlign.first };
1177
+ mlir::Type seqTy;
1178
+ std::pair<mlir::Type, unsigned > firstTyAndWidth{nullptr , 0 };
1179
+
1180
+ // If there are less than or equal to 8 floats, the structure is flatten as
1181
+ // an array of floats.
1182
+ constexpr uint64_t maxNoOfFloats{8 };
1183
+
1184
+ // i64 type
1185
+ mlir::Type elemTy{mlir::IntegerType::get (recTy.getContext (), defaultWidth)};
1186
+ uint64_t nElem{static_cast <uint64_t >(
1187
+ std::ceil (static_cast <float >(recordTypeSize * 8 ) / defaultWidth))};
1188
+
1189
+ // If the derived type components contains are all floats with the same
1190
+ // width, the argument is passed as an array of floats.
1191
+ if (hasSameFloatAndWidth (recTy, firstTyAndWidth)) {
1192
+ uint64_t n{};
1193
+ auto firstType{firstTyAndWidth.first };
1194
+
1195
+ // Type is either float or complex
1196
+ if (auto cmplx = mlir::dyn_cast<mlir::ComplexType>(firstType)) {
1197
+ auto fltType{mlir::dyn_cast<mlir::FloatType>(cmplx.getElementType ())};
1198
+ n = static_cast <uint64_t >(8 * recordTypeSize / fltType.getWidth ());
1199
+ if (n <= maxNoOfFloats) {
1200
+ nElem = n;
1201
+ elemTy = fltType;
1202
+ }
1203
+ } else if (mlir::isa<mlir::FloatType>(firstType)) {
1204
+ auto elemSizeAndAlign{fir::getTypeSizeAndAlignmentOrCrash (
1205
+ loc, firstType, getDataLayout (), kindMap)};
1206
+ n = static_cast <uint64_t >(recordTypeSize / elemSizeAndAlign.first );
1207
+ if (n <= maxNoOfFloats) {
1208
+ nElem = n;
1209
+ elemTy = firstType;
1210
+ }
1211
+ }
1212
+ // Neither float nor complex
1213
+ assert (n > 0 && " unexpected type" );
1214
+ }
1215
+
1216
+ // For function returns, only flattened if there are less than 8
1217
+ // floats in total.
1218
+ if (isResult &&
1219
+ ((mlir::isa<mlir::FloatType>(elemTy) && nElem > maxNoOfFloats) ||
1220
+ !mlir::isa<mlir::FloatType>(elemTy))) {
1221
+ return passOnTheStack (loc, recTy, isResult);
1222
+ }
1223
+
1224
+ seqTy = fir::SequenceType::get (nElem, elemTy);
1225
+ marshal.emplace_back (seqTy, AT{});
1226
+ return marshal;
1227
+ }
1228
+
1229
+ CodeGenSpecifics::Marshalling
1230
+ structArgumentType (mlir::Location loc, fir::RecordType recType,
1231
+ const Marshalling &previousArguments) const override {
1232
+ auto sizeAndAlign{fir::getTypeSizeAndAlignmentOrCrash (
1233
+ loc, recType, getDataLayout (), kindMap)};
1234
+ if (sizeAndAlign.first > 64 ) {
1235
+ return passOnTheStack (loc, recType, false );
1236
+ }
1237
+ return structType (loc, recType, false );
1238
+ }
1239
+
1240
+ CodeGenSpecifics::Marshalling
1241
+ structReturnType (mlir::Location loc, fir::RecordType recType) const override {
1242
+ return structType (loc, recType, true );
1243
+ }
1084
1244
};
1085
1245
} // namespace
1086
1246
0 commit comments