SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
evaluation
CrossValidationMulticlassStorage.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
8
*/
9
10
#include <
shogun/evaluation/CrossValidationMulticlassStorage.h
>
11
#include <
shogun/evaluation/ROCEvaluation.h
>
12
#include <
shogun/evaluation/PRCEvaluation.h
>
13
#include <
shogun/evaluation/MulticlassAccuracy.h
>
14
15
using namespace
shogun;
16
17
CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage
(
bool
compute_ROC,
bool
compute_PRC,
bool
compute_conf_matrices) :
18
CCrossValidationOutput
()
19
{
20
m_initialized
=
false
;
21
m_compute_ROC
= compute_ROC;
22
m_compute_PRC
= compute_PRC;
23
m_compute_conf_matrices
= compute_conf_matrices;
24
m_pred_labels
= NULL;
25
m_true_labels
= NULL;
26
m_num_classes
= 0;
27
m_binary_evaluations
=
new
CDynamicObjectArray
();
28
29
m_fold_ROC_graphs
=NULL;
30
m_conf_matrices
=NULL;
31
}
32
33
34
CCrossValidationMulticlassStorage::~CCrossValidationMulticlassStorage
()
35
{
36
if
(
m_compute_ROC
)
37
{
38
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
39
m_fold_ROC_graphs
[i].~
SGMatrix<float64_t>
();
40
41
SG_FREE(
m_fold_ROC_graphs
);
42
}
43
44
if
(
m_compute_PRC
)
45
{
46
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
47
m_fold_PRC_graphs
[i].~
SGMatrix<float64_t>
();
48
49
SG_FREE(
m_fold_PRC_graphs
);
50
}
51
52
if
(
m_compute_conf_matrices
)
53
{
54
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
; i++)
55
m_conf_matrices
[i].~
SGMatrix<int32_t>
();
56
57
SG_FREE(
m_conf_matrices
);
58
}
59
60
SG_UNREF
(
m_binary_evaluations
);
61
};
62
63
64
void
CCrossValidationMulticlassStorage::post_init
()
65
{
66
if
(
m_initialized
)
67
SG_ERROR
(
"CrossValidationMulticlassStorage was already initialized once\n"
)
68
69
if
(
m_compute_ROC
)
70
{
71
SG_DEBUG
(
"Allocating %d ROC graphs\n"
,
m_num_folds
*
m_num_runs
*
m_num_classes
)
72
m_fold_ROC_graphs
= SG_MALLOC(
SGMatrix<float64_t>
,
m_num_folds
*
m_num_runs
*m_num_classes);
73
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
74
new
(&
m_fold_ROC_graphs
[i])
SGMatrix<float64_t>
();
75
}
76
77
if
(
m_compute_PRC
)
78
{
79
SG_DEBUG
(
"Allocating %d PRC graphs\n"
,
m_num_folds
*
m_num_runs
*
m_num_classes
)
80
m_fold_PRC_graphs
= SG_MALLOC(
SGMatrix<float64_t>
,
m_num_folds
*
m_num_runs
*m_num_classes);
81
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
*
m_num_classes
; i++)
82
new
(&
m_fold_PRC_graphs
[i])
SGMatrix<float64_t>
();
83
}
84
85
if
(
m_binary_evaluations
->
get_num_elements
())
86
m_evaluations_results
=
SGVector<float64_t>
(
m_num_folds
*
m_num_runs
*
m_num_classes
*
m_binary_evaluations
->
get_num_elements
());
87
88
m_accuracies
=
SGVector<float64_t>
(
m_num_folds
*
m_num_runs
);
89
90
if
(
m_compute_conf_matrices
)
91
{
92
m_conf_matrices
= SG_MALLOC(
SGMatrix<int32_t>
,
m_num_folds
*
m_num_runs
);
93
for
(int32_t i=0; i<
m_num_folds
*
m_num_runs
; i++)
94
new
(&
m_conf_matrices
[i])
SGMatrix<int32_t>
();
95
}
96
97
m_initialized
=
true
;
98
}
99
100
void
CCrossValidationMulticlassStorage::init_expose_labels
(
CLabels
* labels)
101
{
102
ASSERT
((
CMulticlassLabels
*)labels)
103
m_num_classes
= ((
CMulticlassLabels
*)labels)->get_num_classes();
104
}
105
106
void
CCrossValidationMulticlassStorage::post_update_results
()
107
{
108
CROCEvaluation
eval_ROC;
109
CPRCEvaluation
eval_PRC;
110
int32_t n_evals =
m_binary_evaluations
->
get_num_elements
();
111
for
(int32_t c=0; c<
m_num_classes
; c++)
112
{
113
SG_DEBUG
(
"Computing ROC for run %d fold %d class %d"
,
m_current_run_index
,
m_current_fold_index
, c)
114
CBinaryLabels
* pred_labels_binary =
m_pred_labels
->
get_binary_for_class
(c);
115
CBinaryLabels
* true_labels_binary =
m_true_labels
->
get_binary_for_class
(c);
116
if
(
m_compute_ROC
)
117
{
118
eval_ROC.
evaluate
(pred_labels_binary, true_labels_binary);
119
m_fold_ROC_graphs
[
m_current_run_index
*
m_num_folds
*m_num_classes+
m_current_fold_index
*m_num_classes+c] =
120
eval_ROC.
get_ROC
();
121
}
122
if
(
m_compute_PRC
)
123
{
124
eval_PRC.
evaluate
(pred_labels_binary, true_labels_binary);
125
m_fold_PRC_graphs
[
m_current_run_index
*
m_num_folds
*m_num_classes+
m_current_fold_index
*m_num_classes+c] =
126
eval_PRC.
get_PRC
();
127
}
128
129
for
(int32_t i=0; i<n_evals; i++)
130
{
131
CBinaryClassEvaluation
* evaluator = (
CBinaryClassEvaluation
*)
m_binary_evaluations
->
get_element_safe
(i);
132
m_evaluations_results
[
m_current_run_index
*
m_num_folds
*m_num_classes*n_evals+
m_current_fold_index
*m_num_classes*n_evals+c*n_evals+i] =
133
evaluator->
evaluate
(pred_labels_binary, true_labels_binary);
134
SG_UNREF
(evaluator);
135
}
136
137
SG_UNREF
(pred_labels_binary);
138
SG_UNREF
(true_labels_binary);
139
}
140
CMulticlassAccuracy
accuracy;
141
142
m_accuracies
[
m_current_run_index
*
m_num_folds
+
m_current_fold_index
] = accuracy.
evaluate
(
m_pred_labels
,
m_true_labels
);
143
144
if
(
m_compute_conf_matrices
)
145
{
146
m_conf_matrices
[
m_current_run_index
*
m_num_folds
+
m_current_fold_index
] =
CMulticlassAccuracy::get_confusion_matrix
(
m_pred_labels
,
m_true_labels
);
147
}
148
}
149
150
void
CCrossValidationMulticlassStorage::update_test_result
(
CLabels
* results,
const
char
* prefix)
151
{
152
m_pred_labels
= (
CMulticlassLabels
*)results;
153
}
154
155
void
CCrossValidationMulticlassStorage::update_test_true_result
(
CLabels
* results,
const
char
* prefix)
156
{
157
m_true_labels
= (
CMulticlassLabels
*)results;
158
}
159
SHOGUN
机器学习工具包 - 项目文档