SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
modelselection
GridSearchModelSelection.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2011-2012 Heiko Strathmann
8
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
9
*/
10
11
#include <
shogun/modelselection/GridSearchModelSelection.h
>
12
#include <
shogun/modelselection/ParameterCombination.h
>
13
#include <
shogun/modelselection/ModelSelectionParameters.h
>
14
#include <
shogun/evaluation/CrossValidation.h
>
15
#include <
shogun/machine/Machine.h
>
16
17
using namespace
shogun;
18
19
CGridSearchModelSelection::CGridSearchModelSelection
() :
CModelSelection
()
20
{
21
}
22
23
CGridSearchModelSelection::CGridSearchModelSelection
(
24
CMachineEvaluation
* machine_eval,
25
CModelSelectionParameters
* model_parameters)
26
:
CModelSelection
(machine_eval, model_parameters)
27
{
28
}
29
30
CGridSearchModelSelection::~CGridSearchModelSelection
()
31
{
32
}
33
34
CParameterCombination
*
CGridSearchModelSelection::select_model
(
bool
print_state)
35
{
36
if
(print_state)
37
SG_PRINT
(
"Generating parameter combinations\n"
)
38
39
/* Retrieve all possible parameter combinations */
40
CDynamicObjectArray
* combinations=
41
(
CDynamicObjectArray
*)
m_model_parameters
->
get_combinations
();
42
43
CCrossValidationResult
* best_result=
new
CCrossValidationResult
();
44
45
CParameterCombination
* best_combination=NULL;
46
if
(
m_machine_eval
->
get_evaluation_direction
()==
ED_MAXIMIZE
)
47
{
48
if
(print_state)
SG_PRINT
(
"Direction is maximize\n"
)
49
best_result->
mean
=
CMath::ALMOST_NEG_INFTY
;
50
}
51
else
52
{
53
if
(print_state)
SG_PRINT
(
"Direction is minimize\n"
)
54
best_result->
mean
=
CMath::ALMOST_INFTY
;
55
}
56
57
/* underlying learning machine */
58
CMachine
* machine=
m_machine_eval
->
get_machine
();
59
60
/* apply all combinations and search for best one */
61
for
(
index_t
i=0; i<combinations->
get_num_elements
(); ++i)
62
{
63
CParameterCombination
* current_combination=(
CParameterCombination
*)
64
combinations->
get_element
(i);
65
66
/* eventually print */
67
if
(print_state)
68
{
69
SG_PRINT
(
"trying combination:\n"
)
70
current_combination->
print_tree
();
71
}
72
73
current_combination->
apply_to_modsel_parameter
(
74
machine->
m_model_selection_parameters
);
75
76
/* note that this may implicitly lock and unlockthe machine */
77
CCrossValidationResult
* result=
78
(
CCrossValidationResult
*)(
m_machine_eval
->
evaluate
());
79
80
if
(result->
get_result_type
() !=
CROSSVALIDATION_RESULT
)
81
SG_ERROR
(
"Evaluation result is not of type CCrossValidationResult!"
)
82
83
if
(print_state)
84
result->
print_result
();
85
86
/* check if current result is better, delete old combinations */
87
if
(
m_machine_eval
->
get_evaluation_direction
()==
ED_MAXIMIZE
)
88
{
89
if
(result->
mean
>best_result->
mean
)
90
{
91
if
(best_combination)
92
SG_UNREF
(best_combination);
93
94
best_combination=(
CParameterCombination
*)
95
combinations->
get_element
(i);
96
97
SG_REF
(result);
98
SG_UNREF
(best_result);
99
best_result=result;
100
}
101
else
102
{
103
CParameterCombination
* combination=(
CParameterCombination
*)
104
combinations->
get_element
(i);
105
SG_UNREF
(combination);
106
}
107
}
108
else
109
{
110
if
(result->
mean
<best_result->
mean
)
111
{
112
if
(best_combination)
113
SG_UNREF
(best_combination);
114
115
best_combination=(
CParameterCombination
*)
116
combinations->
get_element
(i);
117
118
SG_REF
(result);
119
SG_UNREF
(best_result);
120
best_result=result;
121
}
122
else
123
{
124
CParameterCombination
* combination=(
CParameterCombination
*)
125
combinations->
get_element
(i);
126
SG_UNREF
(combination);
127
}
128
}
129
130
SG_UNREF
(result);
131
SG_UNREF
(current_combination);
132
}
133
134
SG_UNREF
(best_result);
135
SG_UNREF
(machine);
136
SG_UNREF
(combinations);
137
138
return
best_combination;
139
}
SHOGUN
机器学习工具包 - 项目文档