-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
44aa44b
commit 1b1554c
Showing
2 changed files
with
61 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,37 @@ | ||
# Benchmark to be run on Linux as root | ||
|
||
using BenchmarkTools | ||
using NCDatasets | ||
using Dates | ||
using CommonDataModel: @groupby | ||
using CommonDataModel | ||
|
||
fname = expanduser("~/sample_perf2.nc") | ||
ds = NCDataset(fname) | ||
|
||
v = ds[:data] | ||
data_f64 = Float64.(ds[:data][:,:,:]) | ||
|
||
println("runtime") | ||
gm = @btime begin | ||
write("/proc/sys/vm/drop_caches","3") | ||
mean(@groupby(ds[:data],Dates.Month(time)))[:,:,:]; | ||
end | ||
|
||
# Welford | ||
gs = @btime begin | ||
write("/proc/sys/vm/drop_caches","3") | ||
std(@groupby(ds[:data],Dates.Month(time)))[:,:,:]; | ||
end | ||
|
||
println("accuracy") | ||
|
||
mean_ref = cat( | ||
[mean(v[:,:,findall(Dates.month.(ds[:time][:]) .== m)],dims=3) | ||
[mean(data_f64[:,:,findall(Dates.month.(ds[:time][:]) .== m)],dims=3) | ||
for m in 1:12]...,dims=3); | ||
|
||
std_ref = cat( | ||
[std(v[:,:,findall(Dates.month.(ds[:time][:]) .== m)],dims=3) | ||
[std(data_f64[:,:,findall(Dates.month.(ds[:time][:]) .== m)],dims=3) | ||
for m in 1:12]...,dims=3); | ||
|
||
|
||
gm = @btime mean(@groupby(ds[:data],Dates.Month(time)))[:,:,:]; | ||
# 1.005 s (523137 allocations: 2.67 GiB) | ||
|
||
@show sqrt(mean((gm - mean_ref).^2)) | ||
|
||
# Welford | ||
gs = @btime std(@groupby(ds[:data],Dates.Month(time)))[:,:,:]; | ||
@show sqrt(mean((gs - std_ref).^2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,65 @@ | ||
# Benchmark to be run on Linux as root | ||
# dropping file cache is OS specific and requires root priviledges | ||
|
||
import timeit | ||
import xarray as xr | ||
import numpy | ||
import os | ||
import sys | ||
|
||
# xarray-2023.12.0 | ||
# Python 3.10.12 | ||
|
||
# mean | ||
# minimum runtime of 30 trials | ||
# 0.7370511470362544 seconds | ||
|
||
# std | ||
# 3.9330708980560303 seconds | ||
tests = [ | ||
"""vm = ds["data"].groupby("time.month").mean().to_numpy();""", | ||
"""vm = ds["data"].groupby("time.month").std().to_numpy();""", | ||
] | ||
def mean_no_cache(ds): | ||
with open("/proc/sys/vm/drop_caches","w") as f: | ||
f.write("3") | ||
vm = ds["data"].groupby("time.month").mean().to_numpy(); | ||
|
||
def std_no_cache(ds): | ||
with open("/proc/sys/vm/drop_caches","w") as f: | ||
f.write("3") | ||
|
||
print("runtime") | ||
vm = ds["data"].groupby("time.month").std().to_numpy(); | ||
|
||
for tt in tests: | ||
t = timeit.repeat(tt, | ||
setup=""" | ||
import xarray as xr | ||
fname = "/home/abarth/sample_perf2.nc" | ||
fname = os.path.expanduser("~/sample_perf2.nc") | ||
ds = xr.open_dataset(fname) | ||
""", | ||
number=1, | ||
repeat=30, | ||
) | ||
|
||
print("timeit ",min(t),tt) | ||
print("python: ",sys.version) | ||
print("xarray: ",xr.__version__) | ||
print("numpy: ",numpy.__version__) | ||
|
||
|
||
fname = "/home/abarth/sample_perf2.nc" | ||
ds = xr.open_dataset(fname) | ||
if __name__ == "__main__": | ||
print("runtime") | ||
|
||
for test_fun in [mean_no_cache, std_no_cache]: | ||
t = timeit.repeat(lambda: test_fun(ds), | ||
setup="""from __main__ import ds""", | ||
number=1, | ||
repeat=30, | ||
) | ||
|
||
print(" minimum time of ",test_fun,": ",min(t)) | ||
|
||
month = ds["time.month"].to_numpy() | ||
|
||
print("accuracy") | ||
month = ds["time.month"].to_numpy() | ||
|
||
print("accuracy") | ||
|
||
mean_ref = numpy.stack( | ||
[ds["data"].data[(month == mm).nonzero()[0],:,:].mean(axis=0) for mm in range(1,13)],axis=0) | ||
data_f64 = ds["data"].data.astype(dtype="f8") | ||
|
||
std_ref = numpy.stack( | ||
[ds["data"].data[(month == mm).nonzero()[0],:,:].std(axis=0,ddof=1) for mm in range(1,13)],axis=0) | ||
mean_ref = numpy.stack( | ||
[data_f64[(month == mm).nonzero()[0],:,:].mean(axis=0) for mm in range(1,13)],axis=0) | ||
|
||
std_ref = numpy.stack( | ||
[data_f64[(month == mm).nonzero()[0],:,:].std(axis=0,ddof=1) for mm in range(1,13)],axis=0) | ||
|
||
vm = ds["data"].groupby("time.month").mean().to_numpy(); | ||
|
||
print("accuracy of mean", | ||
numpy.sqrt(numpy.mean((mean_ref - vm)**2))) | ||
# output 0 | ||
vm = ds["data"].groupby("time.month").mean().to_numpy(); | ||
|
||
vs = ds["data"].groupby("time.month").std() | ||
print(" accuracy of mean", | ||
numpy.sqrt(numpy.mean((mean_ref - vm)**2))) | ||
|
||
vs = ds["data"].groupby("time.month").std().to_numpy() | ||
|
||
print("accuracy of std", | ||
numpy.sqrt(numpy.mean((std_ref - vs)**2))) | ||
|
||
# 0.00053720415 | ||
print(" accuracy of std", | ||
numpy.sqrt(numpy.mean((std_ref - vs)**2))) |