From 1fe3ffa0516ad873e4ab0ded2233e7a0494dfbd2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 2 May 2025 23:53:51 -0600 Subject: [PATCH] Fix default adjoint for AppleAccelerate and MKL This can't be easily tested because it's very architecture-dependent, but it fixes https://github.com/SciML/LinearSolve.jl/issues/601 --- src/adjoint.jl | 8 ++++++++ src/default.jl | 14 +++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/adjoint.jl b/src/adjoint.jl index 4781602fb..02e4b068b 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -28,6 +28,14 @@ specific structure distinct from ``A`` then passing in a `linsolve` will be more linsolve::L = missing end +function CRC.rrule(T::typeof(SciMLBase.solve), prob::LinearProblem, alg::Nothing, args...; kwargs...) + @show "here?" + assump = OperatorAssumptions(issquare(prob.A)) + alg = defaultalg(prob.A, prob.b, assump) + @show alg + CRC.rrule(T, prob, alg, args...; kwargs...) +end + function CRC.rrule(::typeof(SciMLBase.solve), prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A( alg, prob.A, prob.b), kwargs...) diff --git a/src/default.jl b/src/default.jl index f5772ade4..7b642425d 100644 --- a/src/default.jl +++ b/src/default.jl @@ -364,12 +364,20 @@ end @generated function defaultalg_adjoint_eval(cache::LinearCache, dy) ex = :() for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T)) - newex = if alg in Symbol.((DefaultAlgorithmChoice.MKLLUFactorization, - DefaultAlgorithmChoice.AppleAccelerateLUFactorization, - DefaultAlgorithmChoice.RFLUFactorization)) + newex = if alg == Symbol(DefaultAlgorithmChoice.RFLUFactorization) quote getproperty(cache.cacheval, $(Meta.quot(alg)))[1]' \ dy end + elseif alg == Symbol(DefaultAlgorithmChoice.MKLLUFactorization) + quote + A = getproperty(cache.cacheval, $(Meta.quot(alg)))[1] + getrs!('T', A.factors, A.ipiv, dy) + end + elseif alg == Symbol(DefaultAlgorithmChoice.AppleAccelerateLUFactorization) + quote + A = getproperty(cache.cacheval, $(Meta.quot(alg)))[1] + aa_getrs!('T', A.factors, A.ipiv, dy) + end elseif alg in Symbol.((DefaultAlgorithmChoice.LUFactorization, DefaultAlgorithmChoice.QRFactorization, DefaultAlgorithmChoice.KLUFactorization,