SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
modelselection
RandomSearchModelSelection.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Copyright (C) 2011 Heiko Strathmann
8
* Copyright (C) 2012 Sergey Lisitsyn
9
*/
10
11
#include <
shogun/modelselection/RandomSearchModelSelection.h
>
12
#include <
shogun/modelselection/ParameterCombination.h
>
13
#include <
shogun/modelselection/ModelSelectionParameters.h
>
14
#include <
shogun/evaluation/CrossValidation.h
>
15
#include <
shogun/mathematics/Statistics.h
>
16
#include <
shogun/machine/Machine.h
>
17
18
using namespace
shogun;
19
20
CRandomSearchModelSelection::CRandomSearchModelSelection
() :
CModelSelection
()
21
{
22
set_ratio
(0.5);
23
}
24
25
CRandomSearchModelSelection::CRandomSearchModelSelection
(
26
CMachineEvaluation
* machine_eval,
27
CModelSelectionParameters
* model_parameters,
float64_t
ratio)
28
:
CModelSelection
(machine_eval, model_parameters)
29
{
30
set_ratio
(ratio);
31
}
32
33
CRandomSearchModelSelection::~CRandomSearchModelSelection
()
34
{
35
}
36
37
CParameterCombination
*
CRandomSearchModelSelection::select_model
(
bool
print_state)
38
{
39
if
(print_state)
40
SG_PRINT
(
"Generating parameter combinations\n"
)
41
42
/* Retrieve all possible parameter combinations */
43
CDynamicObjectArray
* all_combinations=
44
(
CDynamicObjectArray
*)
m_model_parameters
->
get_combinations
();
45
46
int32_t n_all_combinations=all_combinations->
get_num_elements
();
47
SGVector<index_t>
combinations_indices=
CStatistics::sample_indices
(n_all_combinations*
m_ratio
, n_all_combinations);
48
49
CDynamicObjectArray
* combinations=
new
CDynamicObjectArray
();
50
51
for
(int32_t i=0; i<combinations_indices.
vlen
; i++)
52
combinations->
append_element
(all_combinations->
get_element
(i));
53
54
CCrossValidationResult
* best_result=
new
CCrossValidationResult
();
55
56
CParameterCombination
* best_combination=NULL;
57
if
(
m_machine_eval
->
get_evaluation_direction
()==
ED_MAXIMIZE
)
58
{
59
if
(print_state)
SG_PRINT
(
"Direction is maximize\n"
)
60
best_result->
mean
=
CMath::ALMOST_NEG_INFTY
;
61
}
62
else
63
{
64
if
(print_state)
SG_PRINT
(
"Direction is minimize\n"
)
65
best_result->
mean
=
CMath::ALMOST_INFTY
;
66
}
67
68
/* underlying learning machine */
69
CMachine
* machine=
m_machine_eval
->
get_machine
();
70
71
/* apply all combinations and search for best one */
72
for
(
index_t
i=0; i<combinations->
get_num_elements
(); ++i)
73
{
74
CParameterCombination
* current_combination=(
CParameterCombination
*)
75
combinations->
get_element
(i);
76
77
/* eventually print */
78
if
(print_state)
79
{
80
SG_PRINT
(
"trying combination:\n"
)
81
current_combination->
print_tree
();
82
}
83
84
current_combination->
apply_to_modsel_parameter
(
85
machine->
m_model_selection_parameters
);
86
87
/* note that this may implicitly lock and unlockthe machine */
88
CCrossValidationResult
* result =
89
(
CCrossValidationResult
*)(
m_machine_eval
->
evaluate
());
90
91
if
(result->
get_result_type
() !=
CROSSVALIDATION_RESULT
)
92
SG_ERROR
(
"Evaluation result is not of type CCrossValidationResult!"
)
93
94
if
(print_state)
95
result->
print_result
();
96
97
/* check if current result is better, delete old combinations */
98
if
(
m_machine_eval
->
get_evaluation_direction
()==
ED_MAXIMIZE
)
99
{
100
if
(result->
mean
>best_result->
mean
)
101
{
102
if
(best_combination)
103
SG_UNREF
(best_combination);
104
105
best_combination=(
CParameterCombination
*)
106
combinations->
get_element
(i);
107
108
SG_REF
(result);
109
SG_UNREF
(best_result);
110
best_result=result;
111
}
112
else
113
{
114
CParameterCombination
* combination=(
CParameterCombination
*)
115
combinations->
get_element
(i);
116
SG_UNREF
(combination);
117
}
118
}
119
else
120
{
121
if
(result->
mean
<best_result->
mean
)
122
{
123
if
(best_combination)
124
SG_UNREF
(best_combination);
125
126
best_combination=(
CParameterCombination
*)
127
combinations->
get_element
(i);
128
129
SG_REF
(result);
130
SG_UNREF
(best_result);
131
best_result=result;
132
}
133
else
134
{
135
CParameterCombination
* combination=(
CParameterCombination
*)
136
combinations->
get_element
(i);
137
SG_UNREF
(combination);
138
}
139
}
140
141
SG_UNREF
(result);
142
SG_UNREF
(current_combination);
143
}
144
145
SG_UNREF
(best_result);
146
SG_UNREF
(machine);
147
SG_UNREF
(combinations);
148
149
return
best_combination;
150
}
SHOGUN
机器学习工具包 - 项目文档