SHOGUN
3.2.1
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
classifier
vw
learners
VwAdaptiveLearner.cpp
浏览该文件的文档.
1
/*
2
* Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights
3
* embodied in the content of this file are licensed under the BSD
4
* (revised) open source license.
5
*
6
* This program is free software; you can redistribute it and/or modify
7
* it under the terms of the GNU General Public License as published by
8
* the Free Software Foundation; either version 3 of the License, or
9
* (at your option) any later version.
10
*
11
* Written (W) 2011 Shashwat Lal Das
12
* Adaptation of Vowpal Wabbit v5.1.
13
* Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
14
*/
15
16
#include <
shogun/classifier/vw/learners/VwAdaptiveLearner.h
>
17
18
using namespace
shogun;
19
20
CVwAdaptiveLearner::CVwAdaptiveLearner
()
21
:
CVwLearner
()
22
{
23
}
24
25
CVwAdaptiveLearner::CVwAdaptiveLearner
(
CVwRegressor
* regressor,
CVwEnvironment
* vw_env)
26
:
CVwLearner
(regressor, vw_env)
27
{
28
}
29
30
CVwAdaptiveLearner::~CVwAdaptiveLearner
()
31
{
32
}
33
34
void
CVwAdaptiveLearner::train
(
VwExample
* &ex,
float32_t
update
)
35
{
36
if
(fabs(update) == 0.)
37
return
;
38
39
vw_size_t
thread_num = 0;
40
41
vw_size_t
thread_mask =
env
->
thread_mask
;
42
float32_t
* weights =
reg
->
weight_vectors
[thread_num];
43
44
float32_t
g =
reg
->
loss
->
get_square_grad
(ex->
final_prediction
, ex->
ld
->
label
) * ex->
ld
->
weight
;
45
vw_size_t
ctr = 0;
46
for
(
vw_size_t
* i = ex->
indices
.
begin
; i != ex->
indices
.
end
; i++)
47
{
48
for
(
VwFeature
*f = ex->
atomics
[*i].begin; f != ex->
atomics
[*i].end; f++)
49
{
50
float32_t
* w = &weights[f->weight_index & thread_mask];
51
w[1] += g * f->x * f->x;
52
float32_t
t = f->x *
CMath::invsqrt
(w[1]);
53
w[0] += update * t;
54
}
55
}
56
57
for
(int32_t k = 0; k <
env
->
pairs
.
get_num_elements
(); k++)
58
{
59
char
* i =
env
->
pairs
.
get_element
(k);
60
61
v_array<VwFeature>
temp = ex->
atomics
[(int32_t)(i[0])];
62
temp.
begin
= ex->
atomics
[(int32_t)(i[0])].begin;
63
temp.
end
= ex->
atomics
[(int32_t)(i[0])].end;
64
for
(; temp.
begin
!= temp.
end
; temp.
begin
++)
65
quad_update(weights, *temp.
begin
, ex->
atomics
[(int32_t)(i[1])], thread_mask, update, g, ex, ctr);
66
}
67
}
68
69
void
CVwAdaptiveLearner::quad_update(
float32_t
* weights,
VwFeature
& page_feature,
70
v_array<VwFeature>
&offer_features,
vw_size_t
mask,
71
float32_t
update
,
float32_t
g,
VwExample
* ex,
vw_size_t
& ctr)
72
{
73
vw_size_t
halfhash =
quadratic_constant
* page_feature.
weight_index
;
74
update *= page_feature.
x
;
75
float32_t
update2 = g * page_feature.
x
* page_feature.
x
;
76
77
for
(
VwFeature
* elem = offer_features.
begin
; elem != offer_features.
end
; elem++)
78
{
79
float32_t
* w = &weights[(halfhash + elem->weight_index) & mask];
80
w[1] += update2 * elem->x * elem->x;
81
float32_t
t = elem->x *
CMath::invsqrt
(w[1]);
82
w[0] += update * t;
83
}
84
}
SHOGUN
机器学习工具包 - 项目文档