-
Notifications
You must be signed in to change notification settings - Fork 16
Add scalar_fieldmatrix #2289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add scalar_fieldmatrix #2289
Conversation
67fd30e
to
65b036a
Compare
65b036a
to
3f81735
Compare
if hasfield(Method, :recursion_relation) | ||
dont_limit = (args...) -> true | ||
for m in methods(broadcasted_get_field_type) | ||
m.recursion_relation = dont_limit | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you confirm that this is actually needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I just tried running the tests with this removed, and it results in a
failed to optimize due to recursion: ClimaCore.MatrixFields.broadcasted_get_field_type
with @test_opt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it makes sense that this is necessary for types with depth >= 2 on Julia 1.10. My only suggestion would be to move this to the bottom of the file, where I've disabled the recursion limits of a bunch of other functions. It's a little cleaner if they're all in one place.
src/MatrixFields/field_name_dict.jl
Outdated
entry = | ||
entry isa ColumnwiseBandMatrixField ? entry.entries.:(1) : entry |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct for a ColumnwiseBandMatrixField
or a UniformScaling
, but not for a DiagonalMatrixRow
(the other type of ScalingFieldMatrixEntry
). When we have a single tensor in a FieldMatrix
, we need to extract a number from that tensor for scalar_fieldmatrix
. You'll also need to modify the relevant method for get_internal_entry
to do this, and you'll probably want to add/modify a block in dycore_prognostic_EDMF_FieldMatrix
to test this functionality.
2405aca
to
02bb025
Compare
src/MatrixFields/field_name_dict.jl
Outdated
if name_pair[1] == name_pair[2] | ||
entry | ||
elseif name_pair[2] == @name() && has_field(entry, name_pair[1]) | ||
DiagonalMatrixRow(get_field(entry, name_pair[1])) | ||
elseif is_overlapping_name(name_pair[1], name_pair[2]) | ||
throw(key_error) | ||
else | ||
zero(entry) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dennisYatunin I'm not sure if these conditionals are correct. If they are, this can probably be re-combined into a single get_internal_entry(::UniformScaling, ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is slightly wrong; you want to check the entry's value entry[0]
instead of entry
itself, so your condition should be has_field(entry[0], name_pair[1])
and your result value should be get_field(entry[0], name_pair[1])
. It would be good to add a DiagonalMatrixRow(::Axis2Tensor)
block to dycore_prognostic_EDMF_FieldMatrix
to test this functionality.
You can also do a similar optimization for scalars here like you did in the method for ColumnwiseBandMatrixField
, wrapping get_field(entry[0], name_pair[1])
in a UniformScaling
instead of a DiagonalMatrixRow
when get_field(entry[0], name_pair[1]) isa Number
.
And yes, it would be nice to combine this with the method for UniformScaling
so that there is just one method for ScalingFieldMatrixEntry
. I think the method for DiagonalMatrixRow
will work for both types of scaling entries, so you can just remove the other method.
Add a function to convert a FieldMatrix where each matrix entry has an eltype of some struct into a FieldMatrix where each entry has an eltype of a scalar. Add additional tests for scalar_matrixfields Use @test_all in tests
d91cf59
to
e1c1c22
Compare
function bypass_adjoint(field::Fields.Field) | ||
if eltype(field) <: Adjoint | ||
return eltype(field.parent) | ||
else | ||
return eltype(field) | ||
end | ||
end | ||
|
||
function bypass_adjoint(d::T) where {T} | ||
if T <: Adjoint | ||
return typeof(d.parent) | ||
else | ||
return typeof(d) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function bypass_adjoint(field::Fields.Field) | |
if eltype(field) <: Adjoint | |
return eltype(field.parent) | |
else | |
return eltype(field) | |
end | |
end | |
function bypass_adjoint(d::T) where {T} | |
if T <: Adjoint | |
return typeof(d.parent) | |
else | |
return typeof(d) | |
end | |
end | |
bypass_adjoint(x::Fields.Field) = bypass_adjoint(Fields.field_values(x)) | |
bypass_adjoint(x::DataLayouts.AbstractData{<:Adjoint}) = eltype(x.parent) | |
bypass_adjoint(x::DataLayouts.AbstractData) = eltype(x) | |
bypass_adjoint(x::Adjoint) = typeof(x.parent) | |
bypass_adjoint(x) = typeof(x) |
This might be easier to extend?
closes #2306
This PR adds
scalar_fieldmatrix
, a function which takes in a field matrix, and returns a field matrix with keys that are associated with matrix fields of either uniform scaling or columnwisebandmatrices of scalars.This also makes the
get_index
for field matrices return a field instead of a broadcasted object when the keys used to index the field matrix will result in a columnwisebandmatrices of scalars.