diff --git a/ext/KrylovExt.jl b/ext/KrylovExt.jl index 746ab4de97..2aa8ea0d23 100644 --- a/ext/KrylovExt.jl +++ b/ext/KrylovExt.jl @@ -1,10 +1,14 @@ module KrylovExt -import ClimaComms -import ClimaCore: Fields +import ClimaCore: DataLayouts, Fields import Krylov -Krylov.ktypeof(x::Fields.FieldVector) = - ClimaComms.array_type(x){eltype(parent(x)), 1} +function Krylov.ktypeof(x::Fields.FieldVector) + array_type_unknown_N = typeof(parent(Fields.representative_field(x))) + array_type_variable_N = DataLayouts.parent_array_type(array_type_unknown_N) + return typeintersect(array_type_variable_N, AbstractVector) # Set N = 1. +end + +Krylov.kcopy!(::Integer, y::AbstractVector, x::Fields.FieldVector) = y .= x end diff --git a/src/Fields/fieldvector.jl b/src/Fields/fieldvector.jl index ec3578f5fa..ff91c3e715 100644 --- a/src/Fields/fieldvector.jl +++ b/src/Fields/fieldvector.jl @@ -456,20 +456,23 @@ end import ClimaComms -ClimaComms.array_type(x::FieldVector) = - promote_type(unrolled_map(ClimaComms.array_type, _values(x))...) - -ClimaComms.device(x::FieldVector) = ClimaComms.device(ClimaComms.context(x)) -function ClimaComms.context(x::FieldVector) - isempty(_values(x)) && error("Empty FieldVector has no device or context") - # We don't have promotion for devices or contexts, so we use the first value - # that isn't a PointField (a PointField's data can be stored on a different - # device from other Fields to avoid scalar indexing on GPUs). If there is no - # such value, fall back to using the first PointField. - index = unrolled_findfirst(Base.Fix1(!isa, PointField), _values(x)) - return ClimaComms.context(_values(x)[isnothing(index) ? 1 : index]) +# To infer the ClimaComms device and its properties, use the first Field in a +# FieldVector that isn't a PointField, since a PointField's data can be stored +# on a different device from other Fields to avoid scalar indexing on GPUs. If +# the FieldVector only contains PointFields, fall back to using the first one. +function representative_field(x) + all_fields = _values(x) + isempty(all_fields) && error("Empty FieldVector has no ClimaComms device") + field_index = unrolled_findfirst(Base.Fix2(!isa, PointField), all_fields) + return all_fields[isnothing(field_index) ? 1 : field_index] end +ClimaComms.array_type(x::FieldVector) = + ClimaComms.array_type(representative_field(x)) +ClimaComms.device(x::FieldVector) = ClimaComms.device(representative_field(x)) +ClimaComms.context(x::FieldVector) = ClimaComms.context(representative_field(x)) + + function __rprint_diff( io::IO, x::T,