Skip to content

Commit

Permalink
Merge pull request #28 from invenia/bes/InputTrait
Browse files Browse the repository at this point in the history
Add PredictInputTrait to Models API
  • Loading branch information
BSnelling authored Mar 2, 2021
2 parents 857af7d + 9252321 commit 7a62068
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 56 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Models"
uuid = "e6388cff-ecff-480c-9b53-83211bf7812a"
authors = ["Invenia Technical Computing Corporation"]
version = "0.2.4"
version = "0.2.5"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ predict
submodels
estimate_type
output_type
predict_input_type
```

## Traits
Expand All @@ -23,4 +24,7 @@ DistributionEstimate
OutputTrait
SingleOutput
MultiOutput
PredictInputTrait
PointPredictInput
PointOrDistributionPredictInput
```
4 changes: 4 additions & 0 deletions docs/src/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@ Here are the current [`Model`](@ref) traits in use and their possible values:
- [`output_type`](@ref) - determines how many output variates a [`Model`](@ref) can learn
- [`SingleOutput`](@ref): Fits and predicts on a single output only.
- [`MultiOutput`](@ref): Fits and predicts on multiple outputs at a time.
- [`predict_input_type`](@ref) - determines which datatypes a [`Model`](@ref) can accept at predict time.
- [`PointPredictInput`](@ref): Real valued input variables accepted at predict time.
- [`PointOrDistributionPredictInput`](@ref): Either real valued or distributions of input variables accepted at predict time.

The traits always agree between the [`Model`](@ref) and the [`Template`](@ref).
Every [`Model`](@ref) and [`Template`](@ref) should define all the listed traits.
If left undefined, the [`PredictInputTrait`](@ref) will have the default value of [`PointPredictInput`](@ref).

This package uses traits implemented such that the trait function returns an `abstract type` (rather than an instance).
That means to check a trait one uses:
Expand Down
17 changes: 17 additions & 0 deletions docs/src/testutils.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,23 @@ TestUtils
test_interface
```

### Note on PredictInputTrait Interface Tests

In the case where the [`PredictInputTrait`](@ref) is [`PointOrDistributionPredictInput`](@ref) the the Models API requires only that the distribution in question is `Sampleable`.
When using [`Models.TestUtils.test_interface`](@ref) to test a model where distributions can be passed to [`predict`](@ref), the user should provide `inputs` of the distribution type appropriate to their model.
In the example below the `CustomModel` accepts `MvNormal` distributions to `predict`.

```julia
using CustomModels
using Distributions
using Models.TestUtils

test_interface(
CustomModelTemplate();
distribution_inputs=[MvNormal(5, 1) for _ in 1:5],
)
```

## Test Fakes
```@autodocs
Modules = [Models.TestUtils]
Expand Down
12 changes: 11 additions & 1 deletion src/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ module Models
import StatsBase: fit, predict

export Model, Template
export fit, predict, submodels, estimate_type, output_type
export fit, predict, submodels, estimate_type, output_type, predict_input_type
export EstimateTrait, PointEstimate, DistributionEstimate
export OutputTrait, SingleOutput, MultiOutput
export PredictInputTrait, PointPredictInput, PointOrDistributionPredictInput

"""
Template
Expand Down Expand Up @@ -38,9 +39,14 @@ function fit end

"""
predict(model::Model, inputs::AbstractMatrix)
predict(model::Model, inputs::AbstractVector{<:AbstractVector})
Predict targets for the provided the collection of `inputs` and [`Model`](@ref).
A [`Model`](@ref) subtype for which the `predict_input_type(model)` is
[`PointPredictInput`](@ref) will only need to implement a `predict` function that operates
on an `AbstractMatrix` of inputs.
If the `estimate_type(model)` is [`PointEstimate`](@ref) then this function should return
another `AbstractMatrix` in which each column contains the prediction for a single input.
Expand All @@ -49,6 +55,10 @@ return a `AbstractVector{<:Distribution}`.
"""
function predict end

function predict(model::Model, inputs::AbstractVector{<:AbstractVector})
return predict(model, reduce(hcat, inputs))
end

"""
submodels(::Union{Template, Model})
Expand Down
112 changes: 71 additions & 41 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ end
"""
FakeTemplate{PointEstimate, SingleOutput}()
A [`Template`](@ref) whose [`Model`](@ref) will predict 0 for each observation.
A [`Template`](@ref) whose [`Model`](@ref) will accept real value variables to predict 0
for each observation.
"""
function FakeTemplate{PointEstimate, SingleOutput}()
FakeTemplate{PointEstimate, SingleOutput}() do num_variates, inputs
@assert(num_variates == 1, "$num_variates != 1")
inputs = NamedDimsArray{(:features, :observations)}(inputs)
inputs = _handle_inputs(inputs)
return NamedDimsArray{(:variates, :observations)}(
zeros(1, size(inputs, :observations))
)
Expand All @@ -53,12 +54,12 @@ end
"""
FakeTemplate{PointEstimate, MultiOutput}()
A [`Template`](@ref) whose [`Model`](@ref) will predict a vector of 0s for each observation.
The input and output will have the same dimension.
A [`Template`](@ref) whose [`Model`](@ref) will accept real value variables to predict a
vector of 0s for each observation. The input and output will have the same dimension.
"""
function FakeTemplate{PointEstimate, MultiOutput}()
FakeTemplate{PointEstimate, MultiOutput}() do num_variates, inputs
inputs = NamedDimsArray{(:features, :observations)}(inputs)
inputs = _handle_inputs(inputs)
return NamedDimsArray{(:variates, :observations)}(
zeros(num_variates, size(inputs, :observations))
)
Expand All @@ -68,30 +69,46 @@ end
"""
FakeTemplate{DistributionEstimate, SingleOutput}()
A [`Template`](@ref) whose [`Model`](@ref) will predict a univariate normal
distribution (with zero mean and unit standard deviation) for each observation.
A [`Template`](@ref) whose [`Model`](@ref) will accept real value variables to predict a
univariate normal distribution (with zero mean and unit standard deviation) for each
observation.
"""
function FakeTemplate{DistributionEstimate, SingleOutput}()
FakeTemplate{DistributionEstimate, SingleOutput}() do num_variates, inputs
@assert(num_variates == 1, "$num_variates != 1")
inputs = NamedDimsArray{(:features, :observations)}(inputs)
inputs = _handle_inputs(inputs)
return NoncentralT.(3.0, zeros(size(inputs, :observations)))
end
end

"""
FakeTemplate{DistributionEstimate, MultiOutput}()
A [`Template`](@ref) whose [`Model`](@ref) will predict a multivariate normal
distribution (with zero-vector mean and identity covariance matrix) for each observation.
A [`Template`](@ref) whose [`Model`](@ref) will accept real value variables to predict a
multivariate normal distribution (with zero-vector mean and identity covariance matrix) for
each observation.
"""
function FakeTemplate{DistributionEstimate, MultiOutput}()
FakeTemplate{DistributionEstimate, MultiOutput}() do num_variates, inputs
std_dev = ones(num_variates)
inputs = _handle_inputs(inputs)
return [Product(Normal.(0, std_dev)) for _ in 1:size(inputs, 2)]
end
end

"""
_handle_inputs(inputs::AbstractMatrix)
_handle_inputs(inputs::AbstractVector{<:Sampleable})
Process the inputs to `predict` appropriately depending on whether they are real valued or
distributions over input variables.
"""
function _handle_inputs(inputs::AbstractVector{<:Sampleable})
return NamedDimsArray{(:features, :observations)}(hcat([mean(inputs[i]) for i in 1:size(inputs, 1)]...))
end

_handle_inputs(inputs::AbstractMatrix) = NamedDimsArray{(:features, :observations)}(inputs)

"""
FakeModel
Expand Down Expand Up @@ -119,58 +136,70 @@ function StatsBase.fit(
return FakeModel{E, O}(template.predictor, num_variates)
end

StatsBase.predict(m::FakeModel, inputs) = m.predictor(m.num_variates, inputs)
StatsBase.predict(m::FakeModel, inputs::AbstractMatrix) = m.predictor(m.num_variates, inputs)
StatsBase.predict(m::FakeModel, inputs::AbstractVector{<:Sampleable}) = m.predictor(m.num_variates, inputs)

"""
test_interface(template::Template; inputs=rand(5, 5), outputs=rand(5, 5))
Test that subtypes of [`Template`](@ref) and [`Model`](@ref) implement the expected API.
Can be used as an initial test to verify the API has been correctly implemented.
"""
function test_interface(template::Template; kwargs...)
function test_interface(
template::Template;
inputs=rand(5,5),
outputs=_default_outputs(template),
distribution_inputs=[MvNormal(5, m) for m in 1:5],
kwargs...
)
@testset "Models API Interface Test: $(nameof(typeof(template)))" begin
test_interface(template, estimate_type(template), output_type(template); kwargs...)
predictions = test_common(template, inputs, outputs)
test_estimate_type(estimate_type(template), predictions, outputs)
test_output_type(output_type(template), predictions, outputs)
test_predict_input_type(predict_input_type(template), template, outputs, inputs, distribution_inputs)
end
end

function test_interface(
template::Template, ::Type{PointEstimate}, ::Type{SingleOutput};
inputs=rand(5, 5), outputs=rand(1, 5),
)
predictions = test_common(template, inputs, outputs)
_default_outputs(template) = _default_outputs(output_type(template))
_default_outputs(::Type{SingleOutput}) = rand(1, 5)
_default_outputs(::Type{MultiOutput}) = rand(2, 5)

function test_estimate_type(::Type{PointEstimate}, predictions, outputs)
@test predictions isa NamedDimsArray{(:variates, :observations), <:Real, 2}
@test size(predictions) == size(outputs)
@test size(predictions, 1) == 1
end

function test_interface(
template::Template, ::Type{PointEstimate}, ::Type{MultiOutput};
inputs=rand(5, 5), outputs=rand(2, 5),
)
predictions = test_common(template, inputs, outputs)
@test predictions isa NamedDimsArray{(:variates, :observations), <:Real, 2}
@test size(predictions) == size(outputs)
function test_estimate_type(::Type{DistributionEstimate}, predictions, outputs)
@test predictions isa AbstractVector{<:ContinuousDistribution}
@test length(predictions) == size(outputs, 2)
end

function test_interface(
template::Template, ::Type{DistributionEstimate}, ::Type{SingleOutput};
inputs=rand(5, 5), outputs=rand(1, 5),
)
predictions = test_common(template, inputs, outputs)
@test predictions isa AbstractVector{<:ContinuousUnivariateDistribution}
@test length(predictions) == size(outputs, 2)
function test_output_type(::Type{SingleOutput}, predictions, outputs)
@test all(length.(predictions) .== size(outputs, 1))
@test all(length.(predictions) .== 1)
end

function test_interface(
template::Template, ::Type{DistributionEstimate}, ::Type{MultiOutput};
inputs=rand(5, 5), outputs=rand(3, 5)
)
predictions = test_common(template, inputs, outputs)
@test predictions isa AbstractVector{<:ContinuousMultivariateDistribution}
@test length(predictions) == size(outputs, 2)
@test all(length.(predictions) .== size(outputs, 1))
function test_output_type(::Type{MultiOutput}, predictions, outputs)
if eltype(predictions) <: Distribution
@test all(length.(predictions) .== size(outputs, 1))
@test all(length.(predictions) .> 1)
else
@test size(predictions, 1) == size(outputs, 1)
@test size(predictions, 1) > 1
end
end

function test_predict_input_type(::Type{PointPredictInput}, template, outputs, inputs, distribution_inputs)
model = fit(template, outputs, inputs)
@test hasmethod(Models.predict, (typeof(model), typeof(inputs)))
end

function test_predict_input_type(::Type{PointOrDistributionPredictInput}, template, outputs, inputs, distribution_inputs)
model = fit(template, outputs, inputs)
@test hasmethod(Models.predict, (typeof(model), typeof(distribution_inputs)))
predictions = predict(model, distribution_inputs)
test_estimate_type(estimate_type(template), predictions, outputs)
test_output_type(output_type(template), predictions, outputs)
end

function test_names(template, model)
Expand Down Expand Up @@ -210,6 +239,7 @@ function test_common(template, inputs, outputs)
@testset "traits" begin
@test estimate_type(template) == estimate_type(model)
@test output_type(template) == output_type(model)
@test predict_input_type(template) == predict_input_type(model)
end

@testset "submodels" begin
Expand Down
34 changes: 34 additions & 0 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,37 @@ Return the [`OutputTrait`] of the [`Model`](@ref) or [`Template`](@ref).
"""
output_type(::T) where T = output_type(T)
output_type(T::Type) = throw(MethodError(output_type, (T,))) # to prevent recursion

"""
PredictInputTrait
The `PredictInputTrait` specifies if the model supports point or distribution inputs to `predict`,
denoted by [`PointPredictInput`](@ref) or [`PointOrDistributionPredictInput`](@ref).
"""
abstract type PredictInputTrait end

"""
PointPredictInput <: PredictInputTrait
Specifies that the [`Model`](@ref) accepts real-valued input variables to `predict`.
"""
abstract type PointPredictInput <: PredictInputTrait end

"""
PointOrDistributionPredictInput <: PredictInputTrait
Specifies that the [`Model`](@ref) accepts real-values or a joint distribution over the input
variables to `predict`.
"""
abstract type PointOrDistributionPredictInput <: PredictInputTrait end

"""
predict_input_type(::T) where T = predict_input_type(T)
Return the [`PredictInputTrait`] of the [`Model`](@ref) or [`Template`](@ref).
"""
predict_input_type(::T) where T = predict_input_type(T)
predict_input_type(T::Type) = throw(MethodError(predict_input, (T,))) # to prevent recursion

predict_input_type(::Type{<:Model}) = PointPredictInput
predict_input_type(::Type{<:Template}) = PointPredictInput
35 changes: 23 additions & 12 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
@testset "test_utils.jl" begin

@testset "FakeTemplate{PointEstimate, SingleOutput}" begin
temp = FakeTemplate{PointEstimate, SingleOutput}()
test_interface(temp)
end
estimates = (PointEstimate, DistributionEstimate)
outputs = (SingleOutput, MultiOutput)

@testset "FakeTemplate{PointEstimate, MultiOutput}" begin
temp = FakeTemplate{PointEstimate, MultiOutput}()
@testset "$est, $out, PointPredictInput" for (est, out) in Iterators.product(estimates, outputs)
temp = FakeTemplate{est, out}()
test_interface(temp)
end
end

@testset "FakeTemplate{DistributionEstimate, SingleOutput}" begin
temp = FakeTemplate{DistributionEstimate, SingleOutput}()
test_interface(temp)
@testset "Vector inputs case" begin
temp = FakeTemplate{PointEstimate, SingleOutput}()
test_interface(temp; inputs=[rand(5), rand(5)], outputs=rand(1, 2))
end

@testset "FakeTemplate{DistributionEstimate, MultiOutput}" begin
temp = FakeTemplate{DistributionEstimate, MultiOutput}()
@testset "$est, $out, PointOrDistributionPredictInput" for (est, out) in Iterators.product(estimates, outputs)
Models.predict_input_type(m::Type{<:FakeTemplate}) = PointOrDistributionPredictInput
Models.predict_input_type(m::Type{<:FakeModel}) = PointOrDistributionPredictInput

temp = FakeTemplate{est, out}()
test_interface(temp)
end

@testset "Vector inputs case" begin
temp = FakeTemplate{PointEstimate, SingleOutput}()
test_interface(
temp;
inputs=[rand(5), rand(5)],
outputs=rand(1, 2),
distribution_inputs=[MvNormal(5, m) for m in 1:2]
)
end

end
Loading

2 comments on commit 7a62068

@BSnelling
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/31149

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.5 -m "<description of version>" 7a620689c22ef1ebc82e7ac8aaaa981c69b66b99
git push origin v0.2.5

Please sign in to comment.