Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise predict to allow all 'type' #172

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
57 changes: 53 additions & 4 deletions src/booster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,20 +275,17 @@ function deserialize(::Type{Booster}, buf::AbstractVector{UInt8}, data=DMatrix[]
deserialize!(b, buf)
end


# sadly this is type unstable because we might return a transpose
"""
predict(b::Booster, data; margin=false, training=false, ntree_limit=0)

Use the model `b` to run predictions on `data`. This will return a `Vector{Float32}` which can be compared
to training or test target data.

If `ntree_limit > 0` only the first `ntree_limit` trees will be used in prediction.

## Examples
```julia
(X, y) = (randn(100,3), randn(100))
b = xgboost((X, y), 10)

ŷ = predict(b, X)
```
"""
Expand All @@ -314,6 +311,58 @@ function predict(b::Booster, Xy::DMatrix;
end
predict(b::Booster, Xy; kw...) = predict(b, DMatrix(Xy); kw...)


"""
predictbytype(b::Booster, data::DMatrix; type=0, training=false, ntree_limit=0)

Use the model `b` to run predictions on `data`.

This version of predict gives access to contribution and interaction values.

If `ntree_limit > 0` only the first `ntree_limit` trees will be used in prediction.

The 'type' parameter conforms to prediction types specified in the XGBoost documentation.
Options include:
0 => normal (default)
1 => output margin
2 => predict contribution
3 => predict approximate contribution
4 => predict feature interactions
5 => predict approximate feature interactions
6 => predict leaf training (see XGBoost documentation)

The shape of returned data varies with 'type' option and certain objectives.

## Examples
```julia
(X, y) = (randn(100,3), randn(100))
b = xgboost((X, y), 10)

ŷ = predict(b, X, type=2)
```
"""
function predictbytype(b::Booster, Xy::DMatrix;
type::Integer=0, # 0-normal, 1-margin, 2-contrib, 3-est. contrib,4-interact,5-est. interact, 6-leaf
training::Bool=false,
ntree_lower_limit::Integer=0,
ntree_limit::Integer=0, # 0 corresponds to no limit
)
opts = Dict("type"=>type ,
"iteration_begin"=>ntree_lower_limit,
"iteration_end"=>ntree_limit,
"strict_shape"=>false,
"training"=>training,
) |> JSON3.write
oshape = Ref{Ptr{Lib.bst_ulong}}()
odim = Ref{Lib.bst_ulong}()
o = Ref{Ptr{Cfloat}}()
xgbcall(XGBoosterPredictFromDMatrix, b.handle, Xy.handle, opts, oshape, odim, o)
dims = reverse(unsafe_wrap(Array, oshape[], odim[]))
o = unsafe_wrap(Array, o[], tuple(dims...))
length(dims) > 1 ? permutedims(o, reverse(1:ndims(o))) : o # permutedims to handle ndims>=3
end


function evaliter(b::Booster, watch, n::Integer=1)
o = Ref{Ptr{Int8}}()
names = collect(Iterators.map(string, keys(watch)))
Expand Down