SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
machine
gp
InferenceMethod.h
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Roman Votyakov
8
* Written (W) 2013-2014 Heiko Strathmann
9
* Copyright (C) 2012 Jacob Walker
10
* Copyright (C) 2013 Roman Votyakov
11
*/
12
13
#ifndef CINFERENCEMETHOD_H_
14
#define CINFERENCEMETHOD_H_
15
16
#include <
shogun/lib/config.h
>
17
18
#ifdef HAVE_EIGEN3
19
20
#include <
shogun/base/SGObject.h
>
21
#include <
shogun/kernel/Kernel.h
>
22
#include <
shogun/features/Features.h
>
23
#include <
shogun/labels/Labels.h
>
24
#include <
shogun/machine/gp/LikelihoodModel.h
>
25
#include <
shogun/machine/gp/MeanFunction.h
>
26
#include <
shogun/evaluation/DifferentiableFunction.h
>
27
28
namespace
shogun
29
{
30
32
enum
EInferenceType
33
{
34
INF_NONE
=0,
35
INF_EXACT
=10,
36
INF_FITC
=20,
37
INF_LAPLACIAN
=30,
38
INF_EP
=40
39
};
40
50
class
CInferenceMethod
:
public
CDifferentiableFunction
51
{
52
public
:
54
CInferenceMethod
();
55
64
CInferenceMethod
(
CKernel
* kernel,
CFeatures
* features,
65
CMeanFunction
* mean,
CLabels
* labels,
CLikelihoodModel
* model);
66
67
virtual
~CInferenceMethod
();
68
73
virtual
EInferenceType
get_inference_type
()
const
{
return
INF_NONE
; }
74
86
virtual
float64_t
get_negative_log_marginal_likelihood
()=0;
87
123
float64_t
get_marginal_likelihood_estimate
(int32_t num_importance_samples=1,
124
float64_t
ridge_size=1e-15);
125
139
virtual
CMap<TParameter*, SGVector<float64_t>
>*
140
get_negative_log_marginal_likelihood_derivatives
(
CMap
<
TParameter
*,
141
CSGObject
*>* parameters);
142
153
virtual
SGVector<float64_t>
get_alpha
()=0;
154
166
virtual
SGMatrix<float64_t>
get_cholesky
()=0;
167
179
virtual
SGVector<float64_t>
get_diagonal_vector
()=0;
180
196
virtual
SGVector<float64_t>
get_posterior_mean
()=0;
197
213
virtual
SGMatrix<float64_t>
get_posterior_covariance
()=0;
214
222
virtual
CMap<TParameter*, SGVector<float64_t>
>*
get_gradient
(
223
CMap<TParameter*, CSGObject*>
* parameters)
224
{
225
return
get_negative_log_marginal_likelihood_derivatives
(parameters);
226
}
227
232
virtual
SGVector<float64_t>
get_value
()
233
{
234
SGVector<float64_t>
result(1);
235
result[0]=
get_negative_log_marginal_likelihood
();
236
return
result;
237
}
238
243
virtual
CFeatures
*
get_features
() {
SG_REF
(
m_features
);
return
m_features
; }
244
249
virtual
void
set_features
(
CFeatures
* feat)
250
{
251
SG_REF
(feat);
252
SG_UNREF
(
m_features
);
253
m_features
=feat;
254
}
255
260
virtual
CKernel
*
get_kernel
() {
SG_REF
(
m_kernel
);
return
m_kernel
; }
261
266
virtual
void
set_kernel
(
CKernel
* kern)
267
{
268
SG_REF
(kern);
269
SG_UNREF
(
m_kernel
);
270
m_kernel
=kern;
271
}
272
277
virtual
CMeanFunction
*
get_mean
() {
SG_REF
(
m_mean
);
return
m_mean
; }
278
283
virtual
void
set_mean
(
CMeanFunction
* m)
284
{
285
SG_REF
(m);
286
SG_UNREF
(
m_mean
);
287
m_mean
=m;
288
}
289
294
virtual
CLabels
*
get_labels
() {
SG_REF
(
m_labels
);
return
m_labels
; }
295
300
virtual
void
set_labels
(
CLabels
* lab)
301
{
302
SG_REF
(lab);
303
SG_UNREF
(
m_labels
);
304
m_labels
=lab;
305
}
306
311
CLikelihoodModel
*
get_model
() {
SG_REF
(
m_model
);
return
m_model
; }
312
317
virtual
void
set_model
(
CLikelihoodModel
* mod)
318
{
319
SG_REF
(mod);
320
SG_UNREF
(
m_model
);
321
m_model
=mod;
322
}
323
328
virtual
float64_t
get_scale
()
const
{
return
m_scale
; }
329
334
virtual
void
set_scale
(
float64_t
scale) {
m_scale
=scale; }
335
341
virtual
bool
supports_regression
()
const
{
return
false
; }
342
348
virtual
bool
supports_binary
()
const
{
return
false
; }
349
355
virtual
bool
supports_multiclass
()
const
{
return
false
; }
356
358
virtual
void
update
();
359
360
protected
:
362
virtual
void
check_members
()
const
;
363
365
virtual
void
update_alpha
()=0;
366
368
virtual
void
update_chol
()=0;
369
373
virtual
void
update_deriv
()=0;
374
376
virtual
void
update_train_kernel
();
377
385
virtual
SGVector<float64_t>
get_derivative_wrt_inference_method
(
386
const
TParameter
* param)=0;
387
395
virtual
SGVector<float64_t>
get_derivative_wrt_likelihood_model
(
396
const
TParameter
* param)=0;
397
405
virtual
SGVector<float64_t>
get_derivative_wrt_kernel
(
406
const
TParameter
* param)=0;
407
415
virtual
SGVector<float64_t>
get_derivative_wrt_mean
(
416
const
TParameter
* param)=0;
417
421
static
void
*
get_derivative_helper
(
void
* p);
422
423
private
:
424
void
init();
425
426
protected
:
428
CKernel
*
m_kernel
;
429
431
CMeanFunction
*
m_mean
;
432
434
CLikelihoodModel
*
m_model
;
435
437
CFeatures
*
m_features
;
438
440
CLabels
*
m_labels
;
441
443
SGVector<float64_t>
m_alpha
;
444
446
SGMatrix<float64_t>
m_L
;
447
449
float64_t
m_scale
;
450
452
SGMatrix<float64_t>
m_ktrtr
;
453
};
454
}
455
#endif
/* HAVE_EIGEN3 */
456
#endif
/* CINFERENCEMETHOD_H_ */
SHOGUN
机器学习工具包 - 项目文档