Skip to content

Commit 76d0648

Browse files
authored
[AutoDiff] [Sema] Include certain 'let' properties in 'Differentiable' derived conformances. (swiftlang#33700)
In `Differentiable` derived conformances, `let` properties are currently treated as if they had `@noDerivative` and excluded from the derived `Differentiable` conformance implementation. This is limiting to properties that have a non-mutating `move(along:)` (e.g. class properties), which can be mathematically treated as differentiable variables. This patch changes the derived conformances behavior such that `let` properties will be included as differentiable variables if they have a non-mutating `move(along:)`. This unblocks the following code: ```swift final class Foo: Differentiable { let x: ClassStuff // Class type with a non-mutating 'move(along:)' // Synthesized code: // struct TangentVector { // var x: ClassStuff.TangentVector // } // ... // func move(along direction: TangentVector) { // x.move(along: direction.x) // } } ``` Resolves SR-13474 (rdar://67982207).
1 parent 1e09ad0 commit 76d0648

File tree

5 files changed

+147
-28
lines changed

5 files changed

+147
-28
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,15 +2823,16 @@ WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
28232823
/*nominalCanDeriveAdditiveArithmetic*/ bool))
28242824
WARNING(differentiable_immutable_wrapper_implicit_noderivative_fixit,none,
28252825
"synthesis of the 'Differentiable.move(along:)' requirement for %1 "
2826-
"requires 'wrappedValue' in property wrapper %0 to be mutable; "
2827-
"add an explicit '@noDerivative' attribute"
2826+
"requires 'wrappedValue' in property wrapper %0 to be mutable or have a "
2827+
"non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute"
28282828
"%select{|, or conform %1 to 'AdditiveArithmetic'}2",
28292829
(/*wrapperType*/ Identifier, /*nominalName*/ Identifier,
28302830
/*nominalCanDeriveAdditiveArithmetic*/ bool))
28312831
WARNING(differentiable_let_property_implicit_noderivative_fixit,none,
28322832
"synthesis of the 'Differentiable.move(along:)' requirement for %0 "
28332833
"requires all stored properties not marked with `@noDerivative` to be "
2834-
"mutable; use 'var' instead, or add an explicit '@noDerivative' attribute"
2834+
"mutable or have a non-mutating 'move(along:)'; use 'var' instead, or "
2835+
"add an explicit '@noDerivative' attribute "
28352836
"%select{|, or conform %0 to 'AdditiveArithmetic'}1",
28362837
(/*nominalName*/ Identifier, /*nominalCanDeriveAdditiveArithmetic*/ bool))
28372838

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,37 @@
3232

3333
using namespace swift;
3434

35+
/// Return true if `move(along:)` can be invoked on the given `Differentiable`-
36+
/// conforming property.
37+
///
38+
/// If the given property is a `var`, return true because `move(along:)` can be
39+
/// invoked regardless. Otherwise, return true if and only if the property's
40+
/// type's 'Differentiable.move(along:)' witness is non-mutating.
41+
static bool canInvokeMoveAlongOnProperty(
42+
VarDecl *vd, ProtocolConformanceRef diffableConformance) {
43+
assert(diffableConformance && "Property must conform to 'Differentiable'");
44+
// `var` always supports `move(along:)` since it is mutable.
45+
if (vd->getIntroducer() == VarDecl::Introducer::Var)
46+
return true;
47+
// When the property is a `let`, the only case that would be supported is when
48+
// it has a `move(along:)` protocol requirement witness that is non-mutating.
49+
auto interfaceType = vd->getInterfaceType();
50+
auto &C = vd->getASTContext();
51+
auto witness = diffableConformance.getWitnessByName(
52+
interfaceType, DeclName(C, C.Id_move, {C.Id_along}));
53+
if (!witness)
54+
return false;
55+
auto *decl = cast<FuncDecl>(witness.getDecl());
56+
return decl->isNonMutating();
57+
}
58+
3559
/// Get the stored properties of a nominal type that are relevant for
3660
/// differentiation, except the ones tagged `@noDerivative`.
3761
static void
38-
getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
39-
SmallVectorImpl<VarDecl *> &result,
40-
bool includeLetProperties = false) {
62+
getStoredPropertiesForDifferentiation(
63+
NominalTypeDecl *nominal, DeclContext *DC,
64+
SmallVectorImpl<VarDecl *> &result,
65+
bool includeLetPropertiesWithNonmutatingMoveAlong = false) {
4166
auto &C = nominal->getASTContext();
4267
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
4368
for (auto *vd : nominal->getStoredProperties()) {
@@ -53,15 +78,18 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, DeclContext *DC,
5378
// Skip stored properties with `@noDerivative` attribute.
5479
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
5580
continue;
56-
// Skip `let` stored properties if requested.
57-
// `mutating func move(along:)` cannot be synthesized to update `let`
58-
// properties.
59-
if (!includeLetProperties && vd->isLet())
60-
continue;
6181
if (vd->getInterfaceType()->hasError())
6282
continue;
6383
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
64-
if (!TypeChecker::conformsToProtocol(varType, diffableProto, nominal))
84+
auto conformance = TypeChecker::conformsToProtocol(
85+
varType, diffableProto, nominal);
86+
if (!conformance)
87+
continue;
88+
// Skip `let` stored properties with a mutating `move(along:)` if requested.
89+
// `mutating func move(along:)` cannot be synthesized to update `let`
90+
// properties.
91+
if (!includeLetPropertiesWithNonmutatingMoveAlong &&
92+
!canInvokeMoveAlongOnProperty(vd, conformance))
6593
continue;
6694
result.push_back(vd);
6795
}
@@ -782,18 +810,18 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
782810
continue;
783811
// Check whether to diagnose stored property.
784812
auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
785-
bool conformsToDifferentiable =
786-
!TypeChecker::conformsToProtocol(varType, diffableProto, nominal)
787-
.isInvalid();
813+
auto diffableConformance =
814+
TypeChecker::conformsToProtocol(varType, diffableProto, nominal);
788815
// If stored property should not be diagnosed, continue.
789-
if (conformsToDifferentiable && !vd->isLet())
816+
if (diffableConformance &&
817+
canInvokeMoveAlongOnProperty(vd, diffableConformance))
790818
continue;
791819
// Otherwise, add an implicit `@noDerivative` attribute.
792820
vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
793821
auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
794822
assert(loc.isValid() && "Expected valid source location");
795823
// Diagnose properties that do not conform to `Differentiable`.
796-
if (!conformsToDifferentiable) {
824+
if (!diffableConformance) {
797825
Context.Diags
798826
.diagnose(
799827
loc,

test/AutoDiff/Sema/DerivedConformances/class_differentiable.swift

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,70 @@ func testEmpty() {
2929
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
3030
}
3131

32+
protocol DifferentiableWithNonmutatingMoveAlong: Differentiable {}
33+
extension DifferentiableWithNonmutatingMoveAlong {
34+
func move(along _: TangentVector) {}
35+
}
36+
37+
class EmptyWithInheritedNonmutatingMoveAlong: DifferentiableWithNonmutatingMoveAlong {
38+
typealias TangentVector = Empty.TangentVector
39+
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
40+
static func proof_that_i_have_nonmutating_move_along() {
41+
let empty = EmptyWithInheritedNonmutatingMoveAlong()
42+
empty.move(along: .init())
43+
}
44+
}
45+
46+
class EmptyWrapper<T: Differentiable & AnyObject>: Differentiable {}
47+
func testEmptyWrapper() {
48+
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
49+
assertConformsToAdditiveArithmetic(EmptyWrapper<Empty>.TangentVector.self)
50+
}
51+
3252
// Test structs with `let` stored properties.
3353
// Derived conformances fail because `mutating func move` requires all stored
3454
// properties to be mutable.
35-
class ImmutableStoredProperties: Differentiable {
55+
class ImmutableStoredProperties<T: Differentiable & AnyObject>: Differentiable {
3656
var okay: Float
3757

3858
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
3959
let nondiff: Int
4060

41-
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
61+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
4262
let diff: Float
4363

44-
init() {
64+
let letClass: Empty // No error on class-bound differentiable `let` with a non-mutating 'move(along:)'.
65+
66+
let letClassWithInheritedNonmutatingMoveAlong: EmptyWithInheritedNonmutatingMoveAlong
67+
68+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
69+
let letClassGeneric: T // Error due to lack of non-mutating 'move(along:)'.
70+
71+
let letClassWrappingGeneric: EmptyWrapper<T> // No error on class-bound differentiable `let` with a non-mutating 'move(along:)'.
72+
73+
init(letClassGeneric: T) {
4574
okay = 0
4675
nondiff = 0
4776
diff = 0
77+
letClass = Empty()
78+
self.letClassGeneric = letClassGeneric
79+
self.letClassWrappingGeneric = EmptyWrapper<T>()
4880
}
4981
}
5082
func testImmutableStoredProperties() {
51-
_ = ImmutableStoredProperties.TangentVector(okay: 1)
83+
_ = ImmutableStoredProperties<Empty>.TangentVector(
84+
okay: 1,
85+
letClass: Empty.TangentVector(),
86+
letClassWithInheritedNonmutatingMoveAlong: Empty.TangentVector(),
87+
letClassWrappingGeneric: EmptyWrapper<Empty>.TangentVector())
5288
}
5389
class MutableStoredPropertiesWithInitialValue: Differentiable {
5490
var x = Float(1)
5591
var y = Double(1)
5692
}
5793
// Test class with both an empty constructor and memberwise initializer.
5894
class AllMixedStoredPropertiesHaveInitialValue: Differentiable {
59-
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
95+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
6096
let x = Float(1)
6197
var y = Float(1)
6298
// Memberwise initializer should be `init(y:)` since `x` is immutable.
@@ -550,7 +586,7 @@ struct Generic<T> {}
550586
extension Generic: Differentiable where T: Differentiable {}
551587

552588
class WrappedProperties: Differentiable {
553-
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable; add an explicit '@noDerivative' attribute}}
589+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable or have a non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute}}
554590
@ImmutableWrapper var immutableInt: Generic<Int> = Generic()
555591

556592
// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}

test/AutoDiff/Sema/DerivedConformances/struct_differentiable.swift

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,35 @@ func testEmpty() {
1111
assertConformsToAdditiveArithmetic(Empty.TangentVector.self)
1212
}
1313

14+
struct EmptyWithConcreteNonmutatingMoveAlong: Differentiable {
15+
typealias TangentVector = Empty.TangentVector
16+
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
17+
func move(along _: TangentVector) {}
18+
static func proof_that_i_have_nonmutating_move_along() {
19+
let empty = Self()
20+
empty.move(along: .init())
21+
}
22+
}
23+
24+
protocol DifferentiableWithNonmutatingMoveAlong: Differentiable {}
25+
extension DifferentiableWithNonmutatingMoveAlong {
26+
func move(along _: TangentVector) {}
27+
}
28+
29+
struct EmptyWithInheritedNonmutatingMoveAlong: DifferentiableWithNonmutatingMoveAlong {
30+
typealias TangentVector = Empty.TangentVector
31+
var zeroTangentVectorInitializer: () -> TangentVector { { .init() } }
32+
static func proof_that_i_have_nonmutating_move_along() {
33+
let empty = Self()
34+
empty.move(along: .init())
35+
}
36+
}
37+
38+
class EmptyClass: Differentiable {}
39+
func testEmptyClass() {
40+
assertConformsToAdditiveArithmetic(EmptyClass.TangentVector.self)
41+
}
42+
1443
// Test interaction with `AdditiveArithmetic` derived conformances.
1544
// Previously, this crashed due to duplicate memberwise initializer synthesis.
1645
struct EmptyAdditiveArithmetic: AdditiveArithmetic, Differentiable {}
@@ -21,22 +50,32 @@ struct EmptyAdditiveArithmetic: AdditiveArithmetic, Differentiable {}
2150
struct ImmutableStoredProperties: Differentiable {
2251
var okay: Float
2352

24-
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }}
53+
// expected-warning @+1 {{stored property 'nondiff' has no derivative because 'Int' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
2554
let nondiff: Int
2655

27-
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic}} {{3-3=@noDerivative }}
56+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
2857
let diff: Float
58+
59+
let nonmutatingMoveAlongStruct: EmptyWithConcreteNonmutatingMoveAlong
60+
61+
let inheritedNonmutatingMoveAlongStruct: EmptyWithInheritedNonmutatingMoveAlong
62+
63+
let diffClass: EmptyClass // No error on class-bound `let` with a non-mutating `move(along:)`.
2964
}
3065
func testImmutableStoredProperties() {
31-
_ = ImmutableStoredProperties.TangentVector(okay: 1)
66+
_ = ImmutableStoredProperties.TangentVector(
67+
okay: 1,
68+
nonmutatingMoveAlongStruct: Empty.TangentVector(),
69+
inheritedNonmutatingMoveAlongStruct: Empty.TangentVector(),
70+
diffClass: EmptyClass.TangentVector())
3271
}
3372
struct MutableStoredPropertiesWithInitialValue: Differentiable {
3473
var x = Float(1)
3574
var y = Double(1)
3675
}
3776
// Test struct with both an empty constructor and memberwise initializer.
3877
struct AllMixedStoredPropertiesHaveInitialValue: Differentiable {
39-
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
78+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties not marked with `@noDerivative` to be mutable or have a non-mutating 'move(along:)'; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }}
4079
let x = Float(1)
4180
var y = Float(1)
4281
// Memberwise initializer should be `init(y:)` since `x` is immutable.
@@ -363,7 +402,7 @@ struct Generic<T> {}
363402
extension Generic: Differentiable where T: Differentiable {}
364403

365404
struct WrappedProperties: Differentiable {
366-
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable; add an explicit '@noDerivative' attribute}}
405+
// expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'WrappedProperties' requires 'wrappedValue' in property wrapper 'ImmutableWrapper' to be mutable or have a non-mutating 'move(along:)'; add an explicit '@noDerivative' attribute}}
367406
@ImmutableWrapper var immutableInt: Generic<Int>
368407

369408
// expected-warning @+1 {{stored property 'mutableInt' has no derivative because 'Generic<Int>' does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}}

test/AutoDiff/validation-test/class_differentiation.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,4 +524,19 @@ ClassTests.test("ClassProperties") {
524524
gradient(at: Super(base: 2)) { foo in foo.squared })
525525
}
526526

527+
ClassTests.test("LetProperties") {
528+
final class Foo: Differentiable {
529+
var x: Tracked<Float>
530+
init(x: Tracked<Float>) { self.x = x }
531+
}
532+
final class Bar: Differentiable {
533+
let x = Foo(x: 2)
534+
}
535+
let bar = Bar()
536+
let grad = gradient(at: bar) { bar in (bar.x.x * bar.x.x).value }
537+
expectEqual(Bar.TangentVector(x: .init(x: 6.0)), grad)
538+
bar.move(along: grad)
539+
expectEqual(8.0, bar.x.x)
540+
}
541+
527542
runAllTests()

0 commit comments

Comments
 (0)