Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 43b6b6b

Browse files
Merge pull request #275 from avik-pal/ap/scimlbase_nullparams
Default to the older behavior
2 parents 9a42b50 + 77a337a commit 43b6b6b

File tree

4 files changed

+58
-23
lines changed

4 files changed

+58
-23
lines changed

Project.toml

+1-3
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ Zygote = "0.6"
6262
julia = "1.6"
6363

6464
[extras]
65-
ArrayInterfaceBandedMatrices = "2e50d22c-5be1-4042-81b1-c572ed69783d"
66-
ArrayInterfaceBlockBandedMatrices = "5331f1e9-51c7-46b0-a9b0-df4434785e0a"
6765
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
6866
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
6967
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
@@ -77,4 +75,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7775
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7876

7977
[targets]
80-
test = ["Test", "ArrayInterfaceBandedMatrices", "ArrayInterfaceBlockBandedMatrices", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]
78+
test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"]

src/differentiation/common.jl

+52-16
Original file line numberDiff line numberDiff line change
@@ -38,34 +38,70 @@ __internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop
3838
(f::JacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p)
3939
(f::JacFunctionWrapper{false, true, 3})(u) = f.f(u)
4040

41-
function JacFunctionWrapper(f::F, fu_, u, p, t) where {F}
41+
# NOTE: `use_deprecated_ordering` is a way for external libraries to update to the correct
42+
# style. In the next release, we will drop the first check
43+
function JacFunctionWrapper(f::F, fu_, u, p, t;
44+
use_deprecated_ordering::Val{deporder} = Val(true)) where {F, deporder}
4245
# The warning instead of error ensures a non-breaking change for users relying on an
4346
# undefined / undocumented feature
4447
fu = fu_ === nothing ? copy(u) : copy(fu_)
48+
49+
if deporder
50+
# Check this first else we were breaking things
51+
# In the next breaking release, we will fix the ordering of the checks
52+
iip = static_hasmethod(f, typeof((fu, u)))
53+
oop = static_hasmethod(f, typeof((u,)))
54+
if iip || oop
55+
if p !== nothing || t !== nothing
56+
Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we
57+
potentially detected `f(du, u)` or `f(u)`. This can be caused by:
58+
59+
1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not
60+
be supplied.
61+
2. `f(args...)` is defined, in which case `hasmethod` can be spurious.
62+
63+
Currently, we perform the check for `f(du, u)` and `f(u)` first, but in
64+
future breaking releases, this check will be performed last, which means
65+
that if `t` is provided `f(du, u, p, t)`/`f(u, p, t)` will be given
66+
precedence, similarly if `p` is provided `f(du, u, p)`/`f(u, p)` will be
67+
given precedence.""", :JacFunctionWrapper)
68+
end
69+
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
70+
fu, p, t)
71+
end
72+
end
73+
4574
if t !== nothing
4675
iip = static_hasmethod(f, typeof((fu, u, p, t)))
4776
oop = static_hasmethod(f, typeof((u, p, t)))
4877
if !iip && !oop
49-
@warn """`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` not defined
50-
for `f`! Will fallback to `f(u)` or `f(fu, u)`.""" maxlog=1
51-
else
52-
return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
53-
fu, p, t)
78+
throw(ArgumentError("""`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)`
79+
not defined for `f`!"""))
5480
end
81+
return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
82+
fu, p, t)
5583
elseif p !== nothing
5684
iip = static_hasmethod(f, typeof((fu, u, p)))
5785
oop = static_hasmethod(f, typeof((u, p)))
5886
if !iip && !oop
59-
@warn """`p` provided but `f(u, p)` or `f(fu, u, p)` not defined for `f`! Will
60-
fallback to `f(u)` or `f(fu, u)`.""" maxlog=1
61-
else
62-
return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f,
63-
fu, p, t)
87+
throw(ArgumentError("""`p` is provided but `f(u, p)` or `f(fu, u, p)`
88+
not defined for `f`!"""))
89+
end
90+
return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f,
91+
fu, p, t)
92+
end
93+
94+
if !deporder
95+
iip = static_hasmethod(f, typeof((fu, u)))
96+
oop = static_hasmethod(f, typeof((u,)))
97+
if !iip && !oop
98+
throw(ArgumentError("""`p` is provided but `f(u)` or `f(fu, u)` not defined for
99+
`f`!"""))
64100
end
101+
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
102+
fu, p, t)
103+
else
104+
throw(ArgumentError("""Couldn't determine the function signature of `f` to
105+
construct a JacobianWrapper!"""))
65106
end
66-
iip = static_hasmethod(f, typeof((fu, u)))
67-
oop = static_hasmethod(f, typeof((u,)))
68-
!iip && !oop && throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`"))
69-
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
70-
fu, p, t)
71107
end

src/differentiation/jaches_products.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,9 @@ f(du, u) # Otherwise
263263
```
264264
"""
265265
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
266-
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
267-
ff = JacFunctionWrapper(f, fu, u, p, t)
266+
autodiff = AutoForwardDiff(), tag = DeivVecTag(),
267+
use_deprecated_ordering::Val = Val(true), kwargs...)
268+
ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering)
268269
fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u)
269270

270271
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff

src/differentiation/vecjac_products.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ f(du, u) # Otherwise
7272
```
7373
"""
7474
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
75-
autodiff = AutoFiniteDiff(), kwargs...)
76-
ff = JacFunctionWrapper(f, fu, u, p, t)
75+
autodiff = AutoFiniteDiff(), use_deprecated_ordering::Val = Val(true), kwargs...)
76+
ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering)
7777

7878
if !__internal_oop(ff) && autodiff isa AutoZygote
7979
msg = "Zygote requires an out of place method with signature f(u)."

0 commit comments

Comments
 (0)