Skip to content

Commit

Permalink
Merge pull request #973 from JuliaAI/compact-performance-evaluations
Browse files Browse the repository at this point in the history
Add `CompactPerformanceEvaluation` objects and the option in `evaluate!` to construct them
  • Loading branch information
ablaom authored Apr 30, 2024
2 parents 6e77d6a + c83cae9 commit f811dc3
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
env:
JULIA_NUM_THREADS: 2
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
- uses: codecov/codecov-action@v3
with:
file: lcov.info
docs:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "1.2.1"
version = "1.3"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
2 changes: 1 addition & 1 deletion src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ export TransformedTargetModel

# resampling.jl:
export ResamplingStrategy, Holdout, CV, StratifiedCV, TimeSeriesCV,
evaluate!, Resampler, PerformanceEvaluation
evaluate!, Resampler, PerformanceEvaluation, CompactPerformanceEvaluation

# `MLJType` and the abstract `Model` subtypes are exported from within
# src/composition/abstract_types.jl
Expand Down
8 changes: 6 additions & 2 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ mutable struct Machine{M,OM,C} <: MLJType

model::M
old_model::OM # for remembering the model used in last call to `fit!`

# the next two refer to objects returned by `MLJModlelInterface.fit(::M, ...)`.
fitresult
cache
cache # relevant to `MLJModelInterface.update`, not to be confused with type param `C`

# training arguments (`Node`s or user-specified data wrapped in
# `Source`s):
Expand Down Expand Up @@ -81,7 +83,7 @@ mutable struct Machine{M,OM,C} <: MLJType
# In the case of symbolic model, machine cannot know the type of model to be fit
# at time of construction:
OM = M == Symbol ? Any : M
mach = new{M,OM,cache}(model)
mach = new{M,OM,cache}(model) # (this `cache` is not the *field* `cache`)
mach.frozen = false
mach.state = 0
mach.args = args
Expand All @@ -92,6 +94,8 @@ mutable struct Machine{M,OM,C} <: MLJType

end

caches_data(::Machine{<:Any, <:Any, C}) where C = C

"""
age(mach::Machine)
Expand Down
173 changes: 140 additions & 33 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ Stratified cross-validation resampling strategy, for use in
problems (`OrderedFactor` or `Multiclass` targets).
train_test_pairs(stratified_cv, rows, y)
Returns an `nfolds`-length iterator of `(train, test)` pairs of
vectors (row indices) where each `train` and `test` is a sub-vector of
`rows`. The `test` vectors are mutually exclusive and exhaust
Expand Down Expand Up @@ -465,12 +464,18 @@ end
# ================================================================
## EVALUATION RESULT TYPE

abstract type AbstractPerformanceEvaluation <: MLJType end

"""
PerformanceEvaluation
PerformanceEvaluation <: AbstractPerformanceEvaluation
Type of object returned by [`evaluate`](@ref) (for models plus data) or
[`evaluate!`](@ref) (for machines). Such objects encode estimates of the performance
(generalization error) of a supervised model or outlier detection model.
(generalization error) of a supervised model or outlier detection model, and store other
information ancillary to the computation.
If [`evaluate`](@ref) or [`evaluate!`](@ref) is called with the `compact=true` option,
then a [`CompactPerformanceEvaluation`](@ref) object is returned instead.
When `evaluate`/`evaluate!` is called, a number of train/test pairs ("folds") of row
indices are generated, according to the options provided, which are discussed in the
Expand All @@ -479,7 +484,7 @@ pairs are recorded in the `train_test_rows` field of the `PerformanceEvaluation`
and the corresponding estimates, aggregated over all train/test pairs, are recorded in
`measurement`, a vector with one entry for each measure (metric) recorded in `measure`.
When displayed, a `PerformanceEvalution` object includes a value under the heading
When displayed, a `PerformanceEvaluation` object includes a value under the heading
`1.96*SE`, derived from the standard error of the `per_fold` entries. This value is
suitable for constructing a formal 95% confidence interval for the given
`measurement`. Such intervals should be interpreted with caution. See, for example, Bates
Expand Down Expand Up @@ -526,10 +531,13 @@ These fields are part of the public API of the `PerformanceEvaluation` struct.
and `test` are vectors of row (observation) indices for training and evaluation
respectively.
- `resampling`: the resampling strategy used to generate the train/test pairs.
- `resampling`: the user-specified resampling strategy to generate the train/test pairs
(or literal train/test pairs if that was directly specified).
- `repeats`: the number of times the resampling strategy was repeated.
See also [`CompactPerformanceEvaluation`](@ref).
"""
struct PerformanceEvaluation{M,
Measure,
Expand All @@ -539,7 +547,7 @@ struct PerformanceEvaluation{M,
PerObservation,
FittedParamsPerFold,
ReportPerFold,
R} <: MLJType
R} <: AbstractPerformanceEvaluation
model::M
measure::Measure
measurement::Measurement
Expand All @@ -553,6 +561,47 @@ struct PerformanceEvaluation{M,
repeats::Int
end

"""
CompactPerformanceEvaluation <: AbstractPerformanceEvaluation
Type of object returned by [`evaluate`](@ref) (for models plus data) or
[`evaluate!`](@ref) (for machines) when called with the option `compact = true`. Such
objects have the same structure as the [`PerformanceEvaluation`](@ref) objects returned by
default, except that the following fields are omitted to save memory:
`fitted_params_per_fold`, `report_per_fold`, `train_test_rows`.
For more on the remaining fields, see [`PerformanceEvaluation`](@ref).
"""
struct CompactPerformanceEvaluation{M,
Measure,
Measurement,
Operation,
PerFold,
PerObservation,
R} <: AbstractPerformanceEvaluation
model::M
measure::Measure
measurement::Measurement
operation::Operation
per_fold::PerFold
per_observation::PerObservation
resampling::R
repeats::Int
end

compactify(e::CompactPerformanceEvaluation) = e
compactify(e::PerformanceEvaluation) = CompactPerformanceEvaluation(
e.model,
e.measure,
e.measurement,
e.operation,
e.per_fold,
e. per_observation,
e.resampling,
e.repeats,
)

# pretty printing:
round3(x) = x
round3(x::AbstractFloat) = round(x, sigdigits=3)
Expand All @@ -562,7 +611,7 @@ const SE_FACTOR = 1.96 # For a 95% confidence interval.
_standard_error(v::AbstractVector{<:Real}) = SE_FACTOR*std(v) / sqrt(length(v) - 1)
_standard_error(v) = "N/A"

function _standard_errors(e::PerformanceEvaluation)
function _standard_errors(e::AbstractPerformanceEvaluation)
measure = e.measure
length(e.per_fold[1]) == 1 && return [nothing]
std_errors = map(_standard_error, e.per_fold)
Expand All @@ -573,42 +622,81 @@ end
_repr_(f::Function) = repr(f)
_repr_(x) = repr("text/plain", x)

function Base.show(io::IO, ::MIME"text/plain", e::PerformanceEvaluation)
# helper for row labels: _label(1) ="A", _label(2) = "B", _label(27) = "BA", etc
const alphabet = Char.(65:90)
_label(i) = map(digits(i - 1, base=26)) do d alphabet[d + 1] end |> join |> reverse

function Base.show(io::IO, ::MIME"text/plain", e::AbstractPerformanceEvaluation)
_measure = [_repr_(m) for m in e.measure]
_measurement = round3.(e.measurement)
_per_fold = [round3.(v) for v in e.per_fold]
_sterr = round3.(_standard_errors(e))
row_labels = _label.(eachindex(e.measure))

# Only show the standard error if the number of folds is higher than 1.
show_sterr = any(!isnothing, _sterr)
data = show_sterr ?
hcat(_measure, e.operation, _measurement, _sterr, _per_fold) :
hcat(_measure, e.operation, _measurement, _per_fold)
header = show_sterr ?
["measure", "operation", "measurement", "1.96*SE", "per_fold"] :
["measure", "operation", "measurement", "per_fold"]
# Define header and data for main table

println(io, "PerformanceEvaluation object "*
data = hcat(_measure, e.operation, _measurement)
header = ["measure", "operation", "measurement"]
if length(row_labels) > 1
data = hcat(row_labels, data)
header =["", header...]
end

if e isa PerformanceEvaluation
println(io, "PerformanceEvaluation object "*
"with these fields:")
println(io, " model, measure, operation, measurement, per_fold,\n"*
" per_observation, fitted_params_per_fold,\n"*
" report_per_fold, train_test_rows, resampling, repeats")
println(io, " model, measure, operation,\n"*
" measurement, per_fold, per_observation,\n"*
" fitted_params_per_fold, report_per_fold,\n"*
" train_test_rows, resampling, repeats")
else
println(io, "CompactPerformanceEvaluation object "*
"with these fields:")
println(io, " model, measure, operation,\n"*
" measurement, per_fold, per_observation,\n"*
" train_test_rows, resampling, repeats")
end

println(io, "Extract:")
show_color = MLJBase.SHOW_COLOR[]
color_off()
PrettyTables.pretty_table(io,
data;
header,
header_crayon=PrettyTables.Crayon(bold=false),
alignment=:l,
linebreaks=true)
PrettyTables.pretty_table(
io,
data;
header,
header_crayon=PrettyTables.Crayon(bold=false),
alignment=:l,
linebreaks=true,
)

# Show the per-fold table if needed:

if length(first(e.per_fold)) > 1
show_sterr = any(!isnothing, _sterr)
data2 = hcat(_per_fold, _sterr)
header2 = ["per_fold", "1.96*SE"]
if length(row_labels) > 1
data2 = hcat(row_labels, data2)
header2 =["", header2...]
end
PrettyTables.pretty_table(
io,
data2;
header=header2,
header_crayon=PrettyTables.Crayon(bold=false),
alignment=:l,
linebreaks=true,
)
end
show_color ? color_on() : color_off()
end

function Base.show(io::IO, e::PerformanceEvaluation)
summary = Tuple(round3.(e.measurement))
print(io, "PerformanceEvaluation$summary")
end
_summary(e) = Tuple(round3.(e.measurement))
Base.show(io::IO, e::PerformanceEvaluation) =
print(io, "PerformanceEvaluation$(_summary(e))")
Base.show(io::IO, e::CompactPerformanceEvaluation) =
print(io, "CompactPerformanceEvaluation$(_summary(e))")


# ===============================================================
## EVALUATION METHODS
Expand Down Expand Up @@ -931,7 +1019,11 @@ Although `evaluate!` is mutating, `mach.model` and `mach.args` are not mutated.
- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref))
See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref)
- `compact=false` - if `true`, the returned evaluation object excludes these fields:
`fitted_params_per_fold`, `report_per_fold`, `train_test_rows`.
See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref),
[`CompactPerformanceEvaluation`](@ref).
"""
function evaluate!(
Expand All @@ -951,6 +1043,7 @@ function evaluate!(
per_observation=true,
verbosity=1,
logger=nothing,
compact=false,
)

# this method just checks validity of options, preprocess the
Expand Down Expand Up @@ -1017,6 +1110,7 @@ function evaluate!(
per_observation,
logger,
resampling,
compact,
)
end

Expand Down Expand Up @@ -1198,6 +1292,7 @@ function evaluate!(
per_observation_flag,
logger,
user_resampling,
compact,
)

# Note: `user_resampling` keyword argument is the user-defined resampling strategy,
Expand Down Expand Up @@ -1352,7 +1447,8 @@ function evaluate!(
)
log_evaluation(logger, evaluation)

evaluation
compact && return compactify(evaluation)
return evaluation
end

# ----------------------------------------------------------------
Expand Down Expand Up @@ -1399,6 +1495,7 @@ end
check_measure=true,
per_observation=true,
logger=nothing,
compact=false,
)
Resampling model wrapper, used internally by the `fit` method of `TunedModel` instances
Expand Down Expand Up @@ -1442,6 +1539,7 @@ mutable struct Resampler{S, L} <: Model
cache::Bool
per_observation::Bool
logger::L
compact::Bool
end

# Some traits are markded as `missing` because we cannot determine
Expand Down Expand Up @@ -1485,6 +1583,7 @@ function Resampler(
cache=true,
per_observation=true,
logger=nothing,
compact=false,
)
resampler = Resampler(
model,
Expand All @@ -1499,6 +1598,7 @@ function Resampler(
cache,
per_observation,
logger,
compact,
)
message = MLJModelInterface.clean!(resampler)
isempty(message) || @warn message
Expand Down Expand Up @@ -1532,6 +1632,10 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...)

_acceleration = _process_accel_settings(resampler.acceleration)

# the value of `compact` below is always `false`, because we need
# `e.train_test_rows` in `update`. (If `resampler.compact=true`, then
# `evaluate(resampler, ...)` returns the compactified version of the current
# `PerformanceEvaluation` object.)
e = evaluate!(
mach,
resampler.resampling,
Expand All @@ -1547,6 +1651,7 @@ function MLJModelInterface.fit(resampler::Resampler, verbosity::Int, args...)
resampler.per_observation,
resampler.logger,
resampler.resampling,
false, # compact
)

fitresult = (machine = mach, evaluation = e)
Expand Down Expand Up @@ -1620,6 +1725,7 @@ function MLJModelInterface.update(
resampler.per_observation,
resampler.logger,
resampler.resampling,
false # we use `compact=false`; see comment in `fit` above
)
report = (evaluation = e, )
fitresult = (machine=mach2, evaluation=e)
Expand All @@ -1643,7 +1749,8 @@ StatisticalTraits.load_path(::Type{<:Resampler}) = "MLJBase.Resampler"

fitted_params(::Resampler, fitresult) = fitresult

evaluate(resampler::Resampler, fitresult) = fitresult.evaluation
evaluate(resampler::Resampler, fitresult) = resampler.compact ?
compactify(fitresult.evaluation) : fitresult.evaluation

function evaluate(machine::Machine{<:Resampler})
if isdefined(machine, :fitresult)
Expand Down
Loading

0 comments on commit f811dc3

Please sign in to comment.