Skip to content

Commit a538906

Browse files
committed
Use archetype inspection to enforce homogenous operations
1 parent 259a2b1 commit a538906

File tree

3 files changed

+45
-31
lines changed

3 files changed

+45
-31
lines changed

Sources/LLVM/Constant.swift

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import cllvm
33
#endif
44

5-
public struct Constant: IRValue {
5+
public protocol ConstantRepresentation {}
6+
public enum Unsigned: ConstantRepresentation {}
7+
public enum Signed: ConstantRepresentation {}
8+
public enum Floating: ConstantRepresentation {}
9+
10+
public struct Constant<Repr: ConstantRepresentation>: IRValue {
611
internal enum Representation {
712
case unsigned
813
case signed
@@ -17,9 +22,18 @@ public struct Constant: IRValue {
1722
return llvm
1823
}
1924

20-
internal init(llvm: LLVMValueRef!, representation: Representation) {
25+
public init(llvm: LLVMValueRef!) {
2126
self.llvm = llvm
22-
self.repr = representation
27+
28+
if ObjectIdentifier(Repr.self) == ObjectIdentifier(Unsigned.self) {
29+
self.repr = .unsigned
30+
} else if ObjectIdentifier(Repr.self) == ObjectIdentifier(Signed.self) {
31+
self.repr = .signed
32+
} else if ObjectIdentifier(Repr.self) == ObjectIdentifier(Floating.self) {
33+
self.repr = .floating
34+
} else {
35+
fatalError("Invalid representation \(type(of: Repr.self))")
36+
}
2337
}
2438

2539
public static func +(lhs: Constant, rhs: Constant) -> Constant {
@@ -28,9 +42,9 @@ public struct Constant: IRValue {
2842
switch lhs.repr {
2943
case .signed: fallthrough
3044
case .unsigned:
31-
return Constant(llvm: LLVMConstAdd(lhs.llvm, rhs.llvm), representation: lhs.repr)
45+
return Constant(llvm: LLVMConstAdd(lhs.llvm, rhs.llvm))
3246
case .floating:
33-
return Constant(llvm: LLVMConstFAdd(lhs.llvm, rhs.llvm), representation: lhs.repr)
47+
return Constant(llvm: LLVMConstFAdd(lhs.llvm, rhs.llvm))
3448
}
3549
}
3650

@@ -40,9 +54,9 @@ public struct Constant: IRValue {
4054
switch lhs.repr {
4155
case .signed: fallthrough
4256
case .unsigned:
43-
return Constant(llvm: LLVMConstSub(lhs.llvm, rhs.llvm), representation: lhs.repr)
57+
return Constant(llvm: LLVMConstSub(lhs.llvm, rhs.llvm))
4458
case .floating:
45-
return Constant(llvm: LLVMConstFSub(lhs.llvm, rhs.llvm), representation: lhs.repr)
59+
return Constant(llvm: LLVMConstFSub(lhs.llvm, rhs.llvm))
4660
}
4761
}
4862

@@ -52,9 +66,9 @@ public struct Constant: IRValue {
5266
switch lhs.repr {
5367
case .signed: fallthrough
5468
case .unsigned:
55-
return Constant(llvm: LLVMConstMul(lhs.llvm, rhs.llvm), representation: lhs.repr)
69+
return Constant(llvm: LLVMConstMul(lhs.llvm, rhs.llvm))
5670
case .floating:
57-
return Constant(llvm: LLVMConstFMul(lhs.llvm, rhs.llvm), representation: lhs.repr)
71+
return Constant(llvm: LLVMConstFMul(lhs.llvm, rhs.llvm))
5872
}
5973
}
6074

@@ -63,30 +77,30 @@ public struct Constant: IRValue {
6377

6478
switch lhs.repr {
6579
case .signed:
66-
return Constant(llvm: LLVMConstSDiv(lhs.llvm, rhs.llvm), representation: lhs.repr)
80+
return Constant(llvm: LLVMConstSDiv(lhs.llvm, rhs.llvm))
6781
case .unsigned:
68-
return Constant(llvm: LLVMConstUDiv(lhs.llvm, rhs.llvm), representation: lhs.repr)
82+
return Constant(llvm: LLVMConstUDiv(lhs.llvm, rhs.llvm))
6983
case .floating:
70-
return Constant(llvm: LLVMConstFDiv(lhs.llvm, rhs.llvm), representation: lhs.repr)
84+
return Constant(llvm: LLVMConstFDiv(lhs.llvm, rhs.llvm))
7185
}
7286
}
7387

7488
public static func ==(lhs: Constant, rhs: Constant) -> Constant {
7589
precondition(lhs.repr == rhs.repr, "Mixed-representation constant operations are disallowed")
7690

77-
return Constant(llvm: LLVMConstICmp(IntPredicate.eq.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
91+
return Constant(llvm: LLVMConstICmp(IntPredicate.eq.llvm, lhs.llvm, rhs.llvm))
7892
}
7993

8094
public static func <(lhs: Constant, rhs: Constant) -> Constant {
8195
precondition(lhs.repr == rhs.repr, "Mixed-representation constant operations are disallowed")
8296

8397
switch lhs.repr {
8498
case .signed:
85-
return Constant(llvm: LLVMConstICmp(IntPredicate.slt.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
99+
return Constant(llvm: LLVMConstICmp(IntPredicate.slt.llvm, lhs.llvm, rhs.llvm))
86100
case .unsigned:
87-
return Constant(llvm: LLVMConstICmp(IntPredicate.ult.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
101+
return Constant(llvm: LLVMConstICmp(IntPredicate.ult.llvm, lhs.llvm, rhs.llvm))
88102
case .floating:
89-
return Constant(llvm: LLVMConstFCmp(RealPredicate.olt.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
103+
return Constant(llvm: LLVMConstFCmp(RealPredicate.olt.llvm, lhs.llvm, rhs.llvm))
90104
}
91105
}
92106

@@ -95,11 +109,11 @@ public struct Constant: IRValue {
95109

96110
switch lhs.repr {
97111
case .signed:
98-
return Constant(llvm: LLVMConstICmp(IntPredicate.sgt.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
112+
return Constant(llvm: LLVMConstICmp(IntPredicate.sgt.llvm, lhs.llvm, rhs.llvm))
99113
case .unsigned:
100-
return Constant(llvm: LLVMConstICmp(IntPredicate.ugt.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
114+
return Constant(llvm: LLVMConstICmp(IntPredicate.ugt.llvm, lhs.llvm, rhs.llvm))
101115
case .floating:
102-
return Constant(llvm: LLVMConstFCmp(RealPredicate.ogt.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
116+
return Constant(llvm: LLVMConstFCmp(RealPredicate.ogt.llvm, lhs.llvm, rhs.llvm))
103117
}
104118
}
105119

@@ -108,11 +122,11 @@ public struct Constant: IRValue {
108122

109123
switch lhs.repr {
110124
case .signed:
111-
return Constant(llvm: LLVMConstICmp(IntPredicate.sle.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
125+
return Constant(llvm: LLVMConstICmp(IntPredicate.sle.llvm, lhs.llvm, rhs.llvm))
112126
case .unsigned:
113-
return Constant(llvm: LLVMConstICmp(IntPredicate.ule.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
127+
return Constant(llvm: LLVMConstICmp(IntPredicate.ule.llvm, lhs.llvm, rhs.llvm))
114128
case .floating:
115-
return Constant(llvm: LLVMConstFCmp(RealPredicate.ole.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
129+
return Constant(llvm: LLVMConstFCmp(RealPredicate.ole.llvm, lhs.llvm, rhs.llvm))
116130
}
117131
}
118132

@@ -121,11 +135,11 @@ public struct Constant: IRValue {
121135

122136
switch lhs.repr {
123137
case .signed:
124-
return Constant(llvm: LLVMConstICmp(IntPredicate.sge.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
138+
return Constant(llvm: LLVMConstICmp(IntPredicate.sge.llvm, lhs.llvm, rhs.llvm))
125139
case .unsigned:
126-
return Constant(llvm: LLVMConstICmp(IntPredicate.uge.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
140+
return Constant(llvm: LLVMConstICmp(IntPredicate.uge.llvm, lhs.llvm, rhs.llvm))
127141
case .floating:
128-
return Constant(llvm: LLVMConstFCmp(RealPredicate.oge.llvm, lhs.llvm, rhs.llvm), representation: lhs.repr)
142+
return Constant(llvm: LLVMConstFCmp(RealPredicate.oge.llvm, lhs.llvm, rhs.llvm))
129143
}
130144
}
131145
}

Sources/LLVM/FloatType.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ public enum FloatType: IRType {
1919
case ppcFP128
2020

2121
/// Creates a constant floating value of this type from a Swift `Double` value.
22-
public func constant(_ value: Double) -> Constant {
23-
return Constant(llvm: LLVMConstReal(asLLVM(), value), representation: .floating)
22+
public func constant(_ value: Double) -> Constant<Floating> {
23+
return Constant(llvm: LLVMConstReal(asLLVM(), value))
2424
}
2525

2626
/// Retrieves the underlying LLVM type object.

Sources/LLVM/IntType.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,21 @@ public struct IntType: IRType {
4040
/// - parameter value: A Swift integer value.
4141
/// - parameter signExtend: Whether to sign-extend this value to fit this
4242
/// type's bit width. Defaults to `false`.
43-
public func constant<IntTy: UnsignedInteger>(_ value: IntTy, signExtend: Bool = false) -> Constant {
43+
public func constant<IntTy: UnsignedInteger>(_ value: IntTy, signExtend: Bool = false) -> Constant<Unsigned> {
4444
return Constant(llvm: LLVMConstInt(asLLVM(),
4545
unsafeBitCast(value.toIntMax(), to: UInt64.self),
46-
signExtend.llvm), representation: .unsigned)
46+
signExtend.llvm))
4747
}
4848

4949
/// Creates a signed integer constant value with the given Swift integer value.
5050
///
5151
/// - parameter value: A Swift integer value.
5252
/// - parameter signExtend: Whether to sign-extend this value to fit this
5353
/// type's bit width. Defaults to `false`.
54-
public func constant<IntTy: SignedInteger>(_ value: IntTy, signExtend: Bool = false) -> Constant {
54+
public func constant<IntTy: SignedInteger>(_ value: IntTy, signExtend: Bool = false) -> Constant<Signed> {
5555
return Constant(llvm: LLVMConstInt(asLLVM(),
5656
unsafeBitCast(value.toIntMax(), to: UInt64.self),
57-
signExtend.llvm), representation: .signed)
57+
signExtend.llvm))
5858
}
5959

6060

0 commit comments

Comments
 (0)