SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
multiclass
tree
BalancedConditionalProbabilityTree.cpp
浏览该文件的文档.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2012 Chiyuan Zhang
8
* Copyright (C) 2012 Chiyuan Zhang
9
*/
10
11
#include <
shogun/multiclass/tree/BalancedConditionalProbabilityTree.h
>
12
13
using namespace
shogun;
14
15
CBalancedConditionalProbabilityTree::CBalancedConditionalProbabilityTree
()
16
:m_alpha(0.4)
17
{
18
SG_ADD
(&m_alpha,
"m_alpha"
,
"Trade-off parameter of tree balance"
,
MS_NOT_AVAILABLE
);
19
}
20
21
void
CBalancedConditionalProbabilityTree::set_alpha
(
float64_t
alpha)
22
{
23
if
(alpha < 0 || alpha > 1)
24
SG_ERROR
(
"expect 0 <= alpha <= 1, but got %g\n"
, alpha)
25
m_alpha = alpha;
26
}
27
28
bool
CBalancedConditionalProbabilityTree::which_subtree
(
bnode_t
*
node
,
SGVector<float32_t>
ex)
29
{
30
float64_t
pred =
predict_node
(ex, node);
31
float64_t
depth_left = tree_depth(node->
left
());
32
float64_t
depth_right = tree_depth(node->
right
());
33
34
float64_t
cnt_left =
CMath::pow
(2.0, depth_left);
35
float64_t
cnt_right =
CMath::pow
(2.0, depth_right);
36
37
float64_t
obj_val = (1-m_alpha) * 2 * (pred-0.5) + m_alpha *
CMath::log2
(cnt_left/cnt_right);
38
39
if
(obj_val > 0)
40
return
false
;
// go right
41
return
true
;
// go left
42
}
43
44
int32_t CBalancedConditionalProbabilityTree::tree_depth(
bnode_t
*
node
)
45
{
46
int32_t depth = 0;
47
while
(node != NULL)
48
{
49
depth++;
50
node = node->
left
();
51
}
52
53
return
depth;
54
}
SHOGUN
机器学习工具包 - 项目文档