Skip to content

Commit

Permalink
Merge pull request #963 from JuliaAI/feature-importances-wrappers
Browse files Browse the repository at this point in the history
Allow `Pipeline` and `TransformedTargetModel` to support feature importances
  • Loading branch information
ablaom authored Mar 17, 2024
2 parents 44ca101 + 059703a commit b9e6ac1
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "1.1.3"
version = "1.2.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
36 changes: 26 additions & 10 deletions src/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,13 @@ end

# ## Methods to extend a pipeline learning network

# The "front" of a pipeline network, as we grow it, consists of a
# "predict" and a "transform" node. Once the pipeline is complete
# (after a series of `extend` operations - see below) the "transform"
# node is what is used to deliver the output of `transform(pipe)` in
# the exported model, and the "predict" node is what will be used to
# deliver the output of `predict(pipe). Both nodes can be changed by
# `extend` but only the "active" node is propagated. Initially
# "transform" is active and "predict" only becomes active when a
# supervised model is encountered; this change is permanent.
# The "front" of a pipeline network, as we grow it, consists of a "predict" and a
# "transform" node. Once the pipeline is complete (after a series of `extend` operations -
# see below) the "transform" node is what is used to deliver the output of
# `transform(pipe, ...)` in the exported model, and the "predict" node is what will be
# used to deliver the output of `predict(pipe, ...). Both nodes can be changed by `extend`
# but only the "active" node is propagated. Initially "transform" is active and "predict"
# only becomes active when a supervised model is encountered; this change is permanent.
# https://github.com/JuliaAI/MLJClusteringInterface.jl/issues/10

abstract type ActiveNodeOperation end
Expand Down Expand Up @@ -587,7 +585,10 @@ end
# component, only its `abstract_type`. See comment at top of page.

MMI.supports_training_losses(pipe::SupervisedPipeline) =
MMI.supports_training_losses(getproperty(pipe, supervised_component_name(pipe)))
MMI.supports_training_losses(supervised_component(pipe))

MMI.reports_feature_importances(pipe::SupervisedPipeline) =
MMI.reports_feature_importances(supervised_component(pipe))

# This trait cannot be defined at the level of types (see previous comment):
function MMI.iteration_parameter(pipe::SupervisedPipeline)
Expand Down Expand Up @@ -618,3 +619,18 @@ function MMI.training_losses(pipe::SupervisedPipeline, pipe_report)
report = getproperty(pipe_report, supervised_name)
return training_losses(supervised, report)
end


# ## Feature importances

function feature_importances(pipe::SupervisedPipeline, fitresult, report)
# locate the machine associated with the supervised component:
supervised_name = MLJBase.supervised_component_name(pipe)
predict_node = fitresult.interface.predict
mach = only(MLJBase.machines_given_model(predict_node)[supervised_name])

# To extract the feature_importances, we can't do `feature_importances(mach)` because
# `mach.model` is just a symbol; instead we do:
supervised = MLJBase.supervised_component(pipe)
return feature_importances(supervised, mach.fitresult, mach.report[:fit])
end
25 changes: 20 additions & 5 deletions src/composition/models/transformed_target_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,28 +237,41 @@ const ERR_TT_MISSING_REPORT =
"Cannot find report for `TransformedTargetModel` atomic model, from which "*
"to extract training losses. "

function training_losses(composite::SomeTT, tt_report)
function MMI.training_losses(composite::SomeTT, tt_report)
hasproperty(tt_report, :model) || throw(ERR_TT_MISSING_REPORT)
atomic_report = getproperty(tt_report, :model)
return training_losses(composite.model, atomic_report)
end


# # FEATURE IMPORTANCES

function MMI.feature_importances(composite::SomeTT, fitresult, report)
# locate the machine associated with the supervised component:
predict_node = fitresult.interface.predict
mach = only(MLJBase.machines_given_model(predict_node)[:model])

# To extract the feature_importances, we can't do `feature_importances(mach)` because
# `mach.model` is just a symbol; instead we do:
return feature_importances(composite.model, mach.fitresult, mach.report[:fit])
end


## MODEL TRAITS

MMI.package_name(::Type{<:SomeTT}) = "MLJBase"
MMI.package_license(::Type{<:SomeTT}) = "MIT"
MMI.package_uuid(::Type{<:SomeTT}) = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MMI.is_wrapper(::Type{<:SomeTT}) = true
MMI.package_url(::Type{<:SomeTT}) =
"https://github.com/JuliaAI/MLJBase.jl"
MMI.package_url(::Type{<:SomeTT}) = "https://github.com/JuliaAI/MLJBase.jl"

for New in TT_TYPE_EXS
quote
MMI.iteration_parameter(::Type{<:$New{M}}) where M =
MLJBase.prepend(:model, iteration_parameter(M))
end |> eval
for trait in [:input_scitype,
for trait in [
:input_scitype,
:output_scitype,
:target_scitype,
:fit_data_scitype,
Expand All @@ -270,8 +283,10 @@ for New in TT_TYPE_EXS
:supports_class_weights,
:supports_online,
:supports_training_losses,
:reports_feature_importances,
:is_supervised,
:prediction_type]
:prediction_type
]
quote
MMI.$trait(::Type{<:$New{M}}) where M = MMI.$trait(M)
end |> eval
Expand Down
8 changes: 0 additions & 8 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -808,10 +808,6 @@ julia> fitted_params(mach).logistic_classifier
intercept = 0.0883301599726305,)
```
Additional keys, `machines` and `fitted_params_given_machine`, give a
list of *all* machines in the underlying network, and a dictionary of
fitted parameters keyed on those machines.
See also [`report`](@ref)
"""
Expand Down Expand Up @@ -852,10 +848,6 @@ julia> report(mach).linear_binary_classifier
```
Additional keys, `machines` and `report_given_machine`, give a
list of *all* machines in the underlying network, and a dictionary of
reports keyed on those machines.
See also [`fitted_params`](@ref)
"""
Expand Down
20 changes: 17 additions & 3 deletions test/_models/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor

import MLJBase
import MLJBase: @mlj_model, metadata_pkg, metadata_model

import MLJBase.Tables
using ScientificTypes

using CategoricalArrays
Expand Down Expand Up @@ -98,8 +98,11 @@ function MLJBase.fit(model::DecisionTreeClassifier, verbosity::Int, X, y)
#> empty values):

cache = nothing
report = (classes_seen=classes_seen,
print_tree=TreePrinter(tree))
report = (
classes_seen=classes_seen,
print_tree=TreePrinter(tree),
features=Tables.columnnames(Tables.columns(X)) |> collect,
)

return fitresult, cache, report
end
Expand Down Expand Up @@ -137,6 +140,17 @@ function MLJBase.predict(model::DecisionTreeClassifier
for i in 1:size(y_probabilities, 1)]
end

MLJBase.reports_feature_importances(::Type{<:DecisionTreeClassifier}) = true

function MMI.feature_importances(m::DecisionTreeClassifier, fitresult, report)
features = report.features
fi = DecisionTree.impurity_importance(first(fitresult), normalize=true)
fi_pairs = Pair.(features, fi)
# sort descending
sort!(fi_pairs, by= x->-x[2])

return fi_pairs
end

## REGRESSOR

Expand Down
8 changes: 4 additions & 4 deletions test/composition/models/network_composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
MLJBase.reporting_operations(::Type{<:ReportingScaler}) = (:transform, )

MLJBase.transform(model::ReportingScaler, _, X) = (
model.alpha*Tables.matrix(X),
Tables.table(model.alpha*Tables.matrix(X)),
(; nrows = size(MLJBase.matrix(X))[1]),
)

Expand Down Expand Up @@ -143,7 +143,7 @@ composite = WatermelonComposite(
Set([:scaler, :clusterer, :classifier1, :training_loss, :len])
@test fitr.scaler == (nrows=10,)
@test fitr.clusterer == (labels=['A', 'B', 'C'],)
@test Set(keys(fitr.classifier1)) == Set([:classes_seen, :print_tree])
@test Set(keys(fitr.classifier1)) == Set([:classes_seen, :print_tree, :features])
@test fitr.training_loss isa Real
@test fitr.len == 10

Expand All @@ -164,7 +164,7 @@ composite = WatermelonComposite(
Set([:scaler, :clusterer, :classifier1, :finalizer])
@test predictr.scaler == (nrows=5,)
@test predictr.clusterer == (labels=['A', 'B', 'C'],)
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree])
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree, :features])
@test predictr.finalizer == (nrows=5,)

o, predictr = predict(composite, f, selectrows(X, 1:2))
Expand All @@ -174,7 +174,7 @@ composite = WatermelonComposite(
Set([:scaler, :clusterer, :classifier1, :finalizer])
@test predictr.scaler == (nrows=2,) # <----------- different
@test predictr.clusterer == (labels=['A', 'B', 'C'],)
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree])
@test Set(keys(predictr.classifier1)) == Set([:classes_seen, :print_tree, :features])
@test predictr.finalizer == (nrows=2,) # <---------- different

r = MMI.report(composite, Dict(:fit => fitr, :predict=> predictr))
Expand Down
10 changes: 10 additions & 0 deletions test/composition/models/pipelines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,16 @@ end
rm(filename)
end

@testset "feature importances" begin
# the DecisionTreeClassifier in /test/_models/ supports feature importances.
pipe = Standardizer |> DecisionTreeClassifier()
@test reports_feature_importances(pipe)
X, y = @load_iris
fitresult, _, report = MLJBase.fit(pipe, 0, X, y)
features = first.(feature_importances(pipe, fitresult, report))
@test Set(features) == Set(keys(X))
end

end # module

true
9 changes: 9 additions & 0 deletions test/composition/models/transformed_target_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,14 @@ y = rand(5)
@test training_losses(mach) == ones(5)
end

@testset "feature_importances" begin
X, y = @load_iris
atom = DecisionTreeClassifier()
model = TransformedTargetModel(atom, transformer=identity, inverse=identity)
@test reports_feature_importances(model)
fitresult, _, rpt = MMI.fit(model, 0, X, y)
@test Set(first.(feature_importances(model, fitresult, rpt))) == Set(keys(X))
end

end
true

0 comments on commit b9e6ac1

Please sign in to comment.