SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
kernel
normalizer
ScatterKernelNormalizer.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2010 Soeren Sonnenburg
8
* Copyright (C) 2010 Berlin Institute of Technology
9
*/
10
11
#ifndef _SCATTERKERNELNORMALIZER_H___
12
#define _SCATTERKERNELNORMALIZER_H___
13
14
#include <
shogun/kernel/normalizer/KernelNormalizer.h
>
15
#include <
shogun/kernel/normalizer/IdentityKernelNormalizer.h
>
16
#include <
shogun/kernel/Kernel.h
>
17
#include <
shogun/labels/Labels.h
>
18
#include <
shogun/labels/MulticlassLabels.h
>
19
#include <
shogun/io/SGIO.h
>
20
21
namespace
shogun
22
{
24
class
CScatterKernelNormalizer
:
public
CKernelNormalizer
25
{
26
27
public
:
29
CScatterKernelNormalizer
() :
CKernelNormalizer
()
30
{
31
init
();
32
}
33
36
CScatterKernelNormalizer
(
float64_t
const_diag,
float64_t
const_offdiag,
37
CLabels
* labels,
CKernelNormalizer
* normalizer=NULL)
38
:
CKernelNormalizer
()
39
{
40
init
();
41
42
m_testing_class
=-1;
43
m_const_diag
=const_diag;
44
m_const_offdiag
=const_offdiag;
45
46
ASSERT
(labels)
47
SG_REF
(labels);
48
m_labels
=labels;
49
ASSERT
(labels->
get_label_type
()==
LT_MULTICLASS
)
50
labels->
ensure_valid
();
51
52
if
(normalizer==NULL)
53
normalizer=
new
CIdentityKernelNormalizer
();
54
SG_REF
(normalizer);
55
m_normalizer
=normalizer;
56
57
SG_DEBUG
(
"Constructing ScatterKernelNormalizer with const_diag=%g"
58
" const_offdiag=%g num_labels=%d and normalizer='%s'\n"
,
59
const_diag, const_offdiag, labels->
get_num_labels
(),
60
normalizer->get_name());
61
}
62
64
virtual
~CScatterKernelNormalizer
()
65
{
66
SG_UNREF
(
m_labels
);
67
SG_UNREF
(
m_normalizer
);
68
}
69
72
virtual
bool
init
(
CKernel
* k)
73
{
74
m_normalizer
->
init
(k);
75
return
true
;
76
}
77
82
int32_t
get_testing_class
()
83
{
84
return
m_testing_class
;
85
}
86
91
void
set_testing_class
(int32_t c)
92
{
93
m_testing_class
=c;
94
}
95
101
virtual
float64_t
normalize
(
float64_t
value, int32_t idx_lhs,
102
int32_t idx_rhs)
103
{
104
value=
m_normalizer
->
normalize
(value, idx_lhs, idx_rhs);
105
float64_t
c=
m_const_offdiag
;
106
107
if
(
m_testing_class
>=0)
108
{
109
if
(((
CMulticlassLabels
*)
m_labels
)->get_label(idx_lhs) ==
m_testing_class
)
110
c=
m_const_diag
;
111
}
112
else
113
{
114
if
(((
CMulticlassLabels
*)
m_labels
)->get_label(idx_lhs) == ((
CMulticlassLabels
*)
m_labels
)->get_label(idx_rhs))
115
c=
m_const_diag
;
116
117
}
118
return
value*c;
119
}
120
125
virtual
float64_t
normalize_lhs
(
float64_t
value, int32_t idx_lhs)
126
{
127
SG_ERROR
(
"normalize_lhs not implemented"
)
128
return
0;
129
}
130
135
virtual
float64_t
normalize_rhs
(
float64_t
value, int32_t idx_rhs)
136
{
137
SG_ERROR
(
"normalize_rhs not implemented"
)
138
return
0;
139
}
140
142
virtual
const
char
*
get_name
()
const
143
{
144
return
"ScatterKernelNormalizer"
;
145
}
146
147
private
:
148
void
init()
149
{
150
m_const_diag
= 1.0;
151
m_const_offdiag
= 1.0;
152
153
m_labels
= NULL;
154
m_normalizer
= NULL;
155
156
m_testing_class
= -1;
157
158
SG_ADD
(&
m_testing_class
,
"m_testing_class"
,
159
"Testing Class."
,
MS_NOT_AVAILABLE
);
160
SG_ADD
(&
m_const_diag
,
"m_const_diag"
,
161
"Factor to multiply to diagonal elements."
,
MS_AVAILABLE
);
162
SG_ADD
(&
m_const_offdiag
,
"m_const_offdiag"
,
163
"Factor to multiply to off-diagonal elements."
,
MS_AVAILABLE
);
164
165
SG_ADD
((
CSGObject
**) &
m_labels
,
"m_labels"
,
"Labels"
,
MS_NOT_AVAILABLE
);
166
SG_ADD
((
CSGObject
**) &
m_normalizer
,
"m_normalizer"
,
"Kernel normalizer."
,
167
MS_AVAILABLE
);
168
}
169
170
protected
:
171
173
float64_t
m_const_diag
;
175
float64_t
m_const_offdiag
;
176
178
CLabels
*
m_labels
;
179
181
CKernelNormalizer
*
m_normalizer
;
182
184
int32_t
m_testing_class
;
185
};
186
}
187
#endif
188
SHOGUN
机器学习工具包 - 项目文档