SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
evaluation
CrossValidationMKLStorage.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Sergey Lisitsyn
8
* Written (W) 2012 Heiko Strathmann
9
*/
10
11
#include <
shogun/evaluation/CrossValidationMKLStorage.h
>
12
#include <
shogun/kernel/CombinedKernel.h
>
13
#include <
shogun/classifier/mkl/MKL.h
>
14
#include <
shogun/classifier/mkl/MKLMulticlass.h
>
15
16
using namespace
shogun;
17
18
void
CCrossValidationMKLStorage::update_trained_machine
(
19
CMachine
* machine,
const
char
* prefix)
20
{
21
REQUIRE
(machine,
"%s::update_trained_machine(): Provided Machine is NULL!\n"
,
22
get_name
());
23
24
CMKL
* mkl=
dynamic_cast<
CMKL
*
>
(machine);
25
CMKLMulticlass
* mkl_multiclass=
dynamic_cast<
CMKLMulticlass
*
>
(machine);
26
REQUIRE
(mkl || mkl_multiclass,
"%s::update_trained_machine(): This method is only usable "
27
"with CMKL derived machines. This one is \"%s\"\n"
,
get_name
(),
28
machine->
get_name
());
29
30
CKernel
* kernel = NULL;
31
if
(mkl)
32
kernel = mkl->
get_kernel
();
33
else
34
kernel = mkl_multiclass->
get_kernel
();
35
36
REQUIRE
(kernel,
"%s::update_trained_machine(): No kernel assigned to "
37
"machine of type \"%s\"\n"
,
get_name
(), machine->
get_name
());
38
39
CCombinedKernel
* combined_kernel=
dynamic_cast<
CCombinedKernel
*
>
(kernel);
40
REQUIRE
(combined_kernel,
"%s::update_trained_machine(): This method is only"
41
" usable with CCombinedKernel on machines. This one is \"s\"\n"
,
42
get_name
(), kernel->get_name());
43
44
SGVector<float64_t>
w=combined_kernel->get_subkernel_weights();
45
46
/* evtl re-allocate memory (different number of runs from evaluation before) */
47
if
(
m_mkl_weights
.
num_rows
!=w.vlen ||
48
m_mkl_weights
.
num_cols
!=
m_num_folds
*
m_num_runs
)
49
{
50
if
(
m_mkl_weights
.
matrix
)
51
{
52
SG_DEBUG
(
"deleting memory for mkl weight matrix\n"
)
53
m_mkl_weights
=
SGMatrix<float64_t>
();
54
}
55
}
56
57
/* evtl allocate memory (first call) */
58
if
(!
m_mkl_weights
.
matrix
)
59
{
60
SG_DEBUG
(
"allocating memory for mkl weight matrix\n"
)
61
m_mkl_weights
=
SGMatrix<float64_t>
(w.vlen,
m_num_folds
*
m_num_runs
);
62
}
63
64
/* put current mkl weights into matrix, copy memory vector wise to make
65
* things fast. Compute index of address to where vector goes */
66
67
/* number of runs is w.vlen*m_num_folds shift */
68
index_t
run_shift=
m_current_run_index
*w.vlen*
m_num_folds
;
69
70
/* fold shift is m_current_fold_index*w-vlen */
71
index_t
fold_shift=
m_current_fold_index
*w.vlen;
72
73
/* add both index shifts */
74
index_t
first_idx=run_shift+fold_shift;
75
SG_DEBUG
(
"run %d, fold %d, matrix index %d\n"
,
m_current_run_index
,
76
m_current_fold_index
, first_idx);
77
78
/* copy memory */
79
memcpy(&
m_mkl_weights
.
matrix
[first_idx], w.vector,
80
w.vlen*
sizeof
(
float64_t
));
81
82
SG_UNREF
(kernel);
83
}
SHOGUN
机器学习工具包 - 项目文档