-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_inference.py
31 lines (22 loc) · 968 Bytes
/
demo_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import joblib
import pandas as pd
from flare.preprocessing import VaniilaLGBMPreprocessor
from flare.inference import ProbabilisticBinaryClassifier
data = pd.read_csv("merged_data/brfss_combine_test_v2.csv")
preprocessor = VaniilaLGBMPreprocessor(data_mode="testing")
data = preprocessor.preprocess_df(data)
# Numpy data is acceptable for inference
# we dont't need to provide the column information for the model
# In other words, we don't need to wrap the data in a DataFrame
first_row = data.loc[[0]]
first_row = first_row.to_numpy()
# Here, we use an Inference Wrapper to adjust the threshold for inference
model = joblib.load("models/LGBMClassifier-testing-2022-07-04 01:36:33.461375.pkl")
model = ProbabilisticBinaryClassifier(model, prob_threshold=0.3)
# You can use sklearn-compatible API to do infernence
pred = model.predict(first_row)
print(pred)
# Or we also support Pytorch-compatible API (callable)
pred = model(first_row)
print(pred)