Skip to content

Commit

Permalink
Merge pull request #980 from JuliaAI/constructor
Browse files Browse the repository at this point in the history
Have wrappers overload `constructor` trait
  • Loading branch information
ablaom authored Jun 3, 2024
2 parents 14441aa + 52d16df commit e7afc34
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 23 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "1.3"
version = "1.4.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down Expand Up @@ -47,7 +47,7 @@ DelimitedFiles = "1"
Distributions = "0.25.3"
InvertedIndices = "1"
LearnAPI = "0.1"
MLJModelInterface = "1.7"
MLJModelInterface = "1.10"
Missings = "0.4, 1"
OrderedCollections = "1.1"
Parameters = "0.12"
Expand All @@ -58,7 +58,7 @@ Reexport = "1.2"
ScientificTypes = "3"
StatisticalMeasures = "0.1.1"
StatisticalMeasuresBase = "0.1.1"
StatisticalTraits = "3.2"
StatisticalTraits = "3.3"
Statistics = "1"
StatsBase = "0.32, 0.33, 0.34"
Tables = "0.2, 1.0"
Expand Down
3 changes: 3 additions & 0 deletions src/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,9 @@ end

MMI.target_scitype(p::SupervisedPipeline) = target_scitype(supervised_component(p))

MMI.package_name(::Type{<:SomePipeline}) = "MLJBase"
MMI.load_path(::Type{<:SomePipeline}) = "MLJBase.Pipeline"
MMI.constructor(::Type{<:SomePipeline}) = Pipeline

# ## Training losses

Expand Down
7 changes: 4 additions & 3 deletions src/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,16 +264,17 @@ function Base.setproperty!(stack::Stack{modelnames}, _name::Symbol, val) where m
end


# # TRAITS

MMI.target_scitype(::Type{<:Stack{modelnames, input_scitype, target_scitype}}) where
{modelnames, input_scitype, target_scitype} = target_scitype


MMI.input_scitype(::Type{<:Stack{modelnames, input_scitype, target_scitype}}) where
{modelnames, input_scitype, target_scitype} = input_scitype


MLJBase.load_path(::Type{<:ProbabilisticStack}) = "MLJBase.ProbabilisticStack"
MLJBase.load_path(::Type{<:DeterministicStack}) = "MLJBase.DeterministicStack"
MMI.constructor(::Type{<:Stack}) = Stack
MLJBase.load_path(::Type{<:Stack}) = "MLJBase.Stack"
MLJBase.package_name(::Type{<:Stack}) = "MLJBase"
MLJBase.package_uuid(::Type{<:Stack}) = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJBase.package_url(::Type{<:Stack}) = "https://github.com/JuliaAI/MLJBase.jl"
Expand Down
7 changes: 6 additions & 1 deletion src/composition/models/transformed_target_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ const TT_SUPPORTED_ATOMS = (
:Deterministic,
:DeterministicUnsupervisedDetector,
:DeterministicSupervisedDetector,
:Interval)
:Interval,
)

# Each supported atomic type gets its own wrapper:

Expand Down Expand Up @@ -265,6 +266,10 @@ MMI.package_uuid(::Type{<:SomeTT}) = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MMI.is_wrapper(::Type{<:SomeTT}) = true
MMI.package_url(::Type{<:SomeTT}) = "https://github.com/JuliaAI/MLJBase.jl"

MMI.load_path(::Type{<:SomeTT}) = "MLJBase.TransformedTargetModel"
MMI.constructor(::Type{<:SomeTT}) = TransformedTargetModel


for New in TT_TYPE_EXS
quote
MMI.iteration_parameter(::Type{<:$New{M}}) where M =
Expand Down
27 changes: 12 additions & 15 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1548,9 +1548,11 @@ end
compact=false,
)
*Private method.* Use at own risk.
Resampling model wrapper, used internally by the `fit` method of `TunedModel` instances
and `IteratedModel` instances. See [`evaluate!](@ref) for options. Not intended for use by
general user, who will ordinarily use [`evaluate!`](@ref) directly.
and `IteratedModel` instances. See [`evaluate!`](@ref) for meaning of the options. Not
intended for use by general user, who will ordinarily use [`evaluate!`](@ref) directly.
Given a machine `mach = machine(resampler, args...)` one obtains a performance evaluation
of the specified `model`, performed according to the prescribed `resampling` strategy and
Expand Down Expand Up @@ -1592,16 +1594,6 @@ mutable struct Resampler{S, L} <: Model
compact::Bool
end

# Some traits are markded as `missing` because we cannot determine
# them from from the type because we have removed `M` (for "model"} as
# a `Resampler` type parameter. See
# https://github.com/JuliaAI/MLJTuning.jl/issues/141#issue-951221466

StatisticalTraits.is_wrapper(::Type{<:Resampler}) = true
StatisticalTraits.supports_weights(::Type{<:Resampler}) = missing
StatisticalTraits.supports_class_weights(::Type{<:Resampler}) = missing
StatisticalTraits.is_pure_julia(::Type{<:Resampler}) = true

function MLJModelInterface.clean!(resampler::Resampler)
warning = ""
if resampler.measure === nothing && resampler.model !== nothing
Expand Down Expand Up @@ -1787,11 +1779,16 @@ function MLJModelInterface.update(

end

# The input and target scitypes cannot be determined from the type
# because we have removed `M` (for "model") as a `Resampler` type
# parameter. See
# Some traits are marked as `missing` because we cannot determine
# them from from the type because we have removed `M` (for "model"} as
# a `Resampler` type parameter. See
# https://github.com/JuliaAI/MLJTuning.jl/issues/141#issue-951221466

StatisticalTraits.is_wrapper(::Type{<:Resampler}) = true
StatisticalTraits.supports_weights(::Type{<:Resampler}) = missing
StatisticalTraits.supports_class_weights(::Type{<:Resampler}) = missing
StatisticalTraits.is_pure_julia(::Type{<:Resampler}) = true
StatisticalTraits.constructor(::Type{<:Resampler}) = Resampler
StatisticalTraits.input_scitype(::Type{<:Resampler}) = Unknown
StatisticalTraits.target_scitype(::Type{<:Resampler}) = Unknown
StatisticalTraits.package_name(::Type{<:Resampler}) = "MLJBase"
Expand Down
4 changes: 3 additions & 1 deletion test/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ end

@testset "public constructor" begin
# un-named components:
@test Pipeline(m, t, u) isa UnsupervisedPipeline
flute = Pipeline(m, t, u)
@test flute isa UnsupervisedPipeline
@test MLJBase.constructor(flute) == Pipeline
@test Pipeline(m, t, u, p) isa ProbabilisticPipeline
@test Pipeline(m, t, u, p, operation=predict_mean) isa DeterministicPipeline
@test Pipeline(u, p, u, operation=predict_mean) isa DeterministicPipeline
Expand Down
3 changes: 3 additions & 0 deletions test/composition/models/stacking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ end
measures=rmse,
resampling=CV(;nfolds=3),
models...)

@test MLJBase.constructor(mystack) == Stack

@test mystack.ridge_lambda.lambda == 0.1
@test mystack.metalearner isa FooBarRegressor
@test mystack.resampling isa CV
Expand Down
1 change: 1 addition & 0 deletions test/composition/models/transformed_target_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ avg_nonlinear = g(mean(f(y))) # = g(mean(z))

# Test wrapping using f and g:
model = TransformedTargetModel(atom, transformer=f, inverse=g)
@test MLJBase.constructor(model) == TransformedTargetModel
fr1, _, _ = MMI.fit(model, 0, X, y)
@test first(predict(model, fr1, X)) fill(avg_nonlinear, 5)

Expand Down
3 changes: 3 additions & 0 deletions test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ end
holdout = Holdout(fraction_train=0.75)
resampler = Resampler(resampling=holdout, model=ridge_model, measure=mae,
acceleration=accel)
@test constructor(resampler) == Resampler
@test package_name(resampler) == "MLJBase"
@test load_path(resampler) == "MLJBase.Resampler"
resampling_machine = machine(resampler, X, y)
@test_logs((:info, r"^Training"), fit!(resampling_machine))
e1=evaluate(resampling_machine).measurement[1]
Expand Down

0 comments on commit e7afc34

Please sign in to comment.