Skip to content

Commit

Permalink
aggregator for min and max
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Jan 7, 2024
1 parent f64d422 commit 848863d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 13 deletions.
51 changes: 41 additions & 10 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@

struct VarianceWelfordAggegation{Ti,T}
aggregator(::Type{typeof(var)}) = VarianceWelfordAggregation
aggregator(::Type{typeof(maximum)}) = MaxAggregation
aggregator(::Type{typeof(minimum)}) = MinAggregation

struct VarianceWelfordAggregation{Ti,T}
count::Ti
mean::T
M2::T
end

function VarianceWelfordAggegation(T)
VarianceWelfordAggegation{Int,T}(0,zero(T),zero(T))
function VarianceWelfordAggregation(T)
VarianceWelfordAggregation{Int,T}(0,zero(T),zero(T))
end

#function VarianceWelfordAggegation(::Type{<:Array{T,N}}) where {T,N}
# VarianceWelfordAggegation(0,zero(T),zero(T))
#end

# Welford's online algorithm
@inline function update(ag::VarianceWelfordAggegation{Ti,T}, new_value) where {Ti,T}
@inline function update(ag::VarianceWelfordAggregation{Ti,T}, new_value) where {Ti,T}
count = ag.count
mean = ag.mean
M2 = ag.M2
Expand All @@ -24,10 +24,41 @@ end
mean += delta / count
delta2 = new_value - mean
M2 += delta * delta2
return VarianceWelfordAggegation{Ti,T}(count, mean, M2)
return VarianceWelfordAggregation{Ti,T}(count, mean, M2)
end

function result(ag::VarianceWelfordAggegation)
function result(ag::VarianceWelfordAggregation)
sample_variance = ag.M2 / (ag.count - 1)
return sample_variance
end



for (funAggregation,fun) in ((:MaxAggregation,max),(:MinAggregation,min))
@eval begin
struct $funAggregation{T}
result::T
init::Bool
end

function $funAggregation(T)
$funAggregation{T}(zero(T),false)
end

@inline function update(ag::$funAggregation{T}, new_value) where T
if ag.init
return $funAggregation{T}(max(ag.result,new_value),true)
else
return $funAggregation{T}(new_value,true)
end
end

function result(ag::$funAggregation)
if ag.init
return ag.result
else
error("reducing over an empty collection is not allowed")
end
end
end
end
6 changes: 3 additions & 3 deletions src/groupby.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,17 +483,17 @@ function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(mean)},indices:
end


function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(var)},indices::Union{Integer,Colon,AbstractRange{<:Integer},AbstractVector{<:Integer}}...) where {T,N,TGV}
function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,TF},indices::Union{Integer,Colon,AbstractRange{<:Integer},AbstractVector{<:Integer}}...) where TF <: Union{typeof(var),typeof(maximum),typeof(minimum)} where {T,N,TGV}

return _mapreduce_aggregation(
gr.gv.map_fun,VarianceWelfordAggegation,gr.gv,indices);
gr.gv.map_fun,aggregator(TF),gr.gv,indices);
end


function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(std)},indices::Union{Integer,Colon,AbstractRange{<:Integer},AbstractVector{<:Integer}}...) where {T,N,TGV}

return sqrt.(_mapreduce_aggregation(
gr.gv.map_fun,VarianceWelfordAggegation,gr.gv,indices))
gr.gv.map_fun,VarianceWelfordAggregation,gr.gv,indices))
end


Expand Down

0 comments on commit 848863d

Please sign in to comment.