SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
latent
LatentModel.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Viktor Gal
8
* Copyright (C) 2012 Viktor Gal
9
*/
10
11
#include <
shogun/latent/LatentModel.h
>
12
#include <
shogun/labels/BinaryLabels.h
>
13
14
using namespace
shogun;
15
16
CLatentModel::CLatentModel
()
17
: m_features(NULL),
18
m_labels(NULL),
19
m_do_caching(false),
20
m_cached_psi(NULL)
21
{
22
register_parameters();
23
}
24
25
CLatentModel::CLatentModel
(
CLatentFeatures
* feats,
CLatentLabels
* labels,
bool
do_caching)
26
: m_features(feats),
27
m_labels(labels),
28
m_do_caching(do_caching),
29
m_cached_psi(NULL)
30
{
31
register_parameters();
32
SG_REF
(
m_features
);
33
SG_REF
(
m_labels
);
34
}
35
36
CLatentModel::~CLatentModel
()
37
{
38
SG_UNREF
(
m_labels
);
39
SG_UNREF
(
m_features
);
40
SG_UNREF
(
m_cached_psi
);
41
}
42
43
int32_t
CLatentModel::get_num_vectors
()
const
44
{
45
return
m_features
->
get_num_vectors
();
46
}
47
48
void
CLatentModel::set_labels
(
CLatentLabels
* labs)
49
{
50
SG_REF
(labs);
51
SG_UNREF
(
m_labels
);
52
m_labels
= labs;
53
}
54
55
CLatentLabels
*
CLatentModel::get_labels
()
const
56
{
57
SG_REF
(
m_labels
);
58
return
m_labels
;
59
}
60
61
void
CLatentModel::set_features
(
CLatentFeatures
* feats)
62
{
63
SG_REF
(feats);
64
SG_UNREF
(
m_features
);
65
m_features
= feats;
66
}
67
68
void
CLatentModel::argmax_h
(
const
SGVector<float64_t>
& w)
69
{
70
int32_t num =
get_num_vectors
();
71
CBinaryLabels
* y =
CLabelsFactory::to_binary
(
m_labels
->
get_labels
());
72
ASSERT
(num > 0)
73
ASSERT
(num ==
m_labels
->
get_num_labels
())
74
75
// argmax_h only for positive examples
76
for
(int32_t i = 0; i < num; ++i)
77
{
78
if
(y->
get_label
(i) == 1)
79
{
80
// infer h and set it for the argmax_h <w,psi(x,h)>
81
CData
* latent_data =
infer_latent_variable
(w, i);
82
m_labels
->
set_latent_label
(i, latent_data);
83
}
84
}
85
}
86
87
void
CLatentModel::register_parameters()
88
{
89
m_parameters
->
add
((
CSGObject
**) &
m_features
,
"features"
,
"Latent features"
);
90
m_parameters
->
add
((
CSGObject
**) &
m_labels
,
"labels"
,
"Latent labels"
);
91
m_parameters
->
add
((
CSGObject
**) &
m_cached_psi
,
"cached_psi"
,
"Cached PSI features after argmax_h"
);
92
m_parameters
->
add
(&
m_do_caching
,
"do_caching"
,
"Indicate whether or not do PSI vector caching after argmax_h"
);
93
}
94
95
96
CLatentFeatures
*
CLatentModel::get_features
()
const
97
{
98
SG_REF
(
m_features
);
99
return
m_features
;
100
}
101
102
void
CLatentModel::cache_psi_features
()
103
{
104
if
(
m_do_caching
)
105
{
106
if
(
m_cached_psi
)
107
SG_UNREF
(
m_cached_psi
);
108
m_cached_psi
= this->
get_psi_feature_vectors
();
109
SG_REF
(
m_cached_psi
);
110
}
111
}
112
113
CDotFeatures
*
CLatentModel::get_cached_psi_features
()
const
114
{
115
if
(
m_do_caching
)
116
{
117
SG_REF
(
m_cached_psi
);
118
return
m_cached_psi
;
119
}
120
return
NULL;
121
}
SHOGUN
机器学习工具包 - 项目文档