SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
structure
MAPInference.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Shell Hu
8
* Copyright (C) 2013 Shell Hu
9
*/
10
11
#include <
shogun/structure/MAPInference.h
>
12
#include <
shogun/structure/BeliefPropagation.h
>
13
#include <
shogun/labels/FactorGraphLabels.h
>
14
15
using namespace
shogun;
16
17
CMAPInference::CMAPInference
() :
CSGObject
()
18
{
19
SG_UNSTABLE
(
"CMAPInference::CMAPInference()"
,
"\n"
);
20
21
init();
22
}
23
24
CMAPInference::CMAPInference
(
CFactorGraph
* fg,
EMAPInferType
inference_method)
25
:
CSGObject
()
26
{
27
init();
28
m_fg
= fg;
29
30
REQUIRE
(fg != NULL,
"%s::CMAPInference(): fg cannot be NULL!\n"
,
get_name
());
31
32
switch
(inference_method)
33
{
34
case
TREE_MAX_PROD
:
35
m_infer_impl
=
new
CTreeMaxProduct(fg);
36
break
;
37
case
LOOPY_MAX_PROD
:
38
SG_ERROR
(
"%s::CMAPInference(): LoopyMaxProduct has not been implemented!\n"
,
39
get_name
());
40
break
;
41
case
LP_RELAXATION
:
42
SG_ERROR
(
"%s::CMAPInference(): LPRelaxation has not been implemented!\n"
,
43
get_name
());
44
break
;
45
case
TRWS_MAX_PROD
:
46
SG_ERROR
(
"%s::CMAPInference(): TRW-S has not been implemented!\n"
,
47
get_name
());
48
break
;
49
case
ITER_COND_MODE
:
50
SG_ERROR
(
"%s::CMAPInference(): ICM has not been implemented!\n"
,
51
get_name
());
52
break
;
53
case
NAIVE_MEAN_FIELD
:
54
SG_ERROR
(
"%s::CMAPInference(): NaiveMeanField has not been implemented!\n"
,
55
get_name
());
56
break
;
57
case
STRUCT_MEAN_FIELD
:
58
SG_ERROR
(
"%s::CMAPInference(): StructMeanField has not been implemented!\n"
,
59
get_name
());
60
break
;
61
default
:
62
SG_ERROR
(
"%s::CMAPInference(): unsupported inference method!\n"
,
63
get_name
());
64
break
;
65
}
66
67
SG_REF
(
m_infer_impl
);
68
SG_REF
(
m_fg
);
69
}
70
71
CMAPInference::~CMAPInference
()
72
{
73
SG_UNREF
(
m_infer_impl
);
74
SG_UNREF
(
m_outputs
);
75
SG_UNREF
(
m_fg
);
76
}
77
78
void
CMAPInference::init()
79
{
80
SG_ADD
((
CSGObject
**)&
m_fg
,
"fg"
,
"factor graph"
,
MS_NOT_AVAILABLE
);
81
SG_ADD
((
CSGObject
**)&
m_outputs
,
"outputs"
,
"Structured outputs"
,
MS_NOT_AVAILABLE
);
82
SG_ADD
((
CSGObject
**)&
m_infer_impl
,
"infer_impl"
,
"Inference implementation"
,
MS_NOT_AVAILABLE
);
83
SG_ADD
(&
m_energy
,
"energy"
,
"Minimized energy"
,
MS_NOT_AVAILABLE
);
84
85
m_outputs
= NULL;
86
m_infer_impl
= NULL;
87
m_fg
= NULL;
88
m_energy
= 0;
89
}
90
91
void
CMAPInference::inference
()
92
{
93
SGVector<int32_t>
assignment(
m_fg
->
get_num_vars
());
94
assignment.
zero
();
95
m_energy
=
m_infer_impl
->
inference
(assignment);
96
97
// create structured output, with default normalized hamming loss
98
SG_UNREF
(
m_outputs
);
99
SGVector<float64_t>
loss_weights(
m_fg
->
get_num_vars
());
100
SGVector<float64_t>::fill_vector
(loss_weights.vector, loss_weights.vlen, 1.0 / loss_weights.vlen);
101
m_outputs
=
new
CFactorGraphObservation
(assignment, loss_weights);
// already ref() in constructor
102
SG_REF
(
m_outputs
);
103
}
104
105
CFactorGraphObservation
*
CMAPInference::get_structured_outputs
()
const
106
{
107
SG_REF
(
m_outputs
);
108
return
m_outputs
;
109
}
110
111
float64_t
CMAPInference::get_energy
()
const
112
{
113
return
m_energy
;
114
}
115
116
//-----------------------------------------------------------------
117
118
CMAPInferImpl::CMAPInferImpl
() :
CSGObject
()
119
{
120
register_parameters();
121
}
122
123
CMAPInferImpl::CMAPInferImpl
(
CFactorGraph
* fg)
124
:
CSGObject
()
125
{
126
register_parameters();
127
m_fg
= fg;
128
}
129
130
CMAPInferImpl::~CMAPInferImpl
()
131
{
132
}
133
134
void
CMAPInferImpl::register_parameters()
135
{
136
SG_ADD
((
CSGObject
**)&
m_fg
,
"fg"
,
137
"Factor graph pointer"
,
MS_NOT_AVAILABLE
);
138
139
m_fg
= NULL;
140
}
141
SHOGUN
机器学习工具包 - 项目文档