From d6a04024034fd24b0772a3d8caaed5cb63cfab98 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Mon, 30 Sep 2024 10:36:44 -0400 Subject: [PATCH] Force specialization in some key places --- src/Geometry/axistensors.jl | 6 ++- src/Geometry/rmul_with_projection.jl | 2 + src/MatrixFields/band_matrix_row.jl | 6 ++- src/MatrixFields/operator_matrices.jl | 71 ++++++++++++++++++--------- src/Operators/operator2stencil.jl | 15 +++--- src/Operators/spectralelement.jl | 20 ++++++-- src/RecursiveApply/RecursiveApply.jl | 6 ++- 7 files changed, 88 insertions(+), 38 deletions(-) diff --git a/src/Geometry/axistensors.jl b/src/Geometry/axistensors.jl index 534e628380..9e1c45e43a 100644 --- a/src/Geometry/axistensors.jl +++ b/src/Geometry/axistensors.jl @@ -199,8 +199,10 @@ Returns a `StaticArray` containing the components of `a` in its stored basis. """ components(a::AxisTensor) = getfield(a, :components) -Base.@propagate_inbounds Base.getindex(v::AxisTensor, i::Int...) = - getindex(components(v), i...) +Base.@propagate_inbounds Base.getindex( + v::AxisTensor, + i::Vararg{Int, N}, +) where {N} = getindex(components(v), i...) Base.@propagate_inbounds function Base.getindex( diff --git a/src/Geometry/rmul_with_projection.jl b/src/Geometry/rmul_with_projection.jl index 0d9ce3b77e..aee493a29f 100644 --- a/src/Geometry/rmul_with_projection.jl +++ b/src/Geometry/rmul_with_projection.jl @@ -140,8 +140,10 @@ rmul_return_type(::Type{X}, ::Type{Y}) where {X, Y} = rmaptype((X′, Y′) -> mul_return_type(X′, Y′), X, Y) rmul_return_type(::Type{X}, ::Type{Y}) where {X <: SingleValue, Y} = rmaptype(Y′ -> mul_return_type(X, Y′), Y) +# rmaptype(Base.Fix1(mul_return_type, X), Y) rmul_return_type(::Type{X}, ::Type{Y}) where {X, Y <: SingleValue} = rmaptype(X′ -> mul_return_type(X′, Y), X) +# rmaptype(Base.Fix2(mul_return_type, Y), X) rmul_return_type( ::Type{X}, ::Type{Y}, diff --git a/src/MatrixFields/band_matrix_row.jl b/src/MatrixFields/band_matrix_row.jl index 2b2355f16e..75e49fbcef 100644 --- a/src/MatrixFields/band_matrix_row.jl +++ b/src/MatrixFields/band_matrix_row.jl @@ -42,8 +42,10 @@ outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} = @inline lower_diagonal(::Tuple{<:BandMatrixRow{ld}}) where {ld} = ld @inline lower_diagonal(t::Tuple) = lower_diagonal(t...) -@inline lower_diagonal(::BandMatrixRow{ld}, ::BandMatrixRow{ld}...) where {ld} = - ld +@inline lower_diagonal( + ::BandMatrixRow{ld}, + ::Vararg{BandMatrixRow{ld}, N}, +) where {ld, N} = ld """ band_matrix_row_type(ld, ud, T) diff --git a/src/MatrixFields/operator_matrices.jl b/src/MatrixFields/operator_matrices.jl index 7eaa097968..533bc3da31 100644 --- a/src/MatrixFields/operator_matrices.jl +++ b/src/MatrixFields/operator_matrices.jl @@ -241,27 +241,34 @@ Operators.get_boundary( rbw::Operators.RightBoundaryWindow{name}, ) where {name} = Operators.get_boundary(op_matrix.op, rbw) -Operators.stencil_interior_width(op_matrix::FDOperatorMatrix, args...) = - Operators.stencil_interior_width(op_matrix.op, args...) +Operators.stencil_interior_width( + op_matrix::FDOperatorMatrix, + args::Vararg{Any, N}, +) where {N} = Operators.stencil_interior_width(op_matrix.op, args...) Operators.left_interior_idx( space::Spaces.AbstractSpace, op_matrix::FDOperatorMatrix, bc::Operators.AbstractBoundaryCondition, - args..., -) = Operators.left_interior_idx(space, op_matrix.op, bc, args...) + args::Vararg{Any, N}, +) where {N} = Operators.left_interior_idx(space, op_matrix.op, bc, args...) Operators.right_interior_idx( space::Spaces.AbstractSpace, op_matrix::FDOperatorMatrix, bc::Operators.AbstractBoundaryCondition, - args..., -) = Operators.right_interior_idx(space, op_matrix.op, bc, args...) + args::Vararg{Any, N}, +) where {N} = Operators.right_interior_idx(space, op_matrix.op, bc, args...) -Operators.return_space(op_matrix::FDOperatorMatrix, spaces...) = - Operators.return_space(op_matrix.op, spaces...) +Operators.return_space( + op_matrix::FDOperatorMatrix, + spaces::Vararg{Any, N}, +) where {N} = Operators.return_space(op_matrix.op, spaces...) -function Operators.return_eltype(op_matrix::FDOperatorMatrix, args...) +function Operators.return_eltype( + op_matrix::FDOperatorMatrix, + args::Vararg{Any, N}, +) where {N} args′ = args[1:(end - 1)] FT = Geometry.undertype(eltype(args[end])) return op_matrix_row_type(op_matrix.op, FT, args′...) @@ -273,8 +280,8 @@ Base.@propagate_inbounds function Operators.stencil_interior( space, idx, hidx, - args..., -) + args::Vararg{Any, N}, +) where {N} args′ = args[1:(end - 1)] row = op_matrix_interior_row(op_matrix.op, space, loc, idx, hidx, args′...) return convert(Operators.return_eltype(op_matrix, args...), row) @@ -287,8 +294,8 @@ Base.@propagate_inbounds function Operators.stencil_left_boundary( space, idx, hidx, - args..., -) + args::Vararg{Any, N}, +) where {N} args′ = args[1:(end - 1)] row = op_matrix_first_row(op_matrix.op, bc, space, loc, idx, hidx, args′...) return convert(Operators.return_eltype(op_matrix, args...), row) @@ -301,22 +308,42 @@ Base.@propagate_inbounds function Operators.stencil_right_boundary( space, idx, hidx, - args..., -) + args::Vararg{Any, N}, +) where {N} args′ = args[1:(end - 1)] row = op_matrix_last_row(op_matrix.op, bc, space, loc, idx, hidx, args′...) return convert(Operators.return_eltype(op_matrix, args...), row) end # Simplified methods for when the operator matrix only depends on FT. -op_matrix_row_type(op, ::Type{FT}, args...) where {FT} = +op_matrix_row_type(op, ::Type{FT}, args::Vararg{Any, N}) where {FT, N} = typeof(op_matrix_interior_row(op, FT)) -op_matrix_interior_row(op, space, loc, idx, hidx, args...) = - op_matrix_interior_row(op, Spaces.undertype(space)) -op_matrix_first_row(op, bc, space, loc, idx, hidx, args...) = - op_matrix_first_row(op, bc, Spaces.undertype(space)) -op_matrix_last_row(op, bc, space, loc, idx, hidx, args...) = - op_matrix_last_row(op, bc, Spaces.undertype(space)) +op_matrix_interior_row( + op, + space, + loc, + idx, + hidx, + args::Vararg{Any, N}, +) where {N} = op_matrix_interior_row(op, Spaces.undertype(space)) +op_matrix_first_row( + op, + bc, + space, + loc, + idx, + hidx, + args::Vararg{Any, N}, +) where {N} = op_matrix_first_row(op, bc, Spaces.undertype(space)) +op_matrix_last_row( + op, + bc, + space, + loc, + idx, + hidx, + args::Vararg{Any, N}, +) where {N} = op_matrix_last_row(op, bc, Spaces.undertype(space)) ################################################################################ diff --git a/src/Operators/operator2stencil.jl b/src/Operators/operator2stencil.jl index 111a7e4f25..cb11379ae7 100644 --- a/src/Operators/operator2stencil.jl +++ b/src/Operators/operator2stencil.jl @@ -74,29 +74,30 @@ get_boundary(op::Operator2Stencil, bw::LeftBoundaryWindow{name}) where {name} = get_boundary(op::Operator2Stencil, bw::RightBoundaryWindow{name}) where {name} = get_boundary(op.op, bw) -function return_eltype(op::Operator2Stencil, args...) +function return_eltype(op::Operator2Stencil, args::Vararg{Any, M}) where {M} lbw, ubw = stencil_interior_width(op.op, args...)[1] N = ubw - lbw + 1 return StencilCoefs{lbw, ubw, NTuple{N, return_eltype(op.op, args...)}} end -return_space(op::Operator2Stencil, spaces...) = return_space(op.op, spaces...) +return_space(op::Operator2Stencil, spaces::Vararg{Any, N}) where {N} = + return_space(op.op, spaces...) -stencil_interior_width(op::Operator2Stencil, args...) = +stencil_interior_width(op::Operator2Stencil, args::Vararg{Any, N}) where {N} = stencil_interior_width(op.op, args...) left_interior_idx( space::AbstractSpace, op::Operator2Stencil, bc::AbstractBoundaryCondition, - args..., -) = left_interior_idx(space, op.op, bc, args...) + args::Vararg{Any, N}, +) where {N} = left_interior_idx(space, op.op, bc, args...) right_interior_idx( space::AbstractSpace, op::Operator2Stencil, bc::AbstractBoundaryCondition, - args..., -) = right_interior_idx(space, op.op, bc, args...) + args::Vararg{Any, N}, +) where {N} = right_interior_idx(space, op.op, bc, args...) # TODO: find out why using Base.@propagate_inbounds blows up compilation time function stencil_interior( diff --git a/src/Operators/spectralelement.jl b/src/Operators/spectralelement.jl index 39b3eef714..0aee4eed83 100644 --- a/src/Operators/spectralelement.jl +++ b/src/Operators/spectralelement.jl @@ -223,7 +223,11 @@ end Calls `resolve_operator(arg, slabidx)` for each `arg` in `args` """ @inline _resolve_operator_args(slabidx) = () -Base.@propagate_inbounds _resolve_operator_args(slabidx, arg, xargs...) = ( +Base.@propagate_inbounds _resolve_operator_args( + slabidx, + arg, + xargs::Vararg{Any, N}, +) where {N} = ( resolve_operator(arg, slabidx), _resolve_operator_args(slabidx, xargs...)..., ) @@ -270,7 +274,11 @@ end end @inline _reconstruct_placeholder_broadcasted(parent_space) = () -@inline _reconstruct_placeholder_broadcasted(parent_space, arg, xargs...) = ( +@inline _reconstruct_placeholder_broadcasted( + parent_space, + arg, + xargs::Vararg{Any, N}, +) where {N} = ( reconstruct_placeholder_broadcasted(parent_space, arg), _reconstruct_placeholder_broadcasted(parent_space, xargs...)..., ) @@ -310,7 +318,13 @@ end end @inline _get_node(space, ij, slabidx) = () -Base.@propagate_inbounds _get_node(space, ij, slabidx, arg, xargs...) = ( +Base.@propagate_inbounds _get_node( + space, + ij, + slabidx, + arg, + xargs::Vararg{Any, N}, +) where {N} = ( get_node(space, arg, ij, slabidx), _get_node(space, ij, slabidx, xargs...)..., ) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index eb433acecc..e051b255ff 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -123,7 +123,8 @@ rmaptype( Recursively apply `promote_type` to the input types. """ -rpromote_type(Ts...) = reduce((T1, T2) -> rmaptype(promote_type, T1, T2), Ts) +rpromote_type(Ts::Vararg{Any, N}) where {N} = + reduce((T1, T2) -> rmaptype(promote_type, T1, T2), Ts) rpromote_type() = Union{} """ @@ -172,7 +173,8 @@ const ⊞ = radd # Adapted from Base/operators.jl for general nary operator fallbacks for op in (:rmul, :radd) @eval begin - ($op)(a, b, c, xs...) = Base.afoldl($op, ($op)(($op)(a, b), c), xs...) + ($op)(a, b, c, xs::Vararg{Any, N}) where {N} = + Base.afoldl($op, ($op)(($op)(a, b), c), xs...) end end