Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some JET failures #1992

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions src/MatrixFields/field_matrix_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,6 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b)
end
end

# TODO: we can remove the uniform_vertical_levels
# limitation while still using static shared memory
# once Nv is in the type space.
function uniform_vertical_levels(x, names)
_, _, _, Nv1, _ = size(Fields.field_values(x[first(names)]))
return all(Base.tail(names)) do name
_, _, _, Nv, _ = size(Fields.field_values(x[name]))
Nv == Nv1
end
end

NVTX.@annotate function run_field_matrix_solver!(
::BlockDiagonalSolve,
cache,
Expand All @@ -267,8 +256,7 @@ NVTX.@annotate function run_field_matrix_solver!(
)
names = matrix_row_keys(keys(A))
if length(names) == 1 ||
all(name -> A[name, name] isa UniformScaling, names.values) ||
!uniform_vertical_levels(x, names.values)
all(name -> A[name, name] isa UniformScaling, names.values)
foreach(names) do name
single_field_solve!(cache[name], x[name], A[name, name], b[name])
end
Expand Down
5 changes: 3 additions & 2 deletions src/MatrixFields/field_name_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,11 @@ end
values_string(values) =
length(values) == 2 ? join(values, " and ") : join(values, ", ", ", and ")

combine_eltypes(T1, T2) =
T1 == T2 ? T1 :
@noinline combine_eltypes(::T1, ::T2) where {T1, T2} =
errror("Mismatched FieldNameSets: Cannot combine a $T1 with a $T2")

@inline combine_eltypes(::Type{T}, ::Type{T}) where {T} = T

combine_name_trees(::Nothing, ::Nothing) = nothing
combine_name_trees(name_tree1, ::Nothing) = name_tree1
combine_name_trees(::Nothing, name_tree2) = name_tree2
Expand Down
Loading