-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
69 lines (52 loc) · 2.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pickle
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
def convert_data_and_label_to_train(data_df, label_df):
df = pd.merge(data_df, label_df, on='area')
train_label = df['mean_value']
train_data = df.drop('mean_value', 1)
del train_data['area']
return train_data, train_label
def linear_reg_test(train_data, train_label, test_data, index):
"""
线性回归
"""
lr = LinearRegression(normalize=True)
lr.fit(train_data, train_label)
print(lr.intercept_, '\n', lr.coef_)
y1 = lr.predict(test_data)
y1 = pd.Series(y1)
y1.index = index
with pd.option_context('display.max_rows', None, 'display.max_columns', None): # more options
print(y1)
return y1, lr
def polynomial_reg_test(train_data, train_label, test_data, index):
"""
多项式回归
"""
quadratic_featurizer = PolynomialFeatures(degree=2)
X_train_quadratic = quadratic_featurizer.fit_transform(train_data)
regressor_quadratic = LinearRegression(normalize=True)
regressor_quadratic.fit(X_train_quadratic, train_label)
print(regressor_quadratic.intercept_, '\n', regressor_quadratic.coef_)
test_data = quadratic_featurizer.fit_transform(test_data)
y1 = regressor_quadratic.predict(test_data)
y1 = pd.Series(y1)
y1.index = index
with pd.option_context('display.max_rows', None, 'display.max_columns', None): # more options
print(y1)
return y1, regressor_quadratic
if __name__ == "__main__":
df = pd.read_excel('../ndf20190919.xlsx')
# train data
Y = df["原评分"]
Y.fillna(0, inplace=True)
X = df[["自行处理案件总数", "其他案件总数", "强制结案总数",
"立案耗时总长(分钟)", "计划内耗时总长(分钟)", "计划外耗时总长(分钟)"]]
# 线性回归
# linear_reg_test(X, Y, X, X.index)
# 多项式回归
_, model = polynomial_reg_test(X, Y, X, X.index)
pickle.dump(model, open('../polynomial.sav', 'wb'))