SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
multiclass
QDA.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Fernando José Iglesias García
8
* Copyright (C) 2012 Fernando José Iglesias García
9
*/
10
11
#ifndef _QDA_H__
12
#define _QDA_H__
13
14
#include <
shogun/lib/config.h
>
15
16
#ifdef HAVE_EIGEN3
17
18
#include <
shogun/features/DotFeatures.h
>
19
#include <
shogun/features/DenseFeatures.h
>
20
#include <
shogun/machine/NativeMulticlassMachine.h
>
21
#include <
shogun/lib/SGNDArray.h
>
22
23
namespace
shogun
24
{
25
26
//#define DEBUG_QDA
27
36
class
CQDA
:
public
CNativeMulticlassMachine
37
{
38
public
:
39
MACHINE_PROBLEM_TYPE
(
PT_MULTICLASS
)
40
41
46
CQDA
(
float64_t
tolerance = 1e-4,
bool
store_covs =
false
);
47
55
CQDA
(
CDenseFeatures<float64_t>
* traindat,
CLabels
* trainlab,
float64_t
tolerance = 1e-4,
bool
store_covs =
false
);
56
57
virtual
~CQDA
();
58
64
virtual
CMulticlassLabels
*
apply_multiclass
(
CFeatures
* data=NULL);
65
70
inline
void
set_store_covs
(
bool
store_covs) { m_store_covs = store_covs; }
71
76
inline
bool
get_store_covs
() {
return
m_store_covs; }
77
82
inline
void
set_tolerance
(
float64_t
tolerance) { m_tolerance = tolerance; }
83
88
inline
bool
get_tolerance
() {
return
m_tolerance; }
89
94
virtual
EMachineType
get_classifier_type
() {
return
CT_QDA
; }
95
100
virtual
void
set_features
(
CDotFeatures
* feat)
101
{
102
if
(feat->
get_feature_class
() !=
C_DENSE
||
103
feat->
get_feature_type
() !=
F_DREAL
)
104
SG_ERROR
(
"QDA requires SIMPLE REAL valued features\n"
)
105
106
SG_REF
(feat);
107
SG_UNREF
(m_features);
108
m_features = feat;
109
}
110
115
virtual
CDotFeatures
*
get_features
() {
SG_REF
(m_features);
return
m_features; }
116
121
virtual
const
char
*
get_name
()
const
{
return
"QDA"
; }
122
129
inline
SGVector< float64_t >
get_mean
(int32_t c)
const
130
{
131
return
SGVector< float64_t >
(m_means.
get_column_vector
(c), m_dim,
false
);
132
}
133
140
inline
SGMatrix< float64_t >
get_cov
(int32_t c)
const
141
{
142
return
SGMatrix< float64_t >
(m_covs.
get_matrix
(c), m_dim, m_dim,
false
);
143
}
144
145
protected
:
152
virtual
bool
train_machine
(
CFeatures
* data = NULL);
153
154
private
:
155
void
init();
156
157
void
cleanup();
158
159
private
:
161
CDotFeatures
* m_features;
162
164
float64_t
m_tolerance;
165
167
bool
m_store_covs;
168
170
int32_t m_num_classes;
171
173
int32_t m_dim;
174
178
SGNDArray< float64_t >
m_covs;
179
181
SGMatrix< float64_t >
m_means;
182
184
SGNDArray< float64_t >
m_M;
185
187
SGVector< float32_t >
m_slog;
188
189
};
/* class QDA */
190
}
/* namespace shogun */
191
192
#endif
/* HAVE_EIGEN3 */
193
#endif
/* _QDA_H__ */
SHOGUN
机器学习工具包 - 项目文档