-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodels.py
More file actions
99 lines (78 loc) · 2.09 KB
/
models.py
File metadata and controls
99 lines (78 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from sklearn.externals import joblib
from sklearn.linear_model import BayesianRidge, LinearRegression, ElasticNet, Lasso
from sklearn.svm import SVR
from sklearn.ensemble.gradient_boosting import GradientBoostingRegressor
import os
default_model_dir = 'models'
_models = {
'bayesian_ridge': BayesianRidge(),
'linear_regression': LinearRegression(),
'elastic_net': ElasticNet(),
'lasso': Lasso(),
'svr': SVR(kernel='linear'),
'gbr': GradientBoostingRegressor(n_estimators=300, max_depth=5)
}
def get_model_names():
"""
Get supported model names.
:return:
"""
return list(_models.keys())
def get_models(model_name):
"""
Get models.
:param model_name:
:return:
"""
if isinstance(model_name, list):
return dict([(i, _models[i]) for i in model_name if i in _models])
elif model_name in _models:
return _models[model_name]
else:
return None
def save_model(model, model_name, out_dir=None):
"""
Save model to file.
:param model:
:param model_name:
:param out_dir:
:return:
"""
if not out_dir:
out_dir = default_model_dir
if not os.path.exists(out_dir):
os.mkdir(out_dir)
outf = out_dir + os.path.sep + model_name + '.model'
joblib.dump(model, outf)
def save_models(models, model_names, out_dir=None):
"""
Save models to out_dir.
:param models:
:param model_names:
:param out_dir:
:return:
"""
for i in range(len(model_names)):
if i >= len(models):
break
save_model(models[i], model_names[i], out_dir)
def load_model(model_name, out_dir=None):
"""
Load saved model from model directory.
:param model_name:
:param out_dir:
:return:
"""
if not out_dir:
out_dir = default_model_dir
outf = out_dir + os.path.sep + model_name + '.model'
if os.path.exists(outf):
return load_model_from_file(outf)
return None
def load_model_from_file(filepath):
"""
Load saved models from file.
:param filepath:
:return:
"""
return joblib.load(filepath)