Skip to content

Add scalar_fieldmatrix #2289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions docs/src/matrix_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ preconditioner_cache
check_preconditioner
lazy_or_concrete_preconditioner
apply_preconditioner
get_scalar_keys
get_field_first_index_offset
broadcasted_get_field_type
```

## Utilities
Expand All @@ -98,4 +101,5 @@ column_field2array
column_field2array_view
field2arrays
field2arrays_view
scalar_fieldmatrix
```
1 change: 1 addition & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
import ..RecursiveApply: ⊠, ⊞, ⊟
import ..DataLayouts
import ..DataLayouts: AbstractData
import ..DataLayouts: vindex
import ..Geometry
Expand Down
23 changes: 23 additions & 0 deletions src/MatrixFields/field_name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ extract_first(::FieldName{name_chain}) where {name_chain} = first(name_chain)
drop_first(::FieldName{name_chain}) where {name_chain} =
FieldName(Base.tail(name_chain)...)

extract_last(::FieldName{name_chain}) where {name_chain} =
name_chain[length(name_chain)]
# drop_last(::FieldName{name_chain}) where {name_chain} =
# FieldName(name_chain[1:(end - 1)]...)

has_field(x, ::FieldName{()}) = true
has_field(x, name::FieldName) =
extract_first(name) in propertynames(x) &&
Expand All @@ -59,6 +64,24 @@ get_field(x, ::FieldName{()}) = x
get_field(x, name::FieldName) =
get_field(getproperty(x, extract_first(name)), drop_first(name))

"""
broadcasted_get_field_type(::Type{X}, name::FieldName)

Returns the type of the field accessed by `name` in the type `X`.
"""
broadcasted_get_field_type(::Type{X}, ::FieldName{()}) where {X} = X
broadcasted_get_field_type(::Type{X}, name::FieldName) where {X} =
broadcasted_get_field_type(
fieldtype(X, extract_first(name)),
drop_first(name),
)
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
for m in methods(broadcasted_get_field_type)
m.recursion_relation = dont_limit
end
end
Comment on lines +78 to +83
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm that this is actually needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I just tried running the tests with this removed, and it results in a

failed to optimize due to recursion: ClimaCore.MatrixFields.broadcasted_get_field_type

with @test_opt

Copy link
Member

@dennisYatunin dennisYatunin Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it makes sense that this is necessary for types with depth >= 2 on Julia 1.10. My only suggestion would be to move this to the bottom of the file, where I've disabled the recursion limits of a bunch of other functions. It's a little cleaner if they're all in one place.


broadcasted_has_field(::Type{X}, ::FieldName{()}) where {X} = true
broadcasted_has_field(::Type{X}, name::FieldName) where {X} =
extract_first(name) in fieldnames(X) &&
Expand Down
Loading
Loading