Skip to content

Commit

Permalink
Merge pull request #57 from awesome-spectral-indices/fm/idx
Browse files Browse the repository at this point in the history
Dispatch compute_index over AbstractSpectralIndicex
  • Loading branch information
MartinuzziFrancesco authored Apr 3, 2024
2 parents 5e5ddd8 + 899e16b commit e64c08d
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 70 deletions.
2 changes: 1 addition & 1 deletion data/spectral-indices-dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -2644,7 +2644,7 @@
"MODIS"
],
"reference": "https://doi.org/10.1016/j.jag.2015.02.010",
"short_name": "NDSoiI"
"short_name": "NDSoI"
},
"NDTI": {
"application_domain": "water",
Expand Down
30 changes: 16 additions & 14 deletions ext/SpectralIndicesDataFramesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ module SpectralIndicesDataFramesExt

using SpectralIndices
using DataFrames
import SpectralIndices: _create_params, AbstractSpectralIndex, compute_index,
_create_indices, linear, poly, RBF, load_dataset, _load_json

function SpectralIndices._create_params(kw_args::Pair{Symbol,DataFrame}...)
function _create_params(kw_args::Pair{Symbol,DataFrame}...)
combined_df = DataFrame()

for pair in kw_args
Expand All @@ -14,52 +16,52 @@ function SpectralIndices._create_params(kw_args::Pair{Symbol,DataFrame}...)
return combined_df
end

function SpectralIndices.compute_index(
index::String, params::DataFrame; indices=SpectralIndices._create_indices()
function compute_index(
index::AbstractSpectralIndex, params::DataFrame; indices=_create_indices()
)
# Convert DataFrame to a dictionary for each row and compute the index
results = [
SpectralIndices.compute_index(
compute_index(
index, Dict(zip(names(params), row)); indices=indices
) for row in eachrow(params)
]

# Return the results as a DataFrame with the column named after the index
return DataFrame(Symbol(index) => results)
return DataFrame(Symbol(index.short_name) => results)
end

function SpectralIndices.compute_index(
index::Vector{String}, params::DataFrame; indices=SpectralIndices._create_indices()
function compute_index(
index::Vector{<:AbstractSpectralIndex}, params::DataFrame; indices=_create_indices()
)
# Similar conversion and computation for a vector of indices
result_dfs = DataFrame()
for idx in index
result_df = SpectralIndices.compute_index(idx, params; indices=indices)
result_dfs[!, Symbol(idx)] = result_df[!, 1]
result_df = compute_index(idx, params; indices=indices)
result_dfs[!, Symbol(idx.short_name)] = result_df[!, 1]
end
# Return the combined DataFrame with columns named after each index
return result_dfs
end

function SpectralIndices.linear(params::DataFrame)
function linear(params::DataFrame)
result = linear(params[!, "a"], params[!, "b"])
result_df = DataFrame(; linear=result)
return result_df
end

function SpectralIndices.poly(params::DataFrame)
function poly(params::DataFrame)
result = poly(params[!, "a"], params[!, "b"], params[!, "c"], params[!, "p"])
result_df = DataFrame(; poly=result)
return result_df
end

function SpectralIndices.RBF(params::DataFrame)
function RBF(params::DataFrame)
result = RBF(params[!, "a"], params[!, "b"], params[!, "sigma"])
result_df = DataFrame(; RBF=result)
return result_df
end

function SpectralIndices.load_dataset(dataset::String, ::Type{T}) where {T<:DataFrame}
function load_dataset(dataset::String, ::Type{T}) where {T<:DataFrame}
datasets = Dict("spectral" => "spectral.json")

if dataset in keys(datasets)
Expand All @@ -68,7 +70,7 @@ function SpectralIndices.load_dataset(dataset::String, ::Type{T}) where {T<:Data
error("Dataset name not valid. Datasets available for DataFrames: spectral")
end

ds = SpectralIndices._load_json(datasets[dataset])
ds = _load_json(datasets[dataset])
all_indices = Set{Int}()
for val in values(ds)
for idx in keys(val)
Expand Down
44 changes: 24 additions & 20 deletions ext/SpectralIndicesYAXArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ module SpectralIndicesYAXArraysExt
using SpectralIndices
using YAXArrays
using DimensionalData
import SpectralIndices: _check_params, _create_params, _order_params,
AbstractSpectralIndex, compute_index, _create_indices,
linear, poly, RBF, load_dataset, _load_json

function SpectralIndices._check_params(index, params::YAXArray)
function _check_params(index::AbstractSpectralIndex, params::YAXArray)
for band in index.bands
if !(band in params.Variables)
throw(
Expand All @@ -16,7 +19,7 @@ function SpectralIndices._check_params(index, params::YAXArray)
end
end

function SpectralIndices._order_params(index, params::YAXArray)
function _order_params(index::AbstractSpectralIndex, params::YAXArray)
new_params = []
for (bidx, band) in enumerate(index.bands)
push!(new_params, params[Variable=At(band)])
Expand All @@ -25,7 +28,7 @@ function SpectralIndices._order_params(index, params::YAXArray)
return new_params
end

function SpectralIndices._create_params(kw_args::Pair{Symbol,<:YAXArray}...)
function _create_params(kw_args::Pair{Symbol,<:YAXArray}...)
params_yaxa = []
names_yaxa = []
for (key, value) in kw_args
Expand All @@ -39,18 +42,19 @@ end

## TODO: simplify even further
# this is same function contente as dispatch on Dict
function SpectralIndices.compute_index(
index::String, params::YAXArray; indices=SpectralIndices._create_indices()
function compute_index(index::AbstractSpectralIndex,
params::YAXArray;
indices=_create_indices()
)
SpectralIndices._check_params(indices[index], params)
params = SpectralIndices._order_params(indices[index], params)
_check_params(index, params)
params = _order_params(index, params)
T = eltype(first(params))
result = SpectralIndices._compute_index(T, indices[index], params...)
result = _compute_index(T, index, params...)
return result
end

function SpectralIndices.compute_index(
index::Vector{String}, params::YAXArray; indices=SpectralIndices._create_indices()
function compute_index(
index::Vector{<:AbstractSpectralIndex}, params::YAXArray; indices=_create_indices()
)
results = []
for (nidx, idx) in enumerate(index)
Expand All @@ -62,32 +66,32 @@ function SpectralIndices.compute_index(
return result
end

function SpectralIndices._compute_index(
::Type{T}, idx::SpectralIndices.AbstractSpectralIndex, prms::YAXArray...
function _compute_index(
::Type{T}, idx::AbstractSpectralIndex, prms::YAXArray...
) where {T<:Number}
return idx.(T, prms...)
end

function SpectralIndices.linear(params::YAXArray)
return SpectralIndices.linear(params[Variable=At("a")], params[Variable=At("b")])
function linear(params::YAXArray)
return linear(params[Variable=At("a")], params[Variable=At("b")])
end

function SpectralIndices.poly(params::YAXArray)
return SpectralIndices.poly(
function poly(params::YAXArray)
return poly(
params[Variable=At("a")],
params[Variable=At("b")],
params[Variable=At("c")],
params[Variable=At("p")],
)
end

function SpectralIndices.RBF(params::YAXArray)
return SpectralIndices.RBF(
function RBF(params::YAXArray)
return RBF(
params[Variable=At("a")], params[Variable=At("b")], params[Variable=At("sigma")]
)
end

function SpectralIndices.load_dataset(dataset::String, ::Type{T}) where {T<:YAXArray}
function load_dataset(dataset::String, ::Type{T}) where {T<:YAXArray}
datasets = Dict("sentinel" => "S2_10m.json")

if dataset in keys(datasets)
Expand All @@ -96,7 +100,7 @@ function SpectralIndices.load_dataset(dataset::String, ::Type{T}) where {T<:YAXA
error("Dataset name not valid. Datasets available for YAXArrays: sentinel")
end

ds = SpectralIndices._load_json(datasets[dataset])
ds = _load_json(datasets[dataset])
matrices = [hcat(ds[i]...) for i in 1:length(ds)]
data_3d = cat(matrices...; dims=3)
x_dim = Dim{:x}(1:300)
Expand Down
82 changes: 52 additions & 30 deletions src/compute_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,47 +53,67 @@ julia> compute_index(
```
"""
function compute_index(
index::AbstractSpectralIndex, params=nothing, online::Bool=false; indices=indices, kwargs...
)

if isnothing(params)
params = _create_params(kwargs...)
end

return compute_index(index, params; indices=indices)
end

function compute_index(
index::String, params=nothing, online::Bool=false; indices=indices, kwargs...
)
names = keys(indices)
@assert index in names "$index is not a valid Spectral Index!"
results = compute_index(indices[index], params; indices=indices, kwargs...)
return results
end

function compute_index(index::Vector{<:AbstractSpectralIndex},
params=nothing,
online::Bool=false;
indices=_create_indices(online),
kwargs...
)

if isnothing(params)
params = _create_params(kwargs...)
end

results = compute_index(index, params; indices=indices)
return results
return compute_index(index, params; indices=indices)
end

function compute_index(index::Vector{String}, params=nothing, online::Bool=false; kwargs...)
indices = _create_indices(online)
function compute_index(index::Vector{String},
params=nothing,
online::Bool=false;
indices = _create_indices(online),
kwargs...)

names = keys(indices)
for idx in index
@assert idx in names "$index is not a valid Spectral Index!"
end

if isnothing(params)
params = _create_params(kwargs...)
end

results = compute_index(index, params; indices=indices)
idxs = [indices[idx] for idx in index]
results = compute_index(idxs, params; indices=indices, kwargs...)
return results
end

function compute_index(index::String, params::Dict; indices=indices)
_check_params(indices[index], params)
params = _order_params(indices[index], params)
function compute_index(index::AbstractSpectralIndex, params::Dict; indices=indices)
_check_params(index, params)
params = _order_params(index, params)
T = eltype(first(values(params)))
result = _compute_index(T, indices[index], params...)
result = _compute_index(T, index, params...)

return result
end

# TODO: return results in a matrix columnswise
#multi_result = compute_index(["NDVI", "SAVI"], N = fill(0.643, 5), R = fill(0.175, 5), L = fill(0.5, 5))
function compute_index(index::Vector{String}, params::Dict; indices=indices)
function compute_index(index::Vector{<:AbstractSpectralIndex}, params::Dict; indices=indices)
results = []
for (nidx, idx) in enumerate(index)
result = compute_index(idx, params; indices=indices)
Expand All @@ -104,29 +124,31 @@ function compute_index(index::Vector{String}, params::Dict; indices=indices)
end

#_compute_index(idx::AbstractSpectralIndex, prms::Number...) = idx(prms...)
function _compute_index(
::Type{T}, idx::AbstractSpectralIndex, prms::Number...
function _compute_index(::Type{T},
idx::AbstractSpectralIndex,
prms::Number...
) where {T<:Number}
return idx(T, prms...)
end

#_compute_index(idx::AbstractSpectralIndex, prms::AbstractArray...) = idx.(prms...)
function _compute_index(
::Type{T}, idx::AbstractSpectralIndex, prms::AbstractArray...
function _compute_index(::Type{T},
idx::AbstractSpectralIndex,
prms::AbstractArray...
) where {T<:Number}
return idx.(T, prms...)
end

function compute_index(index::String, params::NamedTuple; indices=indices)
_check_params(indices[index], params)
params = _order_params(indices[index], params)
function compute_index(index::AbstractSpectralIndex, params::NamedTuple; indices=indices)
_check_params(index, params)
params = _order_params(index, params)
T = eltype(first(values(params)))
result = _compute_index(T, indices[index], params...)
result_nt = (; Dict(Symbol(index) => result)...)
result = _compute_index(T, index, params...)
result_nt = (; Dict(Symbol(index.short_name) => result)...)
return result_nt
end

function compute_index(index::Vector{String}, params::NamedTuple; indices=indices)
function compute_index(index::Vector{<:AbstractSpectralIndex}, params::NamedTuple; indices=indices)
results_dict = Dict{Symbol,Any}()
for idx in index
result_nt = compute_index(idx, params; indices=indices)
Expand Down Expand Up @@ -165,7 +187,7 @@ indices = _get_indices()
_check_params(index_name, parameters, indices)
```
"""
function _check_params(index, params::Dict)
function _check_params(index::AbstractSpectralIndex, params::Dict)
for band in index.bands
if !(band in keys(params))
throw(
Expand All @@ -177,7 +199,7 @@ function _check_params(index, params::Dict)
end
end

function _order_params(index, params)
function _order_params(index::AbstractSpectralIndex, params)
T = eltype(params)
new_params = T[]
for (bidx, band) in enumerate(index.bands)
Expand All @@ -187,7 +209,7 @@ function _order_params(index, params)
return new_params
end

function _order_params(index, params::Dict)
function _order_params(index::AbstractSpectralIndex, params::Dict)
T = eltype(values(params))
new_params = T[]
for (bidx, band) in enumerate(index.bands)
Expand All @@ -202,7 +224,7 @@ function _create_params(kw_args...)
return params
end

function _check_params(index, params::NamedTuple)
function _check_params(index::AbstractSpectralIndex, params::NamedTuple)
for band in index.bands
if !(Symbol(band) in keys(params))
throw(
Expand All @@ -214,7 +236,7 @@ function _check_params(index, params::NamedTuple)
end
end

function _order_params(index, params::NamedTuple)
function _order_params(index::AbstractSpectralIndex, params::NamedTuple)
T = eltype(values(params))
new_params = T[]
for (bidx, band) in enumerate(index.bands)
Expand Down
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ indices = _get_indices(true)
function _get_indices(
online::Bool=false;
filename="spectral-indices-dict.json",
fileloc=joinpath(dirname(@__FILE__), "..", "data", filename),
fileloc=joinpath(dirname(@__FILE__), "..", "data"),
)
final_file = joinpath(fileloc, filename)
if online
indices_loc = Downloads.download(
"https://raw.githubusercontent.com/awesome-spectral-indices/awesome-spectral-indices/main/output/spectral-indices-dict.json",
fileloc,
final_file,
)
indices = JSON.parsefile(indices_loc)
else
Expand Down
Loading

0 comments on commit e64c08d

Please sign in to comment.