diff --git a/src/DomainBuffers.jl b/src/DomainBuffers.jl index 208e5f7b..0026655c 100644 --- a/src/DomainBuffers.jl +++ b/src/DomainBuffers.jl @@ -135,19 +135,19 @@ Get the set of items stored in `db` or `dbs[domain]` """ getset(b::DomainBuffers, domain) = getset(b[domain]) -struct DomainBuffer{I,B,S,SDH<:SubDofHandler} <: AbstractDomainBuffer +struct DomainBuffer{I,B,SV<:StateVariables,SDH<:SubDofHandler} <: AbstractDomainBuffer set::Vector{I} itembuffer::B - states::StateVariables{S} + states::SV sdh::SDH end -struct ThreadedDomainBuffer{I,B,S,SDH<:SubDofHandler} <: AbstractDomainBuffer +struct ThreadedDomainBuffer{I,B,SV<:StateVariables,SDH<:SubDofHandler} <: AbstractDomainBuffer chunks::Vector{Vector{Vector{I}}} # I=Int (cell), I=FacetIndex (facet), or set::Vector{I} # I=NTuple{2,FacetIndex} (interface) num_tasks::Int itembuffer::TaskLocals{B,B} # cell, facet, or interface buffer - states::StateVariables{S} + states::SV sdh::SDH end function ThreadedDomainBuffer(set, itembuffer::AbstractItemBuffer, states::StateVariables, sdh::SubDofHandler, colors_or_chunks=nothing; num_tasks = Threads.nthreads()) diff --git a/src/FerriteAssembly.jl b/src/FerriteAssembly.jl index c7887e4f..8a227b60 100644 --- a/src/FerriteAssembly.jl +++ b/src/FerriteAssembly.jl @@ -1,5 +1,6 @@ module FerriteAssembly using Ferrite, ForwardDiff +using Ferrite.CollectionsOfViews: ArrayOfVectorViews using ConstructionBase: setproperties include("Multithreading/TaskLocals.jl") # Task-local storage model diff --git a/src/Utils/MaterialModelsBase.jl b/src/Utils/MaterialModelsBase.jl index 689c701b..d5f1753b 100644 --- a/src/Utils/MaterialModelsBase.jl +++ b/src/Utils/MaterialModelsBase.jl @@ -2,7 +2,7 @@ import MaterialModelsBase as MMB """ FerriteAssembly.element_routine!( - Ke, re, state::Vector{<:MMB.AbstractMaterialState}, ae, + Ke, re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae, m::MMB.AbstractMaterial, cv::AbstractCellValues, buffer) Solve the weak form @@ -16,13 +16,13 @@ where ``\\sigma`` is calculated with the `material_response` function from Note that `create_cell_state` is already implemented for `<:AbstractMaterial`. """ function FerriteAssembly.element_routine!( - Ke, re, state::Vector{<:MMB.AbstractMaterialState}, + Ke, re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer) return mechanical_element_routine!(MMB.get_tensorbase(material), Ke, re, state, ae, material, cellvalues, buffer) end function mechanical_element_routine!(::Type{<:SymmetricTensor{2}}, - Ke, re, state::Vector{<:MMB.AbstractMaterialState}, + Ke, re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer) cache = FerriteAssembly.get_user_cache(buffer) Δt = FerriteAssembly.get_time_increment(buffer) @@ -47,7 +47,7 @@ function mechanical_element_routine!(::Type{<:SymmetricTensor{2}}, end function mechanical_element_routine!(::Type{<:Tensor{2}}, - Ke, re, state::Vector{<:MMB.AbstractMaterialState}, + Ke, re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer) cache = FerriteAssembly.get_user_cache(buffer) Δt = FerriteAssembly.get_time_increment(buffer) @@ -73,20 +73,20 @@ end """ FerriteAssembly.element_residual!( - re, state::Vector{<:MMB.AbstractMaterialState}, ae, + re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae, m::MMB.AbstractMaterial, cv::AbstractCellValues, buffer) The `element_residual!` implementation corresponding to the `element_routine!` implementation for a `MaterialModelsBase.AbstractMaterial` """ function FerriteAssembly.element_residual!( - re, state::Vector{<:MMB.AbstractMaterialState}, + re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer) return mechanical_element_residual!(MMB.get_tensorbase(material), re, state, ae, material, cellvalues, buffer) end function mechanical_element_residual!(::Type{<:SymmetricTensor{2}}, - re, state::Vector{<:MMB.AbstractMaterialState}, + re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer) cache = FerriteAssembly.get_user_cache(buffer) Δt = FerriteAssembly.get_time_increment(buffer) @@ -105,7 +105,7 @@ function mechanical_element_residual!(::Type{<:SymmetricTensor{2}}, end function mechanical_element_residual!(::Type{<:Tensor{2}}, - re, state::Vector{<:MMB.AbstractMaterialState}, + re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer) cache = FerriteAssembly.get_user_cache(buffer) Δt = FerriteAssembly.get_time_increment(buffer) diff --git a/src/Workers/QuadPointEvaluator.jl b/src/Workers/QuadPointEvaluator.jl index dd98348c..c3e6ec71 100644 --- a/src/Workers/QuadPointEvaluator.jl +++ b/src/Workers/QuadPointEvaluator.jl @@ -1,5 +1,3 @@ -using Ferrite.CollectionsOfViews: ArrayOfVectorViews - """ QuadPointEvaluator{VT}(db::Union{DomainBuffer, DomainBuffers}, f::Function) diff --git a/src/setup.jl b/src/setup.jl index 24c1c35b..75381ff1 100644 --- a/src/setup.jl +++ b/src/setup.jl @@ -108,7 +108,7 @@ function setup_domainbuffer(domain::DomainSpec; threading=Val(false), kwargs...) end create_states(domain::DomainSpec{Int}, a) = create_states(domain.sdh, domain.material, domain.fe_values, a, domain.set, create_dofrange(domain.sdh)) -create_states(::DomainSpec{FacetIndex}, ::Any) = Dict{Int,Nothing}() +create_states(::DomainSpec{FacetIndex}, ::Any) = (s = Dict{Int,Nothing}(); StateVariables(Int[], s, s)) function setup_itembuffer(adb, domain::DomainSpec{FacetIndex}, args...) dofrange = create_dofrange(domain.sdh) @@ -120,10 +120,9 @@ function setup_itembuffer(adb, domain::DomainSpec{Int}, states) end function _setup_domainbuffer(threaded, domain; a=nothing, autodiffbuffer=Val(false), kwargs...) - new_states = create_states(domain, a) - old_states = create_states(domain, a) - itembuffer = setup_itembuffer(autodiffbuffer, domain, new_states) - return _setup_domainbuffer(threaded, domain.set, itembuffer, StateVariables(old_states, new_states), domain.sdh, domain.colors_or_chunks; kwargs...) + statevars = create_states(domain, a) + itembuffer = setup_itembuffer(autodiffbuffer, domain, statevars.old) + return _setup_domainbuffer(threaded, domain.set, itembuffer, statevars, domain.sdh, domain.colors_or_chunks; kwargs...) end # Type-unstable switch diff --git a/src/states.jl b/src/states.jl index 364dad6d..3547fa63 100644 --- a/src/states.jl +++ b/src/states.jl @@ -1,16 +1,57 @@ # Minimal interface for a vector, storage format will probably be updated later. -mutable struct StateVector{SV} - vals::Dict{Int, SV} +mutable struct StateVector{SV, VV <: AbstractVector{SV}} + vals::VV + inds::Vector{Int} # Can as an optimization be shared between all `StateVectors` (also across domains) end -Base.getindex(s::StateVector, cellnum::Int) = s.vals[cellnum] -Base.setindex!(s::StateVector, v, cellnum::Int) = setindex!(s.vals, v, cellnum) +Base.getindex(s::StateVector, cellnum::Int) = s.vals[s.inds[cellnum]] +Base.setindex!(s::StateVector, v, cellnum::Int) = setindex!(s.vals, v, s.inds[cellnum]) Base.:(==)(a::StateVector, b::StateVector) = (a.vals == b.vals) +function Base.iterate(s::StateVector, i::Int = 1) + i > length(s.vals) && return nothing + return s.vals[i], i + 1 +end -struct StateVariables{SV} - old::StateVector{SV} # Rule: Referenced during assembly, not changed (ever) - new::StateVector{SV} # Rule: Updated during assembly, not referenced (before updated) +struct StateVariables{SV, VV} + old::StateVector{SV, VV} # Rule: Referenced during assembly, not changed (ever) + new::StateVector{SV, VV} # Rule: Updated during assembly, not referenced (before updated) +end +function StateVariables(inds::Vector{Int}, old::Dict{K, SV}, new::Dict{K, SV}) where {K, T, SV <: Vector{T}} + # if eltype(old) isa Vector => ArrayOfVectorViews + # else => Vector{eltype(old)} + num = length(old) + num_total = sum(length, values(old)) + old_data = Vector{T}(undef, num_total) + new_data = Vector{T}(undef, num_total) + indices = Vector{Int}(undef, num + 1) + i = 1 + j = 1 + for key in sort(collect(keys(old))) + indices[i] = j + inds[key] = i + for (oldval, newval) in zip(old[key], new[key]) + old_data[j] = oldval + new_data[j] = newval + j += 1 + end + i += 1 + end + indices[i] = j + oldvals = ArrayOfVectorViews(indices, old_data, LinearIndices((num,))) + newvals = ArrayOfVectorViews(indices, new_data, LinearIndices((num,))) + StateVariables(StateVector(oldvals, inds), StateVector(newvals, inds)) +end +function StateVariables(inds::Vector{Int}, old::Dict{K, SV}, new::Dict{K, SV}) where {K, SV} + oldvals = Vector{SV}(undef, length(old)) + newvals = Vector{SV}(undef, length(new)) + i = 1 + for key in sort(collect(keys(old))) + oldvals[i] = old[key] + newvals[i] = new[key] + inds[key] = i + i += 1 + end + StateVariables(StateVector(oldvals, inds), StateVector(newvals, inds)) end -StateVariables(old::Dict, new::Dict) = StateVariables(StateVector(old), StateVector(new)) function update_states!(sv::StateVariables) tmp = sv.old.vals @@ -66,7 +107,12 @@ define the [`create_cell_state`](@ref) function for their `material` (and corres """ function create_states(sdh::SubDofHandler, material, cellvalues, a, cellset, dofrange) ae = zeros(ndofs_per_cell(sdh)) - coords = getcoordinates(_getgrid(sdh), first(cellset)) + grid = _getgrid(sdh) + coords = getcoordinates(grid, first(cellset)) dofs = zeros(Int, ndofs_per_cell(sdh)) - return Dict(cellnr => _create_cell_state(coords, dofs, material, cellvalues, a, ae, dofrange, sdh, cellnr) for cellnr in cellset) + # Could make construction more efficient by doing this when creating the ArrayOfVectorViews + old = Dict(cellnr => _create_cell_state(coords, dofs, material, cellvalues, a, ae, dofrange, sdh, cellnr) for cellnr in cellset) + new = Dict(cellnr => _create_cell_state(coords, dofs, material, cellvalues, a, ae, dofrange, sdh, cellnr) for cellnr in cellset) + inds = zeros(Int, getncells(grid)) # Could be moved out and shared between all domains... + return StateVariables(inds, old, new) end diff --git a/test/quadpoint_evaluation.jl b/test/quadpoint_evaluation.jl index 07863eeb..fb8c4557 100644 --- a/test/quadpoint_evaluation.jl +++ b/test/quadpoint_evaluation.jl @@ -79,11 +79,13 @@ foo(::QEMat{4}, u, ∇u, qp_state) = 3 * qp_state[2] qe = QuadPointEvaluator{Float64}(db, foo) work!(qe, db) - for (i, s) in states["left"].vals # TODO: Using internals here - @test qe.data[i] ≈ 3 * s + for cellnr in FerriteAssembly.getset(db, "left") + s = states["left"][cellnr] + @test qe.data[cellnr] ≈ 3 * s end - for (i, s) in states["right"].vals # TODO: Using internals here - @test qe.data[i] ≈ 3 * last.(s) + for cellnr in FerriteAssembly.getset(db, "right") + s = states["right"][cellnr] + @test qe.data[cellnr] ≈ 3 * last.(s) @test all(first.(s) .≥ 0) @test all(last.(s) .≤ 0) end diff --git a/test/runtests.jl b/test/runtests.jl index 32035017..8ca7d2a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,12 @@ import MaterialModelsBase as MMB import MechanicalMaterialModels as MMM using Logging +# Sometimes calling `@allocated f(args...)` allocates when called in global scope, +# calling inside a function using local variables solves this +function get_allocations(f::F, args::Vararg{N}) where {F, N} + @allocated f(args...) +end + include("replacements.jl") include("states.jl") include("threading_utils.jl") @@ -45,8 +51,8 @@ include("errors.jl") @test FerriteAssembly.get_dofhandler(buffer_threaded) === dh @test FerriteAssembly.get_dofhandler(buffers) === dh - @test isa(FerriteAssembly.get_state(buffer, 1), Vector{Nothing}) - @test isa(FerriteAssembly.get_old_state(buffer, 1), Vector{Nothing}) + @test isa(FerriteAssembly.get_state(buffer, 1), AbstractVector{Nothing}) + @test isa(FerriteAssembly.get_old_state(buffer, 1), AbstractVector{Nothing}) @test length(FerriteAssembly.get_state(buffer, 1)) == getnquadpoints(cv) @test length(FerriteAssembly.get_old_state(buffer, 1)) == getnquadpoints(cv) diff --git a/test/states.jl b/test/states.jl index 02e4f733..f8cf60f4 100644 --- a/test/states.jl +++ b/test/states.jl @@ -6,7 +6,7 @@ module TestStateModule quadnr::Int end FerriteAssembly.create_cell_state(::MatA, cv, args...) = [StateA(-1, 0) for _ in 1:getnquadpoints(cv)] - function FerriteAssembly.element_residual!(re, states::Vector{StateA}, ae, ::MatA, cv, buffer) + function FerriteAssembly.element_residual!(re, states::AbstractVector{StateA}, ae, ::MatA, cv, buffer) cellnr = cellid(buffer) for i in 1:getnquadpoints(cv) states[i] = StateA(cellnr, i) @@ -33,7 +33,7 @@ module TestStateModule counter::Int end FerriteAssembly.create_cell_state(::MatC, cv, args...) = [StateC(0) for _ in 1:getnquadpoints(cv)] - function FerriteAssembly.element_residual!(re, states::Vector{StateC}, ae, ::MatC, cv, buffer) + function FerriteAssembly.element_residual!(re, states::AbstractVector{StateC}, ae, ::MatC, cv, buffer) old_states = FerriteAssembly.get_old_state(buffer) for i in 1:getnquadpoints(cv) states[i] = StateC(old_states[i].counter + 1) @@ -69,7 +69,7 @@ end buffer = setup_domainbuffer(DomainSpec(dh, MatA(), cv)) states = FerriteAssembly.get_state(buffer) old_states = FerriteAssembly.get_old_state(buffer) - @test isa(old_states, FerriteAssembly.StateVector{Vector{StateA}}) + @test isa(old_states, FerriteAssembly.StateVector{<:AbstractVector{StateA}}) @test old_states == states @test old_states[1] == [StateA(-1, 0) for _ in 1:getnquadpoints(cv)] work!(r_assembler, buffer) @@ -83,8 +83,8 @@ end @test old_states == states_dc # Correctly updated values states[1][1] = StateA(0,0) @test old_states[1][1] == StateA(1,1) # But not aliased - allocs = @allocated update_states!(container) - @test allocs == 0 # Vector{T} where isbitstype(T) should not allocate (MatA fulfills this) + # Vector{T} where isbitstype(T) should not allocate (MatA fulfills this) + @test get_allocations(update_states!, container) == 0 end # MatB (not bitstype) @@ -111,15 +111,15 @@ end x_values = [spatial_coordinate(cv, i, coords) for i in 1:getnquadpoints(cv)] states[cellnr] = StateB(0, -x_values) @test old_states[cellnr] == StateB(cellnr, x_values) # But not aliased - allocs = @allocated update_states!(buffer) - @test allocs == 0 # Vector{T} where !isbitstype(T) should no longer allocate + # Vector{T} where !isbitstype(T) should no longer allocate + @test get_allocations(update_states!, buffer) == 0 # MatC (accumulation), using threading as well colors = create_coloring(grid) buffer = setup_domainbuffer(DomainSpec(dh, MatC(), cv; colors=colors)) states = FerriteAssembly.get_state(buffer) old_states = FerriteAssembly.get_old_state(buffer) - @test isa(old_states, FerriteAssembly.StateVector{Vector{StateC}}) + @test isa(old_states, FerriteAssembly.StateVector{<:AbstractVector{StateC}}) @test old_states == states @test old_states[1][1] == StateC(0) work!(kr_assembler, buffer) @@ -134,8 +134,8 @@ end for cellnr in 1:getncells(grid) @test states[cellnr][2] == StateC(2) # Check that all are updated end - allocs = @allocated update_states!(buffer) - @test allocs == 0 # Vector{T} where isbitstype(T) should not allocate (MatC fulfills this) + # Vector{T} where isbitstype(T) should not allocate (MatC fulfills this) + @test get_allocations(update_states!, buffer) == 0 end end @@ -145,15 +145,13 @@ end # Smoke-test of update_states! for nothing states (and check no allocations) cv = CellValues(QuadratureRule{RefTriangle}(2), ip) buffer = setup_domainbuffer(DomainSpec(dh, nothing, cv)) - @test isa(FerriteAssembly.get_state(buffer), FerriteAssembly.StateVector{Vector{Nothing}}) + @test isa(FerriteAssembly.get_state(buffer), FerriteAssembly.StateVector{<:AbstractVector{Nothing}}) update_states!(buffer) # Compile - allocs = @allocated update_states!(buffer) - @test allocs == 0 + @test get_allocations(update_states!, buffer) == 0 gda = DomainSpec(dh, nothing, cv; set=1:getncells(dh.grid)÷2) gdb = DomainSpec(dh, nothing, cv; set=setdiff!(Set(1:getncells(dh.grid)), gda.set)) buffers = setup_domainbuffers(Dict("a"=>gda, "b"=>gdb)) update_states!(buffers) # Compile - allocs = @allocated update_states!(buffers) - @test allocs == 0 + @test get_allocations(update_states!, buffer) == 0 end