SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
classifier
svm
CPLEXSVM.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 1999-2009 Soeren Sonnenburg
8
* Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9
*/
10
11
#include <
shogun/classifier/svm/CPLEXSVM.h
>
12
#include <
shogun/lib/common.h
>
13
14
#ifdef USE_CPLEX
15
#include <
shogun/io/SGIO.h
>
16
#include <
shogun/mathematics/Math.h
>
17
#include <
shogun/mathematics/Cplex.h
>
18
#include <
shogun/labels/Labels.h
>
19
20
using namespace
shogun;
21
22
CCPLEXSVM::CCPLEXSVM
()
23
:
CSVM
()
24
{
25
}
26
27
CCPLEXSVM::~CCPLEXSVM
()
28
{
29
}
30
31
bool
CCPLEXSVM::train_machine
(
CFeatures
* data)
32
{
33
ASSERT
(
m_labels
)
34
ASSERT
(
m_labels
->
get_label_type
() ==
LT_BINARY
)
35
36
bool
result =
false
;
37
CCplex
cplex;
38
39
if
(data)
40
{
41
if
(
m_labels
->
get_num_labels
() != data->
get_num_vectors
())
42
{
43
SG_ERROR
(
"%s::train_machine(): Number of training vectors (%d) does"
44
" not match number of labels (%d)\n"
,
get_name
(),
45
data->
get_num_vectors
(),
m_labels
->
get_num_labels
());
46
}
47
kernel
->
init
(data, data);
48
}
49
50
if
(cplex.init(
E_QP
))
51
{
52
int32_t n,m;
53
int32_t num_label=0;
54
SGVector<float64_t>
y=((
CBinaryLabels
*)
m_labels
)->get_labels();
55
SGMatrix<float64_t>
H
=
kernel
->
get_kernel_matrix
();
56
m=H.
num_rows
;
57
n=H.
num_cols
;
58
ASSERT
(n>0 && n==m && n==num_label)
59
float64_t
* alphas=SG_MALLOC(
float64_t
, n);
60
float64_t
* lb=SG_MALLOC(
float64_t
, n);
61
float64_t
* ub=SG_MALLOC(
float64_t
, n);
62
63
//hessian y'y.*K
64
for
(int32_t i=0; i<n; i++)
65
{
66
lb[i]=0;
67
ub[i]=
get_C1
();
68
69
for
(int32_t j=0; j<n; j++)
70
H[i*n+j]*=y[j]*y[i];
71
}
72
73
//feed qp to cplex
74
75
76
int32_t j=0;
77
for
(int32_t i=0; i<n; i++)
78
{
79
if
(alphas[i]>0)
80
{
81
//set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
82
set_alpha
(j, alphas[i]*((
CBinaryLabels
*)
m_labels
)->get_int_label(i));
83
set_support_vector
(j, i);
84
j++;
85
}
86
}
87
//compute_objective();
88
SG_INFO
(
"obj = %.16f, rho = %.16f\n"
,
get_objective
(),
get_bias
())
89
SG_INFO
(
"Number of SV: %ld\n"
,
get_num_support_vectors
())
90
91
SG_FREE(alphas);
92
SG_FREE(lb);
93
SG_FREE(ub);
94
95
result =
true
;
96
}
97
98
if
(!result)
99
SG_ERROR
(
"cplex svm failed"
)
100
101
return
result;
102
}
103
#endif
SHOGUN
机器学习工具包 - 项目文档