Skip to content

Commit f1843f3

Browse files
authored
Merge pull request #945 from SciML/ap/lux1.0
fix!: update to Lux & Boltz 1.0
2 parents d2d2b47 + 3586821 commit f1843f3

File tree

15 files changed

+66
-293
lines changed

15 files changed

+66
-293
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ steps:
1818
RETESTITEMS_NWORKERS: 1 # These tests require quite a lot of GPU memory
1919
GROUP: CUDA
2020
DATADEPS_ALWAYS_ACCEPT: 'true'
21-
JULIA_PKG_SERVER: "" # it often struggles with our large artifacts
21+
JULIA_PKG_SERVER: ""
2222

2323
- label: "Documentation"
2424
plugins:
@@ -42,5 +42,6 @@ steps:
4242
DATADEPS_ALWAYS_ACCEPT: true
4343
JULIA_DEBUG: "Documenter"
4444
SECRET_DOCUMENTER_KEY: "AdqcYtp4x3U5j1ELurHIoOwURqXcOan+qmihqVjsjhoGUzS/snTyZNQ5fxaJr8Yawm9CyyGvh+Q5O98St1LJ9S+pi9C5TFSbPWnNp/vXabMmeUEVLHVYHUeR2wgMCciSnM/oLw5sNAEj3hrWFjLslEGKQSptUCTWuU5WRizhQONDxeA3tz9biZUYvKanP8GjsHUkD3te15n1t6o78T1+EJxb1znrBSd9aK1Y4UaVjBEfVtLtTD8Z6VP1L4SVXVipxrDdzwzbzUDaTpvjo3z3e9qx2u6Xn5qa/os6JY81jRa5ZTWFkev73DYhoFmordSI85grOPwNpvrNWqOAs5kTDg==;U2FsdGVkX1+TXM0w98SRH5lY0Dw+nmRJ1xtJmffK+GVWHdMjhiIxQoYGGoP065hgl1VOf+oLzqWWFoYIfcz5i/VKD5F7O3EhfLdcmG2Y15u4sxr/hLKKMedSCueiusNd4N9EGGmJ4LLY0I1K6vA8Pa2eRwE+yDE+wfSpqpTC3yxMo40Xk7wra6ZwAybGpSHzOItV+2QGHttr/WLntRbx7GD8HDBC/LGNhtmFzFhAo/2CiQ0qgMDHBxyVqZlAdCWdu9xGmD9FC2+HDGv6QyNge1Ajmg6TNd6tuRhAP6VfDidEtaohcGl2TxbuUcd1OSrSbwmqxcw0IhVriRN8FgB8pKYmol7J71za9ljViyzgAjhQFvute/PYB2nw9MB8yCoNu6X0hqoLdmSxbzerpeYh0yRdi6SedESBJV3PgL7uahnyHbhC4dudFPavobeP9nU0okzKXG7fYfR1aiqgeeed4WFi48u+pciWv6Uo8J4lbBTUXu4xI+yWCpt01LNXyTmsIPYUvqbE8U0DkJNsNie9lw4of3UYkKhtVkLBoQg8++uc6i0w70+/sKZDp2OA5Y1jMFrRQaWUHyaRfpX3pXqvghAfVLEybiSWnpE0JiAnBsDcI3zajc4Cp+lui9G0+E8Lc+NbXOMbjiYHejjN46/03sIHNu0YPlU5p7o2xrGpa3cw6o3yHhBE7yVTcBc7A0AFPGZQxTLOEw1lYf7+B6J5AEpDhxR9+gUhmL+S+2kUw+nxsMxdD0Tunfeg4CIoeB9Tl7uIBrZDQ23uVRrcEyE6t3zf7skBcW3wlrHpAY61CxuGuMolcTl0JaeYFTJPYzOgPa+nD/vKaICsrRDkaSUUHcGufbTgqdJLjIkh+M9a7+DPKpfoT7H4gp4VrocqClFmmPoZZAIKjxXAEnEHWILBw0a9cOar1DKfJVoyN1vQIdVeux90a50Ao62m3sHoYiXY3DeutHkAmfWWDl/5zcU2h1T0XWHmRnjjmAW0fZPL+E38rKXYQECeHMDFEYYfbGyZMJx9T9pwfvxTM4Yzd4nB0qspOXMdeGvnVbzqlnaGJpxs/M5zyxILMQzq979bwSI0TPFRqLojhNezOhZHZaJdFoWq6UqW1kFDyzNIIRCQak0kGuhpCeFqqxiFFtC1M1vskpZ5UfqtCQSgILU+XbJWAxZOqrAxy2T7+h9JMS/jLPW+tyjCJx/bhqSGF9fBt6Q1R4ZL1MjaBSocnMj/5H7IZI2TdH6ulTyigZv7OEMQZRSyTrZgBNPLAiMHVoZVLkZF2NZse/4tHG/7i0Lio+m2z8WsjSQa04LPjtnlCSpYrug8EoGN4NruaRBDBIlTV9w6Rvz4YYB4iDht14ifF6XJVl/uo4jWKHAJC6Bc5IwFD80A/jYmx1vbLwvwVgYGCOW80WSFUGSnBSwVDLsLufXWt9ct8Kql3ICl2/iLO1ZZofELOgddV8yds1vrBdhn8jP1QCrTqtS0ITgLOabDEJMAma7St5R+Oa7kAj2zlVh/A0WXolGD04ReHnuiNN7S6C/ePpSTP/fMSa5bCrmQkw93fEgHdNOpQq9DwYNa9wEijtdEJsw5Kl+B73SNIhG+X7h1sN1DsCao2v1EPtxzaXw51kfJHzhtdCKnKXYap8Lk+twZ6KKH3QZaSsq2LSL7E7da3ZXwo4oRMjV5OTkWaklGKZmzJaMrUnpbJMQUfb4tyNDBJ/52arcTLOn9b72d+927qUfKNCwm+Ma5tUJZ9PkOUxObbXgguXvgVVBL5Li0kfcCyYQC8HcNu1ZkmP7FbJzYo4m/e8v+jASNorC+49BvDE8WlSw+6dJQoP29S0u2OyZ4CSLYvqVDz5WIiLYR9OH2Dk7dB6n69jgngYkEtTF+1TWwQOR0d+6xungqU86W+4JvPkBx4YwVHmnw+iFRNqJd/OTmBVVDYpEkC6N38SuCRAuZwjcVfl6ERm0C7FEJOlmXqs6UeUuPNxuCE7yKcD4J0JGhVjSxN4c3dbV0aipIt9/ob+I2rXA3TUVOU7G+svsboOo1bHlUfoL1HQcasUHwst6ScsrlzJtSLActVb8QMh2iOw5zlxHGyq/MqU+tcquLZc3ctYZwzXatMjEqNqP4nHF1HHkYavwrhFr4U6lbnPZ6ZlUCdrXKDsn7BrRnn2MQ4My7k/Cau6174Zln9RRB4LRs45P3oDUug3Xc8/erekvf8L1HRFsiHzv+8ssvO/dfOEkD1hUTacNejaWt6HXCPC3zXhnwsk/lw3TLXSuxWO3hDpxBuAy8gRmCVZq4HPLAod/lEIpXQY3Rij8mmU28tCnhXrjxTTCe1ci9F4dy2IigN/1YeA5k+6gVpNmV9NvSDlPKN5vkGymFXokBiHkBKOd/uIkNORPYbjEGq17mO4CTfbJ8uDKneibC22VOaqB5Xk3/Xp4zz4TVCCr4xznVui7OOeuNaUptrypsMRtWfYiJ52s8TYiOsQ4SyKmXdCy51k16wGT9/ZrPpEktYFsC6g4SdY04MvBaowPYsudy9uOyUnZFuxPX+SLusRfLHlgkZU2SxrPMlHbOFpFVHKgCiKXLBoDWL3Bkn/9SUfoFYWf65KEfTeKJc7FCxsH/r6ngLMlJidn/fxI0D5FyVh0FVJnpdhGgv8jXo28nVCRLTNwRl5TzUjdz4EIvSxugmFR38qbpdkYeOlCdggE8safaxy+xXpYX0diec3x9MYvV7cWMtOCmTwKkevlPUTL+UuQiWKNf7fOzmGG67wNKfFtwmMxCbQ20nT2eAT6TjZ9eowW/E7s8i3cpCycJjeshuz7Y1G5ECLjSKe2UjCEGhBqRR8T/cBObQBErrilsRgjdO6w4UvBy4FkuzVpvB2cKhMt3HocBLNMlhUHaIgQfh1C6PMRT5S10+e659+qvF0GZOCJ7Y6b5Z9X9XipJEJo8ISC5uR7Z/WmEqgVtTdfezJvnGoibun2FjFriTP1W+HqHcGTpgwfWrVwHuGXYreckkzZaP//Q5uHajxj4AYQhEmjcQ6jcRNAWxOYLKZWqy7d2sDw79wdvKT6kfKg3wiKB/Arl1Mv33b1FkCfr+MRu1nzWeuAG72su6L91T0CVymOQQCZpA2ubPYrIH3vyKmRLvFIgwIwrWRZtsY8dvznOGmHvlEgUw+C2Ln0loc6fDDQ4Jk20LSXzRtsCeJ+PUahaDT9aW4xgWXHTAjGXjaZUak+5rs3ck7ZH3vbvVtbfg3/PKiCKWKXtt/2ZIZTxgdQVlGhZcz/LVzjIPOSTfKgcyz+pVXFq+hC2fubtsgTTm6DkXzP5pFdens05f27nXdr5FYNoYSJcU7w58gHTZOIez7oiMdkG+FObg9u/cpz923hOeBOmYaoQ9JePCQzQkWyPwHVD1IFsGBKtVIwahdKnpmjxubdr2ehdYbC7SsMD772YdPKOx07ipmz56pGBVqG7nUqYDxX2Mr5hvrHKbnLmYQykBSfazMUmK5/c+dNngn7QU8kN0fxFDmDzaUeQcZwbBNVlfp51Gsc+LvYoyXMUhBu53wQgNmq+ZJCSO4V+XBIgJAKIOSlVDq0GXe1VtyfNC4XcU0ey84K0/mD3RmtNGc2YYNp6OPqwzJAIexQcSr8pehN53fqGuGrRX3EqPzxNZwM7W+VzzpT+Ky2jpLl2YrQyuROIAMV09P8HoDxBorSHAQXkijs2ByIAGbQqJhwtbcovSPIMqvHPeKKMuFfNzKnmCkUysklNxQynM2clwKwbOchghZOBNH2sQ4atfhHjdo66dXtzSmngyPujZcwinq1b1VUbG1n9BuusgdUrhpt/28MJRYLt4tJFVBqYGu98Ewa1oX+7xqCmhEe2us43fg7EYBpwLBAVDNsohVO69upLR+Yy2C1lhqJSSbO+JLKg70/7onpMI8JcCtiNYOMFYMix9ynkpBf8gN+cM/VgL4cldHYwbaAJXgnD7PxdmDIy7r8oZnGOHE//a3iDyB+Xqy0t9c41OYYn6PkB32BqRHFUvbzU+6kaDpQD/gk0EBTb51SLmy3IBBLKpKw1R0CVfS2wY5XX7vYYpgAMQzsoZpL3Ep0NpcRqtutcec0o0VXkd3B9wXJhDG+en0MaY9vc6V4g+nT8Z2jZw0A6lXnbDxlQN/CmvvrcsexHGGIj6vjpQs/oSyvOYaD1gVTWdQgcPhCYZGVH5O/llnKfxsRFVU3g6XvL4ND0oQ7S98eHRhz+8TqOx9Se47vEEC1O3bNDf6Rnnm+aB9vD0GKQ6iAETWI74yF5HrCpZY2XIDK8OgOQJoiWpDWaDxfNjK8nWMSjV8bnEdIzLP08p8fLWP/+JPeJkUB91mxmi5mhMjKFpKo/lXtvM2E8zmzkjo/VyLGYaij7EgF1XNIWRC5LWlsrrPiqVfQlmftDzjaG+jCx/47NLws="
45+
JULIA_PKG_SERVER: ""
4546
if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft
4647
timeout_in_minutes: 1000

.github/workflows/CI.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ jobs:
4242
${{ runner.os }}-
4343
- uses: julia-actions/julia-buildpkg@v1
4444
- uses: julia-actions/julia-runtest@v1
45-
with:
46-
coverage: false
4745
env:
4846
GROUP: ${{ matrix.group }}
4947
- uses: julia-actions/julia-processcoverage@v1

Project.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqFlux"
22
uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
4-
version = "3.6.0"
4+
version = "4.0.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -18,6 +18,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1818
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1919
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2020
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
21+
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2122

2223
[weakdeps]
2324
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
@@ -29,7 +30,7 @@ DiffEqFluxDataInterpolationsExt = "DataInterpolations"
2930
ADTypes = "1.5"
3031
Aqua = "0.8.7"
3132
BenchmarkTools = "1.5.0"
32-
Boltz = "0.4.2"
33+
Boltz = "1"
3334
ChainRulesCore = "1"
3435
ComponentArrays = "0.15.17"
3536
ConcreteStructs = "0.2"
@@ -46,11 +47,10 @@ ForwardDiff = "0.10"
4647
Hwloc = "3"
4748
InteractiveUtils = "<0.0.1, 1"
4849
LinearAlgebra = "1.10"
49-
Lux = "0.5.65"
50+
Lux = "1"
5051
LuxCUDA = "0.3.2"
51-
LuxCore = "0.1"
52-
LuxLib = "0.3.50"
53-
MLDatasets = "0.7.14"
52+
LuxCore = "1"
53+
LuxLib = "1.2"
5454
NNlib = "0.9.22"
5555
OneHotArrays = "0.2.5"
5656
Optimisers = "0.3"
@@ -65,6 +65,7 @@ Reexport = "0.2, 1"
6565
SciMLBase = "2"
6666
SciMLSensitivity = "7"
6767
Setfield = "1.1.1"
68+
Static = "1.1.1"
6869
Statistics = "1.10"
6970
StochasticDiffEq = "6.68.0"
7071
Test = "1.10"
@@ -87,7 +88,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8788
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
8889
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
8990
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
90-
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
9191
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
9292
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
9393
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
@@ -105,4 +105,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
105105
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
106106

107107
[targets]
108-
test = ["Aqua", "BenchmarkTools", "ComponentArrays", "DataInterpolations", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "DistributionsAD", "ExplicitImports", "ForwardDiff", "Flux", "Hwloc", "InteractiveUtils", "LuxCUDA", "MLDatasets", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "Random", "ReTestItems", "Reexport", "Statistics", "StochasticDiffEq", "Test", "Zygote"]
108+
test = ["Aqua", "BenchmarkTools", "ComponentArrays", "DataInterpolations", "DelayDiffEq", "DiffEqCallbacks", "Distances", "Distributed", "DistributionsAD", "ExplicitImports", "ForwardDiff", "Flux", "Hwloc", "InteractiveUtils", "LuxCUDA", "NNlib", "OneHotArrays", "Optimisers", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "OrdinaryDiffEq", "Printf", "Random", "ReTestItems", "Reexport", "Statistics", "StochasticDiffEq", "Test", "Zygote"]

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,18 @@ explore various ways to integrate the two methodologies:
6363

6464
## Breaking Changes
6565

66-
### v4 (upcoming)
66+
### v4
6767

6868
- `TensorLayer` has been removed, use `Boltz.Layers.TensorProductLayer` instead.
6969
- Basis functions in DiffEqFlux have been removed in favor of `Boltz.Basis` module.
7070
- `SplineLayer` has been removed, use `Boltz.Layers.SplineLayer` instead.
7171
- `NeuralHamiltonianDE` has been removed, use `NeuralODE` with `Layers.HamiltonianNN` instead.
7272
- `HamiltonianNN` has been removed in favor of `Layers.HamiltonianNN`.
73+
- `Lux` and `Boltz` are updated to v1.
7374

7475
### v3
7576

76-
- Flux dependency is dropped. If a non Lux `AbstractExplicitLayer` is passed we try to automatically convert it to a Lux model with `FromFluxAdaptor()(model)`.
77+
- Flux dependency is dropped. If a non Lux `AbstractLuxLayer` is passed we try to automatically convert it to a Lux model with `FromFluxAdaptor()(model)`.
7778
- `Flux` is no longer re-exported from `DiffEqFlux`. Instead we reexport `Lux`.
7879
- `NeuralDAE` now allows an optional `du0` as input.
7980
- `TensorLayer` is now a Lux Neural Network.

docs/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,17 @@ CUDA = "5"
4242
ComponentArrays = "0.15"
4343
DataDeps = "0.7"
4444
DataFrames = "1"
45-
DiffEqFlux = "3"
45+
DiffEqFlux = "4"
4646
Distances = "0.10.7"
4747
Distributions = "0.25.78"
4848
Documenter = "1"
4949
Flux = "0.14"
5050
ForwardDiff = "0.10"
5151
IterTools = "1"
5252
LinearAlgebra = "1"
53-
Lux = "0.5.5"
53+
Lux = "1"
5454
LuxCUDA = "0.3"
55-
MLDatasets = "0.7"
55+
MLDatasets = "0.7.18"
5656
MLUtils = "0.4"
5757
NNlib = "0.9"
5858
OneHotArrays = "0.2"

docs/src/examples/hamiltonian_nn.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dataloader = ncycle(
3333
for i in 1:(size(data, 2) ÷ B)),
3434
NEPOCHS)
3535
36-
hnn = HamiltonianNN(Chain(Dense(2 => 64, relu), Dense(64 => 1)); ad = AutoZygote())
36+
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote())
3737
ps, st = Lux.setup(Xoshiro(0), hnn)
3838
ps_c = ps |> ComponentArray
3939
@@ -57,7 +57,7 @@ res = Optimization.solve(opt_prob, opt, dataloader; callback)
5757
5858
ps_trained = res.u
5959
60-
model = NeuralHamiltonianDE(
60+
model = NeuralODE(
6161
hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, save_start = true, saveat = t)
6262
6363
pred = Array(first(model(data[:, 1], ps_trained, st)))
@@ -97,10 +97,10 @@ dataloader = ncycle(
9797

9898
### Training the HamiltonianNN
9999

100-
We parameterize the HamiltonianNN with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization.
100+
We parameterize the with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization.
101101

102102
```@example hamiltonian
103-
hnn = HamiltonianNN(Chain(Dense(2 => 64, relu), Dense(64 => 1)); ad = AutoZygote())
103+
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote())
104104
ps, st = Lux.setup(Xoshiro(0), hnn)
105105
ps_c = ps |> ComponentArray
106106
@@ -127,10 +127,11 @@ ps_trained = res.u
127127

128128
### Solving the ODE using trained HNN
129129

130-
In order to visualize the learned trajectories, we need to solve the ODE. We will use the `NeuralHamiltonianDE` layer, which is essentially a wrapper over `HamiltonianNN` layer, and solves the ODE.
130+
In order to visualize the learned trajectories, we need to solve the ODE. We will use the
131+
`NeuralODE` layer with `HamiltonianNN` layer, and solves the ODE.
131132

132133
```@example hamiltonian
133-
model = NeuralHamiltonianDE(
134+
model = NeuralODE(
134135
hnn, (0.0f0, 1.0f0), Tsit5(); save_everystep = false, save_start = true, saveat = t)
135136
136137
pred = Array(first(model(data[:, 1], ps_trained, st)))

docs/src/examples/neural_gde.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using GraphNeuralNetworks, DifferentialEquations
1414
using DiffEqFlux: NeuralODE
1515
using GraphNeuralNetworks.GNNGraphs: normalized_adjacency
1616
using Lux, NNlib, Optimisers, Zygote, Random, ComponentArrays
17-
using Lux: AbstractExplicitLayer, glorot_normal, zeros32
17+
using Lux: AbstractLuxLayer, glorot_normal, zeros32
1818
import Lux: initialparameters, initialstates
1919
using SciMLSensitivity
2020
using Statistics: mean
@@ -46,7 +46,7 @@ nout = length(classes)
4646
epochs = 20
4747

4848
# Define the graph neural network
49-
struct ExplicitGCNConv{F1, F2, F3, F4} <: AbstractExplicitLayer
49+
struct ExplicitGCNConv{F1, F2, F3, F4} <: AbstractLuxLayer
5050
in_chs::Int
5151
out_chs::Int
5252
activation::F1
@@ -152,7 +152,7 @@ using GraphNeuralNetworks, DifferentialEquations
152152
using DiffEqFlux: NeuralODE
153153
using GraphNeuralNetworks.GNNGraphs: normalized_adjacency
154154
using Lux, NNlib, Optimisers, Zygote, Random, ComponentArrays
155-
using Lux: AbstractExplicitLayer, glorot_normal, zeros32
155+
using Lux: AbstractLuxLayer, glorot_normal, zeros32
156156
import Lux: initialparameters, initialstates
157157
using SciMLSensitivity
158158
using Statistics: mean
@@ -207,10 +207,10 @@ epochs = 20
207207

208208
## Define the Graph Neural Network
209209

210-
Here, we define a type of graph neural networks called `GCNConv`. We use the name `ExplicitGCNConv` to avoid naming conflicts with `GraphNeuralNetworks`. For more information on defining a layer with `Lux`, please consult to the [doc](http://lux.csail.mit.edu/dev/introduction/overview/#AbstractExplicitLayer-API).
210+
Here, we define a type of graph neural networks called `GCNConv`. We use the name `ExplicitGCNConv` to avoid naming conflicts with `GraphNeuralNetworks`. For more information on defining a layer with `Lux`, please consult to the [doc](http://lux.csail.mit.edu/dev/introduction/overview/#AbstractLuxLayer-API).
211211

212212
```julia
213-
struct ExplicitGCNConv{F1, F2, F3} <: AbstractExplicitLayer
213+
struct ExplicitGCNConv{F1, F2, F3} <: AbstractLuxLayer
214214
::AbstractMatrix # nomalized_adjacency matrix
215215
in_chs::Int
216216
out_chs::Int

docs/src/examples/tensor_layer.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ Now, we create a TensorLayer that will be able to perform 10th order expansions
3333
a Legendre Basis:
3434

3535
```@example tensor
36-
A = [LegendreBasis(10), LegendreBasis(10)]
37-
nn = TensorLayer(A, 1)
36+
A = [Basis.Legendre(10), Basis.Legendre(10)]
37+
nn = Layers.TensorProductLayer(A, 1)
3838
ps, st = Lux.setup(Xoshiro(0), nn)
3939
ps = ComponentArray(ps)
4040
nn = StatefulLuxLayer{true}(nn, nothing, st)

src/DiffEqFlux.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using ConcreteStructs: @concrete
66
using Distributions: Distributions, ContinuousMultivariateDistribution, Distribution, logpdf
77
using LinearAlgebra: LinearAlgebra, Diagonal, det, tr, mul!
88
using Lux: Lux, Chain, Dense, StatefulLuxLayer, FromFluxAdaptor
9-
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
9+
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer
1010
using LuxLib: batched_matmul
1111
using Random: Random, AbstractRNG, randn!
1212
using Reexport: @reexport
@@ -20,22 +20,22 @@ using SciMLSensitivity: SciMLSensitivity, AdjointLSS, BacksolveAdjoint, EnzymeVJ
2020
SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint,
2121
ZygoteVJP
2222
using Setfield: @set!
23+
using Static: True, False
2324

2425
const CRC = ChainRulesCore
2526

2627
@reexport using ADTypes, Lux, Boltz
2728

2829
fixed_state_type(_) = true
29-
fixed_state_type(::Layers.HamiltonianNN{FST}) where {FST} = FST
30+
fixed_state_type(::Layers.HamiltonianNN{True}) = true
31+
fixed_state_type(::Layers.HamiltonianNN{False}) = false
3032

3133
include("ffjord.jl")
3234
include("neural_de.jl")
3335

3436
include("collocation.jl")
3537
include("multiple_shooting.jl")
3638

37-
include("deprecated.jl")
38-
3939
export NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, AugmentedNDELayer,
4040
NeuralODEMM
4141
export FFJORD, FFJORDDistribution

src/deprecated.jl

Lines changed: 0 additions & 47 deletions
This file was deleted.

src/ffjord.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
abstract type CNFLayer <: LuxCore.AbstractExplicitContainerLayer{(:model,)} end
1+
abstract type CNFLayer <: AbstractLuxWrapperLayer{:model} end
22

33
"""
44
FFJORD(model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...)
@@ -21,7 +21,7 @@ for new values of x.
2121
2222
Arguments:
2323
24-
- `model`: A `Flux.Chain` or `Lux.AbstractExplicitLayer` neural network that defines the
24+
- `model`: A `Flux.Chain` or `Lux.AbstractLuxLayer` neural network that defines the
2525
dynamics of the model.
2626
- `basedist`: Distribution of the base variable. Set to the unit normal by default.
2727
- `input_dims`: Input Dimensions of the model.
@@ -49,7 +49,7 @@ Information Processing Systems, pp. 6572-6583. 2018.
4949
preprint arXiv:1810.01367 (2018).
5050
"""
5151
@concrete struct FFJORD <: CNFLayer
52-
model <: AbstractExplicitLayer
52+
model <: AbstractLuxLayer
5353
basedist <: Union{Nothing, Distribution}
5454
ad
5555
input_dims
@@ -65,7 +65,7 @@ end
6565

6666
function FFJORD(
6767
model, tspan, input_dims, args...; ad = nothing, basedist = nothing, kwargs...)
68-
!(model isa AbstractExplicitLayer) && (model = FromFluxAdaptor()(model))
68+
!(model isa AbstractLuxLayer) && (model = FromFluxAdaptor()(model))
6969
return FFJORD(model, basedist, ad, input_dims, tspan, args, kwargs)
7070
end
7171

0 commit comments

Comments
 (0)