Skip to content

Force specialization in some key places #2017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,10 @@ Returns a `StaticArray` containing the components of `a` in its stored basis.
"""
components(a::AxisTensor) = getfield(a, :components)

Base.@propagate_inbounds Base.getindex(v::AxisTensor, i::Int...) =
getindex(components(v), i...)
Base.@propagate_inbounds Base.getindex(
v::AxisTensor,
i::Vararg{Int, N},
) where {N} = getindex(components(v), i...)


Base.@propagate_inbounds function Base.getindex(
Expand Down
2 changes: 2 additions & 0 deletions src/Geometry/rmul_with_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,10 @@ rmul_return_type(::Type{X}, ::Type{Y}) where {X, Y} =
rmaptype((X′, Y′) -> mul_return_type(X′, Y′), X, Y)
rmul_return_type(::Type{X}, ::Type{Y}) where {X <: SingleValue, Y} =
rmaptype(Y′ -> mul_return_type(X, Y′), Y)
# rmaptype(Base.Fix1(mul_return_type, X), Y)
rmul_return_type(::Type{X}, ::Type{Y}) where {X, Y <: SingleValue} =
rmaptype(X′ -> mul_return_type(X′, Y), X)
# rmaptype(Base.Fix2(mul_return_type, Y), X)
rmul_return_type(
::Type{X},
::Type{Y},
Expand Down
6 changes: 4 additions & 2 deletions src/MatrixFields/band_matrix_row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} =

@inline lower_diagonal(::Tuple{<:BandMatrixRow{ld}}) where {ld} = ld
@inline lower_diagonal(t::Tuple) = lower_diagonal(t...)
@inline lower_diagonal(::BandMatrixRow{ld}, ::BandMatrixRow{ld}...) where {ld} =
ld
@inline lower_diagonal(
::BandMatrixRow{ld},
::Vararg{BandMatrixRow{ld}, N},
) where {ld, N} = ld

"""
band_matrix_row_type(ld, ud, T)
Expand Down
71 changes: 49 additions & 22 deletions src/MatrixFields/operator_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,27 +241,34 @@ Operators.get_boundary(
rbw::Operators.RightBoundaryWindow{name},
) where {name} = Operators.get_boundary(op_matrix.op, rbw)

Operators.stencil_interior_width(op_matrix::FDOperatorMatrix, args...) =
Operators.stencil_interior_width(op_matrix.op, args...)
Operators.stencil_interior_width(
op_matrix::FDOperatorMatrix,
args::Vararg{Any, N},
) where {N} = Operators.stencil_interior_width(op_matrix.op, args...)

Operators.left_interior_idx(
space::Spaces.AbstractSpace,
op_matrix::FDOperatorMatrix,
bc::Operators.AbstractBoundaryCondition,
args...,
) = Operators.left_interior_idx(space, op_matrix.op, bc, args...)
args::Vararg{Any, N},
) where {N} = Operators.left_interior_idx(space, op_matrix.op, bc, args...)

Operators.right_interior_idx(
space::Spaces.AbstractSpace,
op_matrix::FDOperatorMatrix,
bc::Operators.AbstractBoundaryCondition,
args...,
) = Operators.right_interior_idx(space, op_matrix.op, bc, args...)
args::Vararg{Any, N},
) where {N} = Operators.right_interior_idx(space, op_matrix.op, bc, args...)

Operators.return_space(op_matrix::FDOperatorMatrix, spaces...) =
Operators.return_space(op_matrix.op, spaces...)
Operators.return_space(
op_matrix::FDOperatorMatrix,
spaces::Vararg{Any, N},
) where {N} = Operators.return_space(op_matrix.op, spaces...)

function Operators.return_eltype(op_matrix::FDOperatorMatrix, args...)
function Operators.return_eltype(
op_matrix::FDOperatorMatrix,
args::Vararg{Any, N},
) where {N}
args′ = args[1:(end - 1)]
FT = Geometry.undertype(eltype(args[end]))
return op_matrix_row_type(op_matrix.op, FT, args′...)
Expand All @@ -273,8 +280,8 @@ Base.@propagate_inbounds function Operators.stencil_interior(
space,
idx,
hidx,
args...,
)
args::Vararg{Any, N},
) where {N}
args′ = args[1:(end - 1)]
row = op_matrix_interior_row(op_matrix.op, space, loc, idx, hidx, args′...)
return convert(Operators.return_eltype(op_matrix, args...), row)
Expand All @@ -287,8 +294,8 @@ Base.@propagate_inbounds function Operators.stencil_left_boundary(
space,
idx,
hidx,
args...,
)
args::Vararg{Any, N},
) where {N}
args′ = args[1:(end - 1)]
row = op_matrix_first_row(op_matrix.op, bc, space, loc, idx, hidx, args′...)
return convert(Operators.return_eltype(op_matrix, args...), row)
Expand All @@ -301,22 +308,42 @@ Base.@propagate_inbounds function Operators.stencil_right_boundary(
space,
idx,
hidx,
args...,
)
args::Vararg{Any, N},
) where {N}
args′ = args[1:(end - 1)]
row = op_matrix_last_row(op_matrix.op, bc, space, loc, idx, hidx, args′...)
return convert(Operators.return_eltype(op_matrix, args...), row)
end

# Simplified methods for when the operator matrix only depends on FT.
op_matrix_row_type(op, ::Type{FT}, args...) where {FT} =
op_matrix_row_type(op, ::Type{FT}, args::Vararg{Any, N}) where {FT, N} =
typeof(op_matrix_interior_row(op, FT))
op_matrix_interior_row(op, space, loc, idx, hidx, args...) =
op_matrix_interior_row(op, Spaces.undertype(space))
op_matrix_first_row(op, bc, space, loc, idx, hidx, args...) =
op_matrix_first_row(op, bc, Spaces.undertype(space))
op_matrix_last_row(op, bc, space, loc, idx, hidx, args...) =
op_matrix_last_row(op, bc, Spaces.undertype(space))
op_matrix_interior_row(
op,
space,
loc,
idx,
hidx,
args::Vararg{Any, N},
) where {N} = op_matrix_interior_row(op, Spaces.undertype(space))
op_matrix_first_row(
op,
bc,
space,
loc,
idx,
hidx,
args::Vararg{Any, N},
) where {N} = op_matrix_first_row(op, bc, Spaces.undertype(space))
op_matrix_last_row(
op,
bc,
space,
loc,
idx,
hidx,
args::Vararg{Any, N},
) where {N} = op_matrix_last_row(op, bc, Spaces.undertype(space))

################################################################################

Expand Down
15 changes: 8 additions & 7 deletions src/Operators/operator2stencil.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,29 +74,30 @@ get_boundary(op::Operator2Stencil, bw::LeftBoundaryWindow{name}) where {name} =
get_boundary(op::Operator2Stencil, bw::RightBoundaryWindow{name}) where {name} =
get_boundary(op.op, bw)

function return_eltype(op::Operator2Stencil, args...)
function return_eltype(op::Operator2Stencil, args::Vararg{Any, M}) where {M}
lbw, ubw = stencil_interior_width(op.op, args...)[1]
N = ubw - lbw + 1
return StencilCoefs{lbw, ubw, NTuple{N, return_eltype(op.op, args...)}}
end

return_space(op::Operator2Stencil, spaces...) = return_space(op.op, spaces...)
return_space(op::Operator2Stencil, spaces::Vararg{Any, N}) where {N} =
return_space(op.op, spaces...)

stencil_interior_width(op::Operator2Stencil, args...) =
stencil_interior_width(op::Operator2Stencil, args::Vararg{Any, N}) where {N} =
stencil_interior_width(op.op, args...)

left_interior_idx(
space::AbstractSpace,
op::Operator2Stencil,
bc::AbstractBoundaryCondition,
args...,
) = left_interior_idx(space, op.op, bc, args...)
args::Vararg{Any, N},
) where {N} = left_interior_idx(space, op.op, bc, args...)
right_interior_idx(
space::AbstractSpace,
op::Operator2Stencil,
bc::AbstractBoundaryCondition,
args...,
) = right_interior_idx(space, op.op, bc, args...)
args::Vararg{Any, N},
) where {N} = right_interior_idx(space, op.op, bc, args...)

# TODO: find out why using Base.@propagate_inbounds blows up compilation time
function stencil_interior(
Expand Down
20 changes: 17 additions & 3 deletions src/Operators/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,11 @@ end
Calls `resolve_operator(arg, slabidx)` for each `arg` in `args`
"""
@inline _resolve_operator_args(slabidx) = ()
Base.@propagate_inbounds _resolve_operator_args(slabidx, arg, xargs...) = (
Base.@propagate_inbounds _resolve_operator_args(
slabidx,
arg,
xargs::Vararg{Any, N},
) where {N} = (
resolve_operator(arg, slabidx),
_resolve_operator_args(slabidx, xargs...)...,
)
Expand Down Expand Up @@ -270,7 +274,11 @@ end
end

@inline _reconstruct_placeholder_broadcasted(parent_space) = ()
@inline _reconstruct_placeholder_broadcasted(parent_space, arg, xargs...) = (
@inline _reconstruct_placeholder_broadcasted(
parent_space,
arg,
xargs::Vararg{Any, N},
) where {N} = (
reconstruct_placeholder_broadcasted(parent_space, arg),
_reconstruct_placeholder_broadcasted(parent_space, xargs...)...,
)
Expand Down Expand Up @@ -310,7 +318,13 @@ end
end

@inline _get_node(space, ij, slabidx) = ()
Base.@propagate_inbounds _get_node(space, ij, slabidx, arg, xargs...) = (
Base.@propagate_inbounds _get_node(
space,
ij,
slabidx,
arg,
xargs::Vararg{Any, N},
) where {N} = (
get_node(space, arg, ij, slabidx),
_get_node(space, ij, slabidx, xargs...)...,
)
Expand Down
6 changes: 4 additions & 2 deletions src/RecursiveApply/RecursiveApply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ rmaptype(

Recursively apply `promote_type` to the input types.
"""
rpromote_type(Ts...) = reduce((T1, T2) -> rmaptype(promote_type, T1, T2), Ts)
rpromote_type(Ts::Vararg{Any, N}) where {N} =
reduce((T1, T2) -> rmaptype(promote_type, T1, T2), Ts)
rpromote_type() = Union{}

"""
Expand Down Expand Up @@ -172,7 +173,8 @@ const ⊞ = radd
# Adapted from Base/operators.jl for general nary operator fallbacks
for op in (:rmul, :radd)
@eval begin
($op)(a, b, c, xs...) = Base.afoldl($op, ($op)(($op)(a, b), c), xs...)
($op)(a, b, c, xs::Vararg{Any, N}) where {N} =
Base.afoldl($op, ($op)(($op)(a, b), c), xs...)
end
end

Expand Down
Loading