diff --git a/src/aggregation.jl b/src/aggregation.jl index 42c474d..5d633ab 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -1,20 +1,20 @@ -struct VarianceWelfordAggegation{Ti,T} +aggregator(::Type{typeof(var)}) = VarianceWelfordAggregation +aggregator(::Type{typeof(maximum)}) = MaxAggregation +aggregator(::Type{typeof(minimum)}) = MinAggregation + +struct VarianceWelfordAggregation{Ti,T} count::Ti mean::T M2::T end -function VarianceWelfordAggegation(T) - VarianceWelfordAggegation{Int,T}(0,zero(T),zero(T)) +function VarianceWelfordAggregation(T) + VarianceWelfordAggregation{Int,T}(0,zero(T),zero(T)) end -#function VarianceWelfordAggegation(::Type{<:Array{T,N}}) where {T,N} -# VarianceWelfordAggegation(0,zero(T),zero(T)) -#end - # Welford's online algorithm -@inline function update(ag::VarianceWelfordAggegation{Ti,T}, new_value) where {Ti,T} +@inline function update(ag::VarianceWelfordAggregation{Ti,T}, new_value) where {Ti,T} count = ag.count mean = ag.mean M2 = ag.M2 @@ -24,10 +24,41 @@ end mean += delta / count delta2 = new_value - mean M2 += delta * delta2 - return VarianceWelfordAggegation{Ti,T}(count, mean, M2) + return VarianceWelfordAggregation{Ti,T}(count, mean, M2) end -function result(ag::VarianceWelfordAggegation) +function result(ag::VarianceWelfordAggregation) sample_variance = ag.M2 / (ag.count - 1) return sample_variance end + + + +for (funAggregation,fun) in ((:MaxAggregation,max),(:MinAggregation,min)) + @eval begin + struct $funAggregation{T} + result::T + init::Bool + end + + function $funAggregation(T) + $funAggregation{T}(zero(T),false) + end + + @inline function update(ag::$funAggregation{T}, new_value) where T + if ag.init + return $funAggregation{T}(max(ag.result,new_value),true) + else + return $funAggregation{T}(new_value,true) + end + end + + function result(ag::$funAggregation) + if ag.init + return ag.result + else + error("reducing over an empty collection is not allowed") + end + end + end +end diff --git a/src/groupby.jl b/src/groupby.jl index 0478a7d..9783048 100644 --- a/src/groupby.jl +++ b/src/groupby.jl @@ -483,17 +483,17 @@ function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(mean)},indices: end -function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(var)},indices::Union{Integer,Colon,AbstractRange{<:Integer},AbstractVector{<:Integer}}...) where {T,N,TGV} +function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,TF},indices::Union{Integer,Colon,AbstractRange{<:Integer},AbstractVector{<:Integer}}...) where TF <: Union{typeof(var),typeof(maximum),typeof(minimum)} where {T,N,TGV} return _mapreduce_aggregation( - gr.gv.map_fun,VarianceWelfordAggegation,gr.gv,indices); + gr.gv.map_fun,aggregator(TF),gr.gv,indices); end function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(std)},indices::Union{Integer,Colon,AbstractRange{<:Integer},AbstractVector{<:Integer}}...) where {T,N,TGV} return sqrt.(_mapreduce_aggregation( - gr.gv.map_fun,VarianceWelfordAggegation,gr.gv,indices)) + gr.gv.map_fun,VarianceWelfordAggregation,gr.gv,indices)) end