Skip to content

Commit

Permalink
use Welford's online algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Dec 22, 2023
1 parent 48f9349 commit 65feb9c
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/CommonDataModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include("multifile.jl")
include("defer.jl")
include("subvariable.jl")
include("select.jl")
include("aggregation.jl")
include("groupby.jl")

end # module CommonDataModel
Expand Down
33 changes: 33 additions & 0 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

struct VarianceWelfordAggegation{Ti,T}
count::Ti
mean::T
M2::T
end

function VarianceWelfordAggegation(T)
VarianceWelfordAggegation(0,zero(T),zero(T))
end

#function VarianceWelfordAggegation(::Type{<:Array{T,N}}) where {T,N}
# VarianceWelfordAggegation(0,zero(T),zero(T))
#end

# Welford's online algorithm
@inline function update(ag::VarianceWelfordAggegation, new_value)
count = ag.count
mean = ag.mean
M2 = ag.M2

count += 1
delta = new_value - mean
mean += delta / count
delta2 = new_value - mean
M2 += delta * delta2
return VarianceWelfordAggegation(count, mean, M2)
end

function result(ag::VarianceWelfordAggegation)
sample_variance = ag.M2 / (ag.count - 1)
return sample_variance
end
49 changes: 49 additions & 0 deletions src/groupby.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,40 @@ function _mapreduce(map_fun,reduce_op,gv::GroupedVariable{TV},indices;
return data_by_class,reshape(count,_indices(dim,length(count),1,indices))
end



function _mapreduce_aggregation(map_fun,ag,gv::GroupedVariable{TV},indices) where TV <: AbstractArray{T,N} where {T,N}
data = gv.v
dim = findfirst(==(Symbol(gv.coordname)),Symbol.(dimnames(data)))
class = _val.(gv.class)
unique_class = _val.(gv.unique_class[indices[dim]])
group_fun = gv.group_fun

nclass = length(unique_class)
sz_all = ntuple(i -> (i == dim ? nclass : size(data,i) ),ndims(data))
sz = size_getindex(sz_all,indices...)

data_by_class = fill(ag(T),sz)

count = zeros(Int,nclass)
for k = 1:size(data,dim)
ku = findfirst(==(class[k]),unique_class)

if !isnothing(ku)
dest_ind = _dest_indices(dim,ku,indices)
src_ind = ntuple(i -> (i == dim ? k : indices[i] ),ndims(data))
#@show size(data_by_class),dest_ind, indices
#@show src_ind
data_by_class_ind = view(data_by_class,dest_ind...)
std_data_ind = map_fun(data[src_ind...])

data_by_class_ind .= update.(data_by_class_ind,std_data_ind)
end
end

return result.(data_by_class)
end

function _reduce(args...; kwargs...)
_mapreduce(identity,args...; kwargs...)
end
Expand Down Expand Up @@ -408,6 +442,21 @@ function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(mean)},indices:
data ./ count
end


function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(var)},indices::Union{Integer,Colon,AbstractRange{<:Integer},AbstractVector{<:Integer}}...) where {T,N,TGV}

return _mapreduce_aggregation(
gr.gv.map_fun,VarianceWelfordAggegation,gr.gv,indices);
end


function Base.getindex(gr::ReducedGroupedVariable{T,N,TGV,typeof(std)},indices::Union{Integer,Colon,AbstractRange{<:Integer},AbstractVector{<:Integer}}...) where {T,N,TGV}

return sqrt.(_mapreduce_aggregation(
gr.gv.map_fun,VarianceWelfordAggegation,gr.gv,indices))
end


_dim_after_getindex(dim,ind::Union{Colon,AbstractRange,AbstractVector},other...) = _dim_after_getindex(dim+1,other...)
_dim_after_getindex(dim,ind::Integer,other...) = _dim_after_getindex(dim,other...)
_dim_after_getindex(dim) = dim
Expand Down
27 changes: 27 additions & 0 deletions test/perf/test_perf_cdm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using BenchmarkTools
using NCDatasets
using Dates
using CommonDataModel: @groupby

fname = expanduser("~/sample_perf2.nc")
ds = NCDataset(fname)

v = ds[:data]

mean_ref = cat(
[mean(v[:,:,findall(Dates.month.(ds[:time][:]) .== m)],dims=3)
for m in 1:12]...,dims=3);

std_ref = cat(
[std(v[:,:,findall(Dates.month.(ds[:time][:]) .== m)],dims=3)
for m in 1:12]...,dims=3);


gm = @btime mean(@groupby(ds[:data],Dates.Month(time)))[:,:,:];
# 1.005 s (523137 allocations: 2.67 GiB)

@show sqrt(mean((gm - mean_ref).^2))

# Welford
gs = @btime std(@groupby(ds[:data],Dates.Month(time)))[:,:,:];
@show sqrt(mean((gs - std_ref).^2))
23 changes: 23 additions & 0 deletions test/perf/test_perf_init.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# create the test file

using NCDatasets
using Dates

sz = (360,180)
time = DateTime(1980,1,1):Day(1):DateTime(2010,1,1);
varname = "data"

fname = expanduser("~/sample_perf2.nc")

isfile(fname) && rm(fname)

NCDataset(fname,"c") do ds
defVar(ds,"lon",1:sz[1],("lon",))
defVar(ds,"lat",1:sz[2],("lat",))
defVar(ds,"time",time,("time",))
ncv = defVar(ds,"data",Float32,("lon","lat","time"),attrib=Dict("foo" => "bar"))

for n = 1:length(time)
ncv[:,:,n] = randn(Float32,sz...) .+ 100
end
end
63 changes: 63 additions & 0 deletions test/perf/test_perf_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import timeit
import xarray as xr
import numpy

# xarray-2023.12.0
# Python 3.10.12

# mean
# minimum runtime of 30 trials
# 0.7370511470362544 seconds

# std
# 3.9330708980560303 seconds
tests = [
"""vm = ds["data"].groupby("time.month").mean().to_numpy();""",
"""vm = ds["data"].groupby("time.month").std().to_numpy();""",
]


print("runtime")

for tt in tests:
t = timeit.repeat(tt,
setup="""
import xarray as xr
fname = "/home/abarth/sample_perf2.nc"
ds = xr.open_dataset(fname)
""",
number=1,
repeat=30,
)

print("timeit ",min(t),tt)


fname = "/home/abarth/sample_perf2.nc"
ds = xr.open_dataset(fname)

month = ds["time.month"].to_numpy()

print("accuracy")


mean_ref = numpy.stack(
[ds["data"].data[(month == mm).nonzero()[0],:,:].mean(axis=0) for mm in range(1,13)],axis=0)

std_ref = numpy.stack(
[ds["data"].data[(month == mm).nonzero()[0],:,:].std(axis=0,ddof=1) for mm in range(1,13)],axis=0)


vm = ds["data"].groupby("time.month").mean().to_numpy();

print("accuracy of mean",
numpy.sqrt(numpy.mean((mean_ref - vm)**2)))
# output 0

vs = ds["data"].groupby("time.month").std()


print("accuracy of std",
numpy.sqrt(numpy.mean((std_ref - vs)**2)))

# 0.00053720415

0 comments on commit 65feb9c

Please sign in to comment.