Skip to content

[CIR] Upstream global initialization for ComplexType #141369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return cir::IntAttr::get(ty, 0);
if (cir::isAnyFloatingPointType(ty))
return cir::FPAttr::getZero(ty);
if (auto complexType = mlir::dyn_cast<cir::ComplexType>(ty))
return cir::ZeroAttr::get(complexType);
if (auto arrTy = mlir::dyn_cast<cir::ArrayType>(ty))
return cir::ZeroAttr::get(arrTy);
if (auto vecTy = mlir::dyn_cast<cir::VectorType>(ty))
Expand Down
42 changes: 42 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,48 @@ def ConstPtrAttr : CIR_Attr<"ConstPtr", "ptr", [TypedAttrInterface]> {
}];
}

//===----------------------------------------------------------------------===//
// ConstComplexAttr
//===----------------------------------------------------------------------===//

def ConstComplexAttr : CIR_Attr<"ConstComplex", "const_complex",
[TypedAttrInterface]> {
let summary = "An attribute that contains a constant complex value";
let description = [{
The `#cir.const_complex` attribute contains a constant value of complex
number type. The `real` parameter gives the real part of the complex number
and the `imag` parameter gives the imaginary part of the complex number.

The `real` and `imag` parameters must both reference the same type and must
be either IntAttr or FPAttr.

```mlir
%ci = #cir.const_complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i>
: !cir.complex<!s32i>
%cf = #cir.const_complex<#cir.fp<1.000000e+00> : !cir.float,
#cir.fp<2.000000e+00> : !cir.float> : !cir.complex<!cir.float>
```
}];

let parameters = (ins
AttributeSelfTypeParameter<"", "cir::ComplexType">:$type,
"mlir::TypedAttr":$real, "mlir::TypedAttr":$imag);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create a predicate that will verify the IntAttr or FPAttr requirement? I don't think this is sufficient as it is, right?


let builders = [
AttrBuilderWithInferredContext<(ins "cir::ComplexType":$type,
"mlir::TypedAttr":$real,
"mlir::TypedAttr":$imag), [{
return $_get(type.getContext(), type, real, imag);
}]>,
];

let genVerifyDecl = 1;

let assemblyFormat = [{
`<` qualified($real) `,` qualified($imag) `>`
}];
}

//===----------------------------------------------------------------------===//
// VisibilityAttr
//===----------------------------------------------------------------------===//
Expand Down
48 changes: 48 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,54 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
}];
}

//===----------------------------------------------------------------------===//
// ComplexType
//===----------------------------------------------------------------------===//

def CIR_ComplexType : CIR_Type<"Complex", "complex",
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {

let summary = "CIR complex type";
let description = [{
CIR type that represents a C complex number. `cir.complex` models the C type
`T _Complex`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here mentioning how this relates to std::complex would be useful. Note that the representation of std::complex depends on the header file and I believe can even vary depending on which preprocessor symbols are defined (which maybe depends on the target). I don't want a comprehensive explanation of all possibilities here, just a general note explaining that this doesn't directly map to std::complex.

I'm not sure what guarantees the C++ standard makes about the implementation of std::complex, but it would be very nice if we could in some way indicate when std::complex is being used. It looks like std::complex is generally unsupported in the incubator, so there's some work to be done to figure that out, I think.


The type models complex values, per C99 6.2.5p11. It supports the C99
complex float types as well as the GCC integer complex extensions.

The parameter `elementType` gives the type of the real and imaginary part of
the complex number. `elementType` must be either a CIR integer type or a CIR
floating-point type.

```mlir
!cir.complex<!s32i>
!cir.complex<!cir.float>
```
}];

let parameters = (ins CIR_AnyIntOrFloatType:$elementType);

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
return $_get(elementType.getContext(), elementType);
}]>,
];

let assemblyFormat = [{
`<` $elementType `>`
}];

let extraClassDeclaration = [{
bool isFloatingPointComplex() const {
return isAnyFloatingPointType(getElementType());
}

bool isIntegerComplex() const {
return mlir::isa<cir::IntType>(getElementType());
}
}];
}

//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 25 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,12 +577,33 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &value,
case APValue::Union:
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate struct or union");
return {};
case APValue::FixedPoint:
case APValue::ComplexInt:
case APValue::ComplexFloat:
case APValue::ComplexFloat: {
mlir::Type desiredType = cgm.convertType(destType);
cir::ComplexType complexType =
mlir::dyn_cast<cir::ComplexType>(desiredType);

mlir::Type complexElemTy = complexType.getElementType();
if (isa<cir::IntType>(complexElemTy)) {
llvm::APSInt real = value.getComplexIntReal();
llvm::APSInt imag = value.getComplexIntImag();
return builder.getAttr<cir::ConstComplexAttr>(
complexType, builder.getAttr<cir::IntAttr>(complexElemTy, real),
builder.getAttr<cir::IntAttr>(complexElemTy, imag));
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you assert isa<cir::FPType>(complexElemTy)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think assert(isa<cir::CIRFPTypeInterface>(complexElemTy) && "...")

assert(isa<cir::CIRFPTypeInterface>(complexElemTy) &&
"expected floating-point type");
llvm::APFloat real = value.getComplexFloatReal();
llvm::APFloat imag = value.getComplexFloatImag();
return builder.getAttr<cir::ConstComplexAttr>(
complexType, builder.getAttr<cir::FPAttr>(complexElemTy, real),
builder.getAttr<cir::FPAttr>(complexElemTy, imag));
}
case APValue::FixedPoint:
case APValue::AddrLabelDiff:
cgm.errorNYI("ConstExprEmitter::tryEmitPrivate fixed point, complex int, "
"complex float, addr label diff");
cgm.errorNYI(
"ConstExprEmitter::tryEmitPrivate fixed point, addr label diff");
return {};
}
llvm_unreachable("Unknown APValue kind");
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,13 @@ mlir::Type CIRGenTypes::convertType(QualType type) {
break;
}

case Type::Complex: {
const auto *ct = cast<clang::ComplexType>(ty);
mlir::Type elementTy = convertType(ct->getElementType());
resultType = cir::ComplexType::get(elementTy);
break;
}

case Type::LValueReference:
case Type::RValueReference: {
const ReferenceType *refTy = cast<ReferenceType>(ty);
Expand Down
20 changes: 20 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,26 @@ LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

//===----------------------------------------------------------------------===//
// ConstComplexAttr definitions
//===----------------------------------------------------------------------===//

LogicalResult
ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,
cir::ComplexType type, mlir::TypedAttr real,
mlir::TypedAttr imag) {
mlir::Type elemType = type.getElementType();
if (real.getType() != elemType)
return emitError()
<< "type of the real part does not match the complex type";

if (imag.getType() != elemType)
return emitError()
<< "type of the imaginary part does not match the complex type";

return success();
}

//===----------------------------------------------------------------------===//
// CIR ConstArrayAttr
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 6 additions & 3 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,11 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
}

if (isa<cir::ZeroAttr>(attrType)) {
if (isa<cir::RecordType, cir::ArrayType, cir::VectorType>(opType))
if (isa<cir::RecordType, cir::ArrayType, cir::VectorType, cir::ComplexType>(
opType))
return success();
return op->emitOpError("zero expects struct or array type");
return op->emitOpError(
"zero expects struct, array, vector, or complex type");
}

if (mlir::isa<cir::BoolAttr>(attrType)) {
Expand All @@ -252,7 +254,8 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return success();
}

if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr>(attrType))
if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
cir::ConstComplexAttr>(attrType))
return success();

assert(isa<TypedAttr>(attrType) && "What else could we be looking at here?");
Expand Down
26 changes: 26 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,32 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
.getABIAlignment(dataLayout, params);
}

//===----------------------------------------------------------------------===//
// ComplexType Definitions
//===----------------------------------------------------------------------===//

llvm::TypeSize
cir::ComplexType::getTypeSizeInBits(const mlir::DataLayout &dataLayout,
mlir::DataLayoutEntryListRef params) const {
// C17 6.2.5p13:
// Each complex type has the same representation and alignment requirements
// as an array type containing exactly two elements of the corresponding
// real type.

return dataLayout.getTypeSizeInBits(getElementType()) * 2;
}

uint64_t
cir::ComplexType::getABIAlignment(const mlir::DataLayout &dataLayout,
mlir::DataLayoutEntryListRef params) const {
// C17 6.2.5p13:
// Each complex type has the same representation and alignment requirements
// as an array type containing exactly two elements of the corresponding
// real type.

return dataLayout.getTypeABIAlignment(getElementType());
}

//===----------------------------------------------------------------------===//
// Floating-point and Float-point Vector type helpers
//===----------------------------------------------------------------------===//
Expand Down
65 changes: 53 additions & 12 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,15 @@ class CIRAttrToValue {

mlir::Value visit(mlir::Attribute attr) {
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
.Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr,
cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>(
[&](auto attrT) { return visitCirAttr(attrT); })
.Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
.Default([&](auto attrT) { return mlir::Value(); });
}

mlir::Value visitCirAttr(cir::IntAttr intAttr);
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
Expand Down Expand Up @@ -226,6 +227,42 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
}

/// FPAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
mlir::Location loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
}

/// ConstComplexAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) {
auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType());
mlir::Type complexElemTy = complexType.getElementType();
mlir::Type complexElemLLVMTy = converter->convertType(complexElemTy);

mlir::Attribute components[2];
if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
components[0] = rewriter.getIntegerAttr(
complexElemLLVMTy,
mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
components[1] = rewriter.getIntegerAttr(
complexElemLLVMTy,
mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
} else {
components[0] = rewriter.getFloatAttr(
complexElemLLVMTy,
mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
components[1] = rewriter.getFloatAttr(
complexElemLLVMTy,
mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
}

mlir::Location loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(complexAttr.getType()),
rewriter.getArrayAttr(components));
}

/// ConstPtrAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
mlir::Location loc = parentOp->getLoc();
Expand All @@ -241,13 +278,6 @@ mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
loc, converter->convertType(ptrAttr.getType()), ptrVal);
}

/// FPAttr visitor.
mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
mlir::Location loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::ConstantOp>(
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
}

// ConstArrayAttr visitor
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
mlir::Type llvmTy = converter->convertType(attr.getType());
Expand Down Expand Up @@ -341,9 +371,11 @@ class GlobalInitAttrRewriter {
mlir::Attribute visitCirAttr(cir::IntAttr attr) {
return rewriter.getIntegerAttr(llvmType, attr.getValue());
}

mlir::Attribute visitCirAttr(cir::FPAttr attr) {
return rewriter.getFloatAttr(llvmType, attr.getValue());
}

mlir::Attribute visitCirAttr(cir::BoolAttr attr) {
return rewriter.getBoolAttr(attr.getValue());
}
Expand Down Expand Up @@ -986,7 +1018,7 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
mlir::ConversionPatternRewriter &rewriter) const {
// TODO: Generalize this handling when more types are needed here.
assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
cir::ZeroAttr>(init)));
cir::ConstComplexAttr, cir::ZeroAttr>(init)));

// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
Expand Down Expand Up @@ -1038,7 +1070,8 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
return mlir::failure();
}
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) {
cir::ConstPtrAttr, cir::ConstComplexAttr,
cir::ZeroAttr>(init.value())) {
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
// should be updated. For now, we use a custom op to initialize globals
// to the appropriate value.
Expand Down Expand Up @@ -1549,6 +1582,14 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
return mlir::BFloat16Type::get(type.getContext());
});
converter.addConversion([&](cir::ComplexType type) -> mlir::Type {
// A complex type is lowered to an LLVM struct that contains the real and
// imaginary part as data fields.
mlir::Type elementTy = converter.convertType(type.getElementType());
mlir::Type structFields[2] = {elementTy, elementTy};
return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(),
structFields);
});
converter.addConversion([&](cir::FuncType type) -> std::optional<mlir::Type> {
auto result = converter.convertType(type.getReturnType());
llvm::SmallVector<mlir::Type> arguments;
Expand Down
Loading
Loading