SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
regression
svr
LibSVR.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 1999-2009 Soeren Sonnenburg
8
* Written (W) 2013 Heiko Strathmann
9
* Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
10
*/
11
12
#include <
shogun/regression/svr/LibSVR.h
>
13
#include <
shogun/labels/RegressionLabels.h
>
14
#include <
shogun/io/SGIO.h
>
15
16
using namespace
shogun;
17
18
CLibSVR::CLibSVR
()
19
:
CSVM
()
20
{
21
model
=NULL;
22
solver_type
=LIBSVR_EPSILON_SVR;
23
}
24
25
CLibSVR::CLibSVR
(
float64_t
C,
float64_t
svr_param,
CKernel
* k,
CLabels
* lab,
26
LIBSVR_SOLVER_TYPE st)
27
:
CSVM
()
28
{
29
model
=NULL;
30
31
set_C
(C,C);
32
33
switch
(st)
34
{
35
case
LIBSVR_EPSILON_SVR:
36
set_tube_epsilon
(svr_param);
37
break
;
38
case
LIBSVR_NU_SVR:
39
set_nu
(svr_param);
40
break
;
41
default
:
42
SG_ERROR
(
"CLibSVR::CLibSVR(): Unknown solver type!\n"
);
43
break
;
44
}
45
46
set_labels
(lab);
47
set_kernel
(k);
48
solver_type
=st;
49
}
50
51
CLibSVR::~CLibSVR
()
52
{
53
SG_FREE(
model
);
54
}
55
56
EMachineType
CLibSVR::get_classifier_type
()
57
{
58
return
CT_LIBSVR
;
59
}
60
61
bool
CLibSVR::train_machine
(
CFeatures
* data)
62
{
63
ASSERT
(
kernel
)
64
ASSERT
(
m_labels
&&
m_labels
->
get_num_labels
())
65
ASSERT
(
m_labels
->
get_label_type
() ==
LT_REGRESSION
)
66
67
if
(data)
68
{
69
if
(
m_labels
->
get_num_labels
() != data->
get_num_vectors
())
70
SG_ERROR
(
"Number of training vectors does not match number of labels\n"
)
71
kernel
->
init
(data, data);
72
}
73
74
SG_FREE(
model
);
75
76
struct
svm_node* x_space;
77
78
problem
.l=
m_labels
->
get_num_labels
();
79
SG_INFO
(
"%d trainlabels\n"
,
problem
.l)
80
81
problem
.y=SG_MALLOC(
float64_t
,
problem
.l);
82
problem
.x=SG_MALLOC(
struct
svm_node*,
problem
.l);
83
x_space=SG_MALLOC(
struct
svm_node, 2*
problem
.l);
84
85
for
(int32_t i=0; i<
problem
.l; i++)
86
{
87
problem
.y[i]=((
CRegressionLabels
*)
m_labels
)->get_label(i);
88
problem
.x[i]=&x_space[2*i];
89
x_space[2*i].index=i;
90
x_space[2*i+1].index=-1;
91
}
92
93
int32_t weights_label[2]={-1,+1};
94
float64_t
weights[2]={1.0,
get_C2
()/
get_C1
()};
95
96
switch
(
solver_type
)
97
{
98
case
LIBSVR_EPSILON_SVR:
99
param
.svm_type=EPSILON_SVR;
100
break
;
101
case
LIBSVR_NU_SVR:
102
param
.svm_type=NU_SVR;
103
break
;
104
default
:
105
SG_ERROR
(
"%s::train_machine(): Unknown solver type!\n"
,
get_name
());
106
break
;
107
}
108
109
param
.kernel_type = LINEAR;
110
param
.degree = 3;
111
param
.gamma = 0;
// 1/k
112
param
.coef0 = 0;
113
param
.nu =
nu
;
114
param
.kernel=
kernel
;
115
param
.cache_size =
kernel
->
get_cache_size
();
116
param
.max_train_time =
m_max_train_time
;
117
param
.C =
get_C1
();
118
param
.eps =
epsilon
;
119
param
.p =
tube_epsilon
;
120
param
.shrinking = 1;
121
param
.nr_weight = 2;
122
param
.weight_label = weights_label;
123
param
.weight = weights;
124
param
.use_bias =
get_bias_enabled
();
125
126
const
char
* error_msg = svm_check_parameter(&
problem
,&
param
);
127
128
if
(error_msg)
129
SG_ERROR
(
"Error: %s\n"
,error_msg)
130
131
model
= svm_train(&
problem
, &
param
);
132
133
if
(
model
)
134
{
135
ASSERT
(
model
->nr_class==2)
136
ASSERT
((
model
->l==0) || (
model
->l>0 &&
model
->SV &&
model
->sv_coef &&
model
->sv_coef[0]))
137
138
int32_t num_sv=
model
->l;
139
140
create_new_model
(num_sv);
141
142
CSVM::set_objective
(
model
->objective);
143
144
set_bias
(-
model
->rho[0]);
145
146
for
(int32_t i=0; i<num_sv; i++)
147
{
148
set_support_vector
(i, (
model
->SV[i])->index);
149
set_alpha
(i,
model
->sv_coef[0][i]);
150
}
151
152
SG_FREE(
problem
.x);
153
SG_FREE(
problem
.y);
154
SG_FREE(x_space);
155
156
svm_destroy_model(
model
);
157
model
=NULL;
158
return
true
;
159
}
160
else
161
return
false
;
162
}
SHOGUN
机器学习工具包 - 项目文档