Skip to content

Commit

Permalink
Merge pull request #33 from awesome-spectral-indices/fm/na
Browse files Browse the repository at this point in the history
Dispatch over NamedTuples
  • Loading branch information
MartinuzziFrancesco authored Feb 8, 2024
2 parents 7b49bf6 + 771eb17 commit c19da0e
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 128 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SpectralIndices"
uuid = "df0093a1-273d-40bc-819a-796ec3476907"
authors = ["MartinuzziFrancesco <[email protected]>"]
version = "0.2.0"
version = "0.2.1"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down Expand Up @@ -34,6 +34,7 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
YAXArrays = "c21b50f5-aa40-41ea-b809-c0f5e47bfa5c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[targets]
test = ["Test", "SafeTestsets", "Aqua", "JET", "JuliaFormatter", "DataFrames", "YAXArrays"]
test = ["Test", "SafeTestsets", "Aqua", "JET", "JuliaFormatter", "DataFrames", "YAXArrays", "Random"]
69 changes: 25 additions & 44 deletions docs/src/tutorials/basic_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ NDVI.bands
### Using the `compute` Function


A more flexible way to calculate indices is through the compute function. This function accepts the `SpectralIndex` struct and parameters as either a dictionary or keyword arguments:
A more flexible way to calculate indices is through the `compute` function. This function accepts the `SpectralIndex` struct and parameters as either a dictionary or keyword arguments:

```@example basic
params = Dict(
Expand Down Expand Up @@ -144,49 +144,6 @@ savi = compute_index("SAVI", params)
savi = compute_index("SAVI"; N=nir, R=red, L=0.5)
```

### Support for Different Float Types

In both examples we see that the returned value is a `Float64`, since this is what we gave the function as input:

```@example basic
savi1 = compute_index("SAVI", params)
savi2 = compute_index("SAVI"; N=nir, R=red, L=0.5)
eltype(savi1) == eltype(savi2) == Float64
```

We can also compute spectral indices with other float types, such as `Float32` or `Float16`. All it needs is to feed the `compute` or the `compute_index` function input points of the chosen `Float` type. This is specifically helpful for machine learning applications, where `Float32` are the default:

```@example basic
T = Float32
params = Dict(
"N" => T(nir),
"R" => T(red),
"L" => 0.5f0
)
savi1 = compute_index("SAVI", params)
savi2 = compute_index("SAVI"; N=T(nir), R=T(red), L=0.5f0)
eltype(savi1) == eltype(savi2) == Float32
```

The same also holds for `Float16`:

```@example basic
T = Float16
params = Dict(
"N" => T(nir),
"R" => T(red),
"L" => T(0.5)
)
savi1 = compute_index("SAVI", params)
savi2 = compute_index("SAVI"; N=T(nir), R=T(red), L=T(0.5))
eltype(savi1) == eltype(savi2) == Float16
```

### Computing Multiple Indices

Now that we have added more indices we can explore how to compute multiple indices at the same time. All is needed is to pass a Vector of `String`s to the `compute_index` function with the chosen spectral indices inside, as well as the chosen parameters of course:
Expand Down Expand Up @@ -223,6 +180,9 @@ After that we can compute either one, or both indices:
```@example basic
ndvi, savi = compute_index(["NDVI", "SAVI"], params)
```

We can use the same params to calculate single indices. The additional bands are just going to be ignored:

```@example basic
ndvi = compute_index("NDVI", params)
```
Expand All @@ -249,4 +209,25 @@ savi = compute_index("SAVI";
N=fill(nir, 10),
R=fill(red, 10),
L=fill(0.5, 10))
```

## Extension to NamedTuples

SpectralIndices.jl allows you to also create indices from `NamedTuples`:

```@example basic
params = (N=fill(0.2, 10), R=fill(0.1, 10), L=fill(0.5, 10))
compute_index("NDVI", params)
```
```@example basic
compute_index(["NDVI", "SAVI"], params)
```

You can also pass the `NamedTuple` as kwargs splatting them, but the output will not be a `NamedTuple`

```@example basic
compute_index("NDVI"; params...)
```
```@example basic
compute_index(["NDVI", "SAVI"]; params...)
```
8 changes: 0 additions & 8 deletions ext/SpectralIndicesYAXArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,11 @@ function SpectralIndices.load_dataset(
end

ds = SpectralIndices._load_json(datasets[dataset])

# Convert each vector of vectors in `ds` into a matrix
matrices = [hcat(ds[i]...) for i in 1:length(ds)]

# Stack these matrices to form a 3D array
data_3d = cat(matrices...; dims=3)

# Define dimensions
x_dim = Dim{:x}(1:300)
y_dim = Dim{:y}(1:300)
bands = Dim{:bands}(["B02", "B03", "B04", "B08"])

# Create the YAXArray
yax_ds = YAXArray((x_dim, y_dim, bands), data_3d)

return yax_ds
Expand Down
122 changes: 110 additions & 12 deletions src/compute_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,6 @@ function compute_index(
return results
end

function compute_index(index::String, params::Dict; indices=indices)
_check_params(indices[index], params)
params = _order_params(indices[index], params)
result = _compute_index(indices[index], params...)

return result
end

function compute_index(index::Vector{String}, params=nothing, online::Bool=false; kwargs...)
indices = _create_indices(online)
names = keys(indices)
Expand All @@ -91,19 +83,125 @@ function compute_index(index::Vector{String}, params=nothing, online::Bool=false
return results
end

function compute_index(index::String, params::Dict; indices=indices)
_check_params(indices[index], params)
params = _order_params(indices[index], params)
result = _compute_index(indices[index], params...)

return result
end

# TODO: return results in a matrix columnswise
#multi_result = compute_index(["NDVI", "SAVI"], N = fill(0.643, 5), R = fill(0.175, 5), L = fill(0.5, 5))
function compute_index(index::Vector{String}, params::Dict; indices=indices)
results = []

results = []
for (nidx, idx) in enumerate(index)
_check_params(indices[idx], params)
local_params = _order_params(indices[idx], params)
push!(results, _compute_index(indices[idx], local_params...))
result = compute_index(idx, params; indices=indices)
push!(results, result)
end

return results
end



_compute_index(idx::AbstractSpectralIndex, prms::Number...) = idx(prms...)
_compute_index(idx::AbstractSpectralIndex, prms::AbstractArray...) = idx.(prms...)

function compute_index(index::String, params::NamedTuple; indices=indices)
_check_params(indices[index], params)
params = _order_params(indices[index], params)
result = _compute_index(indices[index], params...)
result_nt = (; Dict(Symbol(index) => result)...)
return result_nt
end

function compute_index(index::Vector{String}, params::NamedTuple; indices=indices)
results_dict = Dict{Symbol, Any}()
for idx in index
result_nt = compute_index(idx, params; indices=indices)
# TODO @MartinuzziFrancesco: there has to be a better way
results_dict[fieldnames(typeof(result_nt))[1]] = first(result_nt)
end

# Convert the dictionary to a named tuple
return (;results_dict...)
end


"""
_check_params(index::String, params::Dict, indices::Dict)
Check if the parameters dictionary contains all required bands for spectral index computation.
# Arguments
- `index::String`: The name of the spectral index to check.
- `params::Dict`: The parameters dictionary to check for required bands.
- `indices::Dict`: The dictionary containing information about spectral indices.
# Returns
- `None`
# Examples
```julia
# Check parameters for the NDVI index
index_name = "NDVI"
parameters = Dict("N" => 0.6, "R" => 0.3, "G" => 0.7)
indices = _get_indices()
# Check if parameters contain required bands
_check_params(index_name, parameters, indices)
```
"""
function _check_params(index, params::Dict)
for band in index.bands
if !(band in keys(params))
throw(
ArgumentError(
"'$band' is missing in the parameters for $index computation!"
),
)
end
end
end

function _order_params(index, params)
new_params = []
for (bidx, band) in enumerate(index.bands)
push!(new_params, params[band])
end

return new_params
end

function _create_params(kw_args...)
params = Dict(String(k) => v for (k, v) in kw_args)
return params
end


function _check_params(index, params::NamedTuple)
for band in index.bands
if !(Symbol(band) in keys(params))
throw(
ArgumentError(
"'$band' is missing in the parameters for $index computation!"
),
)
end
end
end

function _order_params(index, params::NamedTuple)
new_params = []
for (bidx, band) in enumerate(index.bands)
band_symbol = Symbol(band)
push!(new_params, params[band_symbol])
end

return new_params
end
15 changes: 15 additions & 0 deletions src/compute_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ function linear(params::Dict{String,T}) where {T<:Union{<:Number,<:AbstractArray
return result
end

function linear(params::NamedTuple)
result = linear(params.a, params.b)
return result
end

"""
poly(a::T, b::T, c::T, p::T) where T <: Number
poly(a::T, b::T, c::T, p::T) where T <: AbstractArray
Expand Down Expand Up @@ -135,6 +140,11 @@ function poly(params::Dict{String,T}) where {T<:Union{<:Number,<:AbstractArray}}
return result
end

function poly(params::NamedTuple)
result = poly(params.a, params.b, params.c, params.p)
return result
end

"""
RBF(a::T, b::T, sigma::T) where T <: Number
RBF(a::T, b::T, sigma::T) where T <: AbstractArray
Expand Down Expand Up @@ -191,3 +201,8 @@ function RBF(params::Dict{String,T}) where {T<:Union{<:Number,<:AbstractArray}}
result = RBF(params["a"], params["b"], params["sigma"])
return result
end

function RBF(params::NamedTuple)
result = RBF(params.a, params.b, params.sigma)
return result
end
53 changes: 0 additions & 53 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,59 +72,6 @@ function _get_indices(
return indices["SpectralIndices"]
end

"""
_check_params(index::String, params::Dict, indices::Dict)
Check if the parameters dictionary contains all required bands for spectral index computation.
# Arguments
- `index::String`: The name of the spectral index to check.
- `params::Dict`: The parameters dictionary to check for required bands.
- `indices::Dict`: The dictionary containing information about spectral indices.
# Returns
- `None`
# Examples
```julia
# Check parameters for the NDVI index
index_name = "NDVI"
parameters = Dict("N" => 0.6, "R" => 0.3, "G" => 0.7)
indices = _get_indices()
# Check if parameters contain required bands
_check_params(index_name, parameters, indices)
```
"""
function _check_params(index, params::Dict)
for band in index.bands
if !(band in keys(params))
throw(
ArgumentError(
"'$band' is missing in the parameters for $index computation!"
),
)
end
end
end

function _order_params(index, params)
new_params = []
for (bidx, band) in enumerate(index.bands)
push!(new_params, params[band])
end

return new_params
end

function _create_params(kw_args...)
params = Dict(String(k) => v for (k, v) in kw_args)
return params
end

function _create_indexfun(
index_dict::Dict{String,Any}=_get_indices();
filename::String="indices_funcs.jl",
Expand Down
Loading

2 comments on commit c19da0e

@MartinuzziFrancesco
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/100480

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" c19da0ebff0c986392b9529e4a31a201ff7bcb88
git push origin v0.2.1

Please sign in to comment.