Skip to content

Commit ebc079c

Browse files
Try using prop inbounds in mat field
1 parent dec1485 commit ebc079c

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

src/ClimaCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ if hasfield(Method, :recursion_relation)
8787
m.recursion_relation = dont_limit
8888
end
8989
for m in methods(MatrixFields.get_subtree_at_name)
90-
m.recursion_relation = MatrixFields.dont_limit
90+
m.recursion_relation = dont_limit
9191
end
9292
for m in methods(MatrixFields.concrete_field_vector_within_subtree)
9393
m.recursion_relation = dont_limit

src/MatrixFields/matrix_multiplication.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -341,15 +341,7 @@ boundary_modified_ud(_, ud, column_space, i) = ud
341341
boundary_modified_ud(::BottomRightMatrixCorner, ud, column_space, i) =
342342
min(Operators.right_idx(column_space) - i, ud)
343343

344-
# TODO: Use @propagate_inbounds here, and remove @inbounds from this function.
345-
# As of Julia 1.8, doing this increases compilation time by more than an order
346-
# of magnitude, and it also makes type inference fail for some complicated
347-
# matrix field broadcast expressions (in particular, those that involve products
348-
# of linear combinations of matrix fields). Not using @propagate_inbounds causes
349-
# matrix field broadcast expressions to take roughly 3 or 4 times longer to
350-
# evaluate, but this is less significant than the decrease in compilation time.
351-
# matrix-matrix multiplication
352-
function multiply_matrix_at_index(
344+
Base.@propagate_inbounds function multiply_matrix_at_index(
353345
space,
354346
idx,
355347
hidx,
@@ -374,9 +366,11 @@ function multiply_matrix_at_index(
374366

375367
# Precompute the row that is needed from matrix1 so that it does not get
376368
# recomputed multiple times.
377-
matrix1_row = @inbounds Operators.getidx(space, matrix1, idx, hidx)
369+
TM1R = Operators.getidx_return_type(matrix1)
370+
matrix1_row = @inbounds Operators.getidx(space, matrix1, idx, hidx)::TM1R
378371

379372
matrix2 = arg
373+
TM2R = Operators.getidx_return_type(matrix2)
380374
column_space2 = column_axes(matrix2, column_space1)
381375
ld2, ud2 = outer_diagonals(eltype(matrix2))
382376
prod_ld, prod_ud = outer_diagonals(prod_type)
@@ -395,7 +389,7 @@ function multiply_matrix_at_index(
395389
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
396390
Base.@_inline_meta
397391
if isnothing(bc) || boundary_modified_ld1 <= d <= boundary_modified_ud1
398-
@inbounds Operators.getidx(space, matrix2, idx + d, hidx)
392+
@inbounds Operators.getidx(space, matrix2, idx + d, hidx)::TM2R
399393
else
400394
zero(eltype(matrix2)) # This row is outside the matrix.
401395
end
@@ -437,7 +431,7 @@ function multiply_matrix_at_index(
437431
return BandMatrixRow{prod_ld}(prod_entries...)
438432
end
439433
# matrix-vector multiplication
440-
function multiply_matrix_at_index(
434+
Base.@propagate_inbounds function multiply_matrix_at_index(
441435
space,
442436
idx,
443437
hidx,
@@ -454,6 +448,7 @@ function multiply_matrix_at_index(
454448
arg,
455449
typeof(lg),
456450
)
451+
TM1R = Operators.getidx_return_type(matrix1)
457452

458453
column_space1 = column_axes(matrix1, space)
459454
ld1, ud1 = outer_diagonals(eltype(matrix1))
@@ -462,13 +457,14 @@ function multiply_matrix_at_index(
462457

463458
# Precompute the row that is needed from matrix1 so that it does not get
464459
# recomputed multiple times.
465-
matrix1_row = @inbounds Operators.getidx(space, matrix1, idx, hidx)
460+
matrix1_row = @inbounds Operators.getidx(space, matrix1, idx, hidx)::TM1R
466461

467462
vector = arg
463+
TVR = Operators.getidx_return_type(vector)
468464
prod_value = rzero(prod_type)
469465
@inbounds for d in boundary_modified_ld1:boundary_modified_ud1
470466
value1 = matrix1_row[d]
471-
value2 = Operators.getidx(space, vector, idx + d, hidx)
467+
value2 = Operators.getidx(space, vector, idx + d, hidx)::TVR
472468
value2_lg = Geometry.LocalGeometry(space, idx + d, hidx)
473469
prod_value =
474470
radd(prod_value, rmul_with_projection(value1, value2, value2_lg))

0 commit comments

Comments
 (0)