SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
machine
Machine.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 1999-2009 Soeren Sonnenburg
8
* Written (W) 2011-2012 Heiko Strathmann
9
* Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
10
*/
11
12
#include <
shogun/machine/Machine.h
>
13
#include <
shogun/base/Parameter.h
>
14
#include <
shogun/mathematics/Math.h
>
15
#include <
shogun/base/ParameterMap.h
>
16
17
using namespace
shogun;
18
19
CMachine::CMachine
() :
CSGObject
(), m_max_train_time(0), m_labels(NULL),
20
m_solver_type(
ST_AUTO
)
21
{
22
m_data_locked
=
false
;
23
m_store_model_features
=
false
;
24
25
SG_ADD
(&
m_max_train_time
,
"max_train_time"
,
26
"Maximum training time."
,
MS_NOT_AVAILABLE
);
27
SG_ADD
((
machine_int_t
*) &
m_solver_type
,
"solver_type"
,
28
"Type of solver."
,
MS_NOT_AVAILABLE
);
29
SG_ADD
((
CSGObject
**) &
m_labels
,
"labels"
,
30
"Labels to be used."
,
MS_NOT_AVAILABLE
);
31
SG_ADD
(&
m_store_model_features
,
"store_model_features"
,
32
"Should feature data of model be stored after training?"
,
MS_NOT_AVAILABLE
);
33
SG_ADD
(&
m_data_locked
,
"data_locked"
,
34
"Indicates whether data is locked"
,
MS_NOT_AVAILABLE
);
35
36
m_parameter_map
->
put
(
37
new
SGParamInfo
(
"data_locked"
, CT_SCALAR, ST_NONE, PT_BOOL, 1),
38
new
SGParamInfo
()
39
);
40
41
m_parameter_map
->
finalize_map
();
42
}
43
44
CMachine::~CMachine
()
45
{
46
SG_UNREF
(
m_labels
);
47
}
48
49
bool
CMachine::train
(
CFeatures
* data)
50
{
51
/* not allowed to train on locked data */
52
if
(
m_data_locked
)
53
{
54
SG_ERROR
(
"%s::train data_lock() was called, only train_locked() is"
55
" possible. Call data_unlock if you want to call train()\n"
,
56
get_name
());
57
}
58
59
if
(
train_require_labels
())
60
{
61
if
(
m_labels
== NULL)
62
SG_ERROR
(
"%s@%p: No labels given"
,
get_name
(),
this
)
63
64
m_labels
->
ensure_valid
(
get_name
());
65
}
66
67
bool
result =
train_machine
(data);
68
69
if
(
m_store_model_features
)
70
store_model_features
();
71
72
return
result;
73
}
74
75
void
CMachine::set_labels
(
CLabels
* lab)
76
{
77
if
(lab != NULL)
78
if
(!
is_label_valid
(lab))
79
SG_ERROR
(
"Invalid label for %s"
,
get_name
())
80
81
SG_REF
(lab);
82
SG_UNREF
(
m_labels
);
83
m_labels
= lab;
84
}
85
86
CLabels
*
CMachine::get_labels
()
87
{
88
SG_REF
(
m_labels
);
89
return
m_labels
;
90
}
91
92
void
CMachine::set_max_train_time
(
float64_t
t)
93
{
94
m_max_train_time
= t;
95
}
96
97
float64_t
CMachine::get_max_train_time
()
98
{
99
return
m_max_train_time
;
100
}
101
102
EMachineType
CMachine::get_classifier_type
()
103
{
104
return
CT_NONE
;
105
}
106
107
void
CMachine::set_solver_type
(
ESolverType
st)
108
{
109
m_solver_type
= st;
110
}
111
112
ESolverType
CMachine::get_solver_type
()
113
{
114
return
m_solver_type
;
115
}
116
117
void
CMachine::set_store_model_features
(
bool
store_model)
118
{
119
m_store_model_features
= store_model;
120
}
121
122
void
CMachine::data_lock
(
CLabels
* labs,
CFeatures
* features)
123
{
124
SG_DEBUG
(
"entering %s::data_lock\n"
,
get_name
())
125
if
(!
supports_locking
())
126
{
127
{
128
SG_ERROR
(
"%s::data_lock(): Machine does not support data locking!\n"
,
129
get_name
());
130
}
131
}
132
133
if
(!labs)
134
{
135
SG_ERROR
(
"%s::data_lock() is not possible will NULL labels!\n"
,
136
get_name
());
137
}
138
139
/* first set labels */
140
set_labels
(labs);
141
142
if
(
m_data_locked
)
143
{
144
SG_ERROR
(
"%s::data_lock() was already called. Dont lock twice!"
,
145
get_name
());
146
}
147
148
m_data_locked
=
true
;
149
post_lock
(labs,features);
150
SG_DEBUG
(
"leaving %s::data_lock\n"
,
get_name
())
151
}
152
153
void
CMachine::data_unlock
()
154
{
155
SG_DEBUG
(
"entering %s::data_lock\n"
,
get_name
())
156
if
(
m_data_locked
)
157
m_data_locked
=
false
;
158
159
SG_DEBUG
(
"leaving %s::data_lock\n"
,
get_name
())
160
}
161
162
CLabels
*
CMachine::apply
(
CFeatures
* data)
163
{
164
SG_DEBUG
(
"entering %s::apply(%s at %p)\n"
,
165
get_name
(), data ? data->
get_name
() :
"NULL"
, data);
166
167
CLabels
* result=NULL;
168
169
switch
(
get_machine_problem_type
())
170
{
171
case
PT_BINARY
:
172
result=
apply_binary
(data);
173
break
;
174
case
PT_REGRESSION
:
175
result=
apply_regression
(data);
176
break
;
177
case
PT_MULTICLASS
:
178
result=
apply_multiclass
(data);
179
break
;
180
case
PT_STRUCTURED
:
181
result=
apply_structured
(data);
182
break
;
183
case
PT_LATENT
:
184
result=
apply_latent
(data);
185
break
;
186
default
:
187
SG_ERROR
(
"Unknown problem type"
)
188
break
;
189
}
190
191
SG_DEBUG
(
"leaving %s::apply(%s at %p)\n"
,
192
get_name
(), data ? data->
get_name
() :
"NULL"
, data);
193
194
return
result;
195
}
196
197
CLabels
*
CMachine::apply_locked
(
SGVector<index_t>
indices)
198
{
199
switch
(
get_machine_problem_type
())
200
{
201
case
PT_BINARY
:
202
return
apply_locked_binary
(indices);
203
case
PT_REGRESSION
:
204
return
apply_locked_regression
(indices);
205
case
PT_MULTICLASS
:
206
return
apply_locked_multiclass
(indices);
207
case
PT_STRUCTURED
:
208
return
apply_locked_structured
(indices);
209
case
PT_LATENT
:
210
return
apply_locked_latent
(indices);
211
default
:
212
SG_ERROR
(
"Unknown problem type"
)
213
break
;
214
}
215
return
NULL;
216
}
217
218
CBinaryLabels
*
CMachine::apply_binary
(
CFeatures
* data)
219
{
220
SG_ERROR
(
"This machine does not support apply_binary()\n"
)
221
return
NULL;
222
}
223
224
CRegressionLabels
*
CMachine::apply_regression
(
CFeatures
* data)
225
{
226
SG_ERROR
(
"This machine does not support apply_regression()\n"
)
227
return
NULL;
228
}
229
230
CMulticlassLabels
*
CMachine::apply_multiclass
(
CFeatures
* data)
231
{
232
SG_ERROR
(
"This machine does not support apply_multiclass()\n"
)
233
return
NULL;
234
}
235
236
CStructuredLabels
*
CMachine::apply_structured
(
CFeatures
* data)
237
{
238
SG_ERROR
(
"This machine does not support apply_structured()\n"
)
239
return
NULL;
240
}
241
242
CLatentLabels
*
CMachine::apply_latent
(
CFeatures
* data)
243
{
244
SG_ERROR
(
"This machine does not support apply_latent()\n"
)
245
return
NULL;
246
}
247
248
CBinaryLabels
*
CMachine::apply_locked_binary
(
SGVector<index_t>
indices)
249
{
250
SG_ERROR
(
"apply_locked_binary(SGVector<index_t>) is not yet implemented "
251
"for %s\n"
,
get_name
());
252
return
NULL;
253
}
254
255
CRegressionLabels
*
CMachine::apply_locked_regression
(
SGVector<index_t>
indices)
256
{
257
SG_ERROR
(
"apply_locked_regression(SGVector<index_t>) is not yet implemented "
258
"for %s\n"
,
get_name
());
259
return
NULL;
260
}
261
262
CMulticlassLabels
*
CMachine::apply_locked_multiclass
(
SGVector<index_t>
indices)
263
{
264
SG_ERROR
(
"apply_locked_multiclass(SGVector<index_t>) is not yet implemented "
265
"for %s\n"
,
get_name
());
266
return
NULL;
267
}
268
269
CStructuredLabels
*
CMachine::apply_locked_structured
(
SGVector<index_t>
indices)
270
{
271
SG_ERROR
(
"apply_locked_structured(SGVector<index_t>) is not yet implemented "
272
"for %s\n"
,
get_name
());
273
return
NULL;
274
}
275
276
CLatentLabels
*
CMachine::apply_locked_latent
(
SGVector<index_t>
indices)
277
{
278
SG_ERROR
(
"apply_locked_latent(SGVector<index_t>) is not yet implemented "
279
"for %s\n"
,
get_name
());
280
return
NULL;
281
}
282
283
SHOGUN
机器学习工具包 - 项目文档