Skip to content

Commit 00a76de

Browse files
Force specialization in some key places
1 parent 846edd8 commit 00a76de

File tree

7 files changed

+88
-38
lines changed

7 files changed

+88
-38
lines changed

src/Geometry/axistensors.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,10 @@ Returns a `StaticArray` containing the components of `a` in its stored basis.
199199
"""
200200
components(a::AxisTensor) = getfield(a, :components)
201201

202-
Base.@propagate_inbounds Base.getindex(v::AxisTensor, i::Int...) =
203-
getindex(components(v), i...)
202+
Base.@propagate_inbounds Base.getindex(
203+
v::AxisTensor,
204+
i::Vararg{Int, N},
205+
) where {N} = getindex(components(v), i...)
204206

205207

206208
Base.@propagate_inbounds function Base.getindex(

src/Geometry/rmul_with_projection.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,10 @@ rmul_return_type(::Type{X}, ::Type{Y}) where {X, Y} =
140140
rmaptype((X′, Y′) -> mul_return_type(X′, Y′), X, Y)
141141
rmul_return_type(::Type{X}, ::Type{Y}) where {X <: SingleValue, Y} =
142142
rmaptype(Y′ -> mul_return_type(X, Y′), Y)
143+
# rmaptype(Base.Fix1(mul_return_type, X), Y)
143144
rmul_return_type(::Type{X}, ::Type{Y}) where {X, Y <: SingleValue} =
144145
rmaptype(X′ -> mul_return_type(X′, Y), X)
146+
# rmaptype(Base.Fix2(mul_return_type, Y), X)
145147
rmul_return_type(
146148
::Type{X},
147149
::Type{Y},

src/MatrixFields/band_matrix_row.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} =
4242

4343
@inline lower_diagonal(::Tuple{<:BandMatrixRow{ld}}) where {ld} = ld
4444
@inline lower_diagonal(t::Tuple) = lower_diagonal(t...)
45-
@inline lower_diagonal(::BandMatrixRow{ld}, ::BandMatrixRow{ld}...) where {ld} =
46-
ld
45+
@inline lower_diagonal(
46+
::BandMatrixRow{ld},
47+
::Vararg{BandMatrixRow{ld}, N},
48+
) where {ld, N} = ld
4749

4850
"""
4951
band_matrix_row_type(ld, ud, T)

src/MatrixFields/operator_matrices.jl

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -241,27 +241,34 @@ Operators.get_boundary(
241241
rbw::Operators.RightBoundaryWindow{name},
242242
) where {name} = Operators.get_boundary(op_matrix.op, rbw)
243243

244-
Operators.stencil_interior_width(op_matrix::FDOperatorMatrix, args...) =
245-
Operators.stencil_interior_width(op_matrix.op, args...)
244+
Operators.stencil_interior_width(
245+
op_matrix::FDOperatorMatrix,
246+
args::Vararg{Any, N},
247+
) where {N} = Operators.stencil_interior_width(op_matrix.op, args...)
246248

247249
Operators.left_interior_idx(
248250
space::Spaces.AbstractSpace,
249251
op_matrix::FDOperatorMatrix,
250252
bc::Operators.AbstractBoundaryCondition,
251-
args...,
252-
) = Operators.left_interior_idx(space, op_matrix.op, bc, args...)
253+
args::Vararg{Any, N},
254+
) where {N} = Operators.left_interior_idx(space, op_matrix.op, bc, args...)
253255

254256
Operators.right_interior_idx(
255257
space::Spaces.AbstractSpace,
256258
op_matrix::FDOperatorMatrix,
257259
bc::Operators.AbstractBoundaryCondition,
258-
args...,
259-
) = Operators.right_interior_idx(space, op_matrix.op, bc, args...)
260+
args::Vararg{Any, N},
261+
) where {N} = Operators.right_interior_idx(space, op_matrix.op, bc, args...)
260262

261-
Operators.return_space(op_matrix::FDOperatorMatrix, spaces...) =
262-
Operators.return_space(op_matrix.op, spaces...)
263+
Operators.return_space(
264+
op_matrix::FDOperatorMatrix,
265+
spaces::Vararg{Any, N},
266+
) where {N} = Operators.return_space(op_matrix.op, spaces...)
263267

264-
function Operators.return_eltype(op_matrix::FDOperatorMatrix, args...)
268+
function Operators.return_eltype(
269+
op_matrix::FDOperatorMatrix,
270+
args::Vararg{Any, N},
271+
) where {N}
265272
args′ = args[1:(end - 1)]
266273
FT = Geometry.undertype(eltype(args[end]))
267274
return op_matrix_row_type(op_matrix.op, FT, args′...)
@@ -273,8 +280,8 @@ Base.@propagate_inbounds function Operators.stencil_interior(
273280
space,
274281
idx,
275282
hidx,
276-
args...,
277-
)
283+
args::Vararg{Any, N},
284+
) where {N}
278285
args′ = args[1:(end - 1)]
279286
row = op_matrix_interior_row(op_matrix.op, space, loc, idx, hidx, args′...)
280287
return convert(Operators.return_eltype(op_matrix, args...), row)
@@ -287,8 +294,8 @@ Base.@propagate_inbounds function Operators.stencil_left_boundary(
287294
space,
288295
idx,
289296
hidx,
290-
args...,
291-
)
297+
args::Vararg{Any, N},
298+
) where {N}
292299
args′ = args[1:(end - 1)]
293300
row = op_matrix_first_row(op_matrix.op, bc, space, loc, idx, hidx, args′...)
294301
return convert(Operators.return_eltype(op_matrix, args...), row)
@@ -301,22 +308,42 @@ Base.@propagate_inbounds function Operators.stencil_right_boundary(
301308
space,
302309
idx,
303310
hidx,
304-
args...,
305-
)
311+
args::Vararg{Any, N},
312+
) where {N}
306313
args′ = args[1:(end - 1)]
307314
row = op_matrix_last_row(op_matrix.op, bc, space, loc, idx, hidx, args′...)
308315
return convert(Operators.return_eltype(op_matrix, args...), row)
309316
end
310317

311318
# Simplified methods for when the operator matrix only depends on FT.
312-
op_matrix_row_type(op, ::Type{FT}, args...) where {FT} =
319+
op_matrix_row_type(op, ::Type{FT}, args::Vararg{Any, N}) where {FT, N} =
313320
typeof(op_matrix_interior_row(op, FT))
314-
op_matrix_interior_row(op, space, loc, idx, hidx, args...) =
315-
op_matrix_interior_row(op, Spaces.undertype(space))
316-
op_matrix_first_row(op, bc, space, loc, idx, hidx, args...) =
317-
op_matrix_first_row(op, bc, Spaces.undertype(space))
318-
op_matrix_last_row(op, bc, space, loc, idx, hidx, args...) =
319-
op_matrix_last_row(op, bc, Spaces.undertype(space))
321+
op_matrix_interior_row(
322+
op,
323+
space,
324+
loc,
325+
idx,
326+
hidx,
327+
args::Vararg{Any, N},
328+
) where {N} = op_matrix_interior_row(op, Spaces.undertype(space))
329+
op_matrix_first_row(
330+
op,
331+
bc,
332+
space,
333+
loc,
334+
idx,
335+
hidx,
336+
args::Vararg{Any, N},
337+
) where {N} = op_matrix_first_row(op, bc, Spaces.undertype(space))
338+
op_matrix_last_row(
339+
op,
340+
bc,
341+
space,
342+
loc,
343+
idx,
344+
hidx,
345+
args::Vararg{Any, N},
346+
) where {N} = op_matrix_last_row(op, bc, Spaces.undertype(space))
320347

321348
################################################################################
322349

src/Operators/operator2stencil.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,29 +74,30 @@ get_boundary(op::Operator2Stencil, bw::LeftBoundaryWindow{name}) where {name} =
7474
get_boundary(op::Operator2Stencil, bw::RightBoundaryWindow{name}) where {name} =
7575
get_boundary(op.op, bw)
7676

77-
function return_eltype(op::Operator2Stencil, args...)
77+
function return_eltype(op::Operator2Stencil, args::Vararg{T}) where {T}
7878
lbw, ubw = stencil_interior_width(op.op, args...)[1]
7979
N = ubw - lbw + 1
8080
return StencilCoefs{lbw, ubw, NTuple{N, return_eltype(op.op, args...)}}
8181
end
8282

83-
return_space(op::Operator2Stencil, spaces...) = return_space(op.op, spaces...)
83+
return_space(op::Operator2Stencil, spaces::Vararg{T}) where {T} =
84+
return_space(op.op, spaces...)
8485

85-
stencil_interior_width(op::Operator2Stencil, args...) =
86+
stencil_interior_width(op::Operator2Stencil, args::Vararg{T}) where {T} =
8687
stencil_interior_width(op.op, args...)
8788

8889
left_interior_idx(
8990
space::AbstractSpace,
9091
op::Operator2Stencil,
9192
bc::AbstractBoundaryCondition,
92-
args...,
93-
) = left_interior_idx(space, op.op, bc, args...)
93+
args::Vararg{T},
94+
) where {T} = left_interior_idx(space, op.op, bc, args...)
9495
right_interior_idx(
9596
space::AbstractSpace,
9697
op::Operator2Stencil,
9798
bc::AbstractBoundaryCondition,
98-
args...,
99-
) = right_interior_idx(space, op.op, bc, args...)
99+
args::Vararg{T},
100+
) where {T} = right_interior_idx(space, op.op, bc, args...)
100101

101102
# TODO: find out why using Base.@propagate_inbounds blows up compilation time
102103
function stencil_interior(

src/Operators/spectralelement.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,11 @@ end
223223
Calls `resolve_operator(arg, slabidx)` for each `arg` in `args`
224224
"""
225225
@inline _resolve_operator_args(slabidx) = ()
226-
Base.@propagate_inbounds _resolve_operator_args(slabidx, arg, xargs...) = (
226+
Base.@propagate_inbounds _resolve_operator_args(
227+
slabidx,
228+
arg,
229+
xargs::Vararg{Any, N},
230+
) where {N} = (
227231
resolve_operator(arg, slabidx),
228232
_resolve_operator_args(slabidx, xargs...)...,
229233
)
@@ -270,7 +274,11 @@ end
270274
end
271275

272276
@inline _reconstruct_placeholder_broadcasted(parent_space) = ()
273-
@inline _reconstruct_placeholder_broadcasted(parent_space, arg, xargs...) = (
277+
@inline _reconstruct_placeholder_broadcasted(
278+
parent_space,
279+
arg,
280+
xargs::Vararg{Any, N},
281+
) where {N} = (
274282
reconstruct_placeholder_broadcasted(parent_space, arg),
275283
_reconstruct_placeholder_broadcasted(parent_space, xargs...)...,
276284
)
@@ -310,7 +318,13 @@ end
310318
end
311319

312320
@inline _get_node(space, ij, slabidx) = ()
313-
Base.@propagate_inbounds _get_node(space, ij, slabidx, arg, xargs...) = (
321+
Base.@propagate_inbounds _get_node(
322+
space,
323+
ij,
324+
slabidx,
325+
arg,
326+
xargs::Vararg{Any, N},
327+
) where {N} = (
314328
get_node(space, arg, ij, slabidx),
315329
_get_node(space, ij, slabidx, xargs...)...,
316330
)

src/RecursiveApply/RecursiveApply.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ rmaptype(
123123
124124
Recursively apply `promote_type` to the input types.
125125
"""
126-
rpromote_type(Ts...) = reduce((T1, T2) -> rmaptype(promote_type, T1, T2), Ts)
126+
rpromote_type(Ts::Vararg{T, N}) where {T, N} =
127+
reduce((T1, T2) -> rmaptype(promote_type, T1, T2), Ts)
127128
rpromote_type() = Union{}
128129

129130
"""
@@ -172,7 +173,8 @@ const ⊞ = radd
172173
# Adapted from Base/operators.jl for general nary operator fallbacks
173174
for op in (:rmul, :radd)
174175
@eval begin
175-
($op)(a, b, c, xs...) = Base.afoldl($op, ($op)(($op)(a, b), c), xs...)
176+
($op)(a, b, c, xs::Vararg{T, N}) where {T, N} =
177+
Base.afoldl($op, ($op)(($op)(a, b), c), xs...)
176178
end
177179
end
178180

0 commit comments

Comments
 (0)