[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_nodeproxy.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 #ifndef VIGRA_RANDOM_FOREST_NP_HXX 00037 #define VIGRA_RANDOM_FOREST_NP_HXX 00038 00039 #include <algorithm> 00040 #include <map> 00041 #include <numeric> 00042 #include "vigra/mathutil.hxx" 00043 #include "vigra/array_vector.hxx" 00044 #include "vigra/sized_int.hxx" 00045 #include "vigra/matrix.hxx" 00046 #include "vigra/random.hxx" 00047 #include "vigra/functorexpression.hxx" 00048 00049 00050 namespace vigra 00051 { 00052 00053 00054 00055 enum NodeTags 00056 { 00057 UnFilledNode = 42, 00058 AllColumns = 0x00000000, 00059 ToBePrunedTag = 0x80000000, 00060 LeafNodeTag = 0x40000000, 00061 00062 i_ThresholdNode = 0, 00063 i_HyperplaneNode = 1, 00064 i_HypersphereNode = 2, 00065 e_ConstProbNode = 0 | LeafNodeTag, 00066 e_LogRegProbNode = 1 | LeafNodeTag 00067 }; 00068 00069 /** NodeBase class. 00070 00071 \ingroup DecicionTree 00072 00073 This class implements common features of all nodes. 00074 Memory Structure: 00075 Int32 Array: TypeID, ParameterAddr, Child0, Child1, [ColumnData]0_ 00076 double Array: NodeWeight, [Parameters]1_ 00077 00078 TODO: Throw away the crappy iterators and use vigra::ArrayVectorView 00079 it is not like anybody else is going to use this NodeBase class 00080 is it? 00081 00082 TODO: use the RF_Traits::ProblemSpec_t to specify the external 00083 parameters instead of the options. 00084 */ 00085 00086 00087 class NodeBase 00088 { 00089 public: 00090 typedef Int32 INT; 00091 typedef ArrayVector<INT> T_Container_type; 00092 typedef ArrayVector<double> P_Container_type; 00093 typedef T_Container_type::iterator Topology_type; 00094 typedef P_Container_type::iterator Parameter_type; 00095 00096 00097 mutable Topology_type topology_; 00098 int topology_size_; 00099 00100 mutable Parameter_type parameters_; 00101 int parameter_size_ ; 00102 00103 // Tree Parameters 00104 int featureCount_; 00105 int classCount_; 00106 00107 // Node Parameters 00108 bool hasData_; 00109 00110 00111 00112 00113 /** get Node Weight 00114 */ 00115 double & weights() 00116 { 00117 return parameters_begin()[0]; 00118 } 00119 00120 double const & weights() const 00121 { 00122 return parameters_begin()[0]; 00123 } 00124 00125 /** has the data been set? 00126 * todo: throw this out - bad design 00127 */ 00128 bool data() const 00129 { 00130 return hasData_; 00131 } 00132 00133 /** get the node type id 00134 * \sa NodeTags 00135 */ 00136 INT& typeID() 00137 { 00138 return topology_[0]; 00139 } 00140 00141 INT const & typeID() const 00142 { 00143 return topology_[0]; 00144 } 00145 00146 /** Where in the parameter_ array are the weights? 00147 */ 00148 INT & parameter_addr() 00149 { 00150 return topology_[1]; 00151 } 00152 00153 INT const & parameter_addr() const 00154 { 00155 return topology_[1]; 00156 } 00157 00158 /** Column Range **/ 00159 Topology_type column_data() const 00160 { 00161 return topology_ + 4 ; 00162 } 00163 00164 /** get the start iterator to the columns 00165 * - once again - throw out - static members are crap. 00166 */ 00167 Topology_type columns_begin() const 00168 { 00169 return column_data()+1; 00170 } 00171 00172 /** how many columns? 00173 */ 00174 int columns_size() const 00175 { 00176 if(*column_data() == AllColumns) 00177 return featureCount_; 00178 else 00179 return *column_data();; 00180 } 00181 00182 /** end iterator to the columns 00183 */ 00184 Topology_type columns_end() const 00185 { 00186 return columns_begin() + columns_size(); 00187 } 00188 00189 /** Topology Range - gives access to the raw Topo memory 00190 * the size_ member was added as a result of premature 00191 * optimisation. 00192 */ 00193 Topology_type topology_begin() const 00194 { 00195 return topology_; 00196 } 00197 Topology_type topology_end() const 00198 { 00199 return topology_begin() + topology_size(); 00200 } 00201 int topology_size() const 00202 { 00203 return topology_size_; 00204 } 00205 00206 /** Parameter Range **/ 00207 Parameter_type parameters_begin() const 00208 { 00209 return parameters_; 00210 } 00211 Parameter_type parameters_end() const 00212 { 00213 return parameters_begin() + parameters_size(); 00214 } 00215 00216 int parameters_size() const 00217 { 00218 return parameter_size_; 00219 } 00220 00221 00222 /** where are the child nodes? 00223 */ 00224 INT & child(Int32 l) 00225 { 00226 return topology_begin()[2+l]; 00227 } 00228 00229 /** where are the child nodes? 00230 */ 00231 INT const & child(Int32 l) const 00232 { 00233 return topology_begin()[2+l]; 00234 } 00235 00236 /** Default Constructor**/ 00237 NodeBase() 00238 : 00239 hasData_(false) 00240 {} 00241 void copy(const NodeBase& o) 00242 { 00243 vigra_precondition(topology_size_==o.topology_size_,"Cannot copy nodes of different sizes"); 00244 vigra_precondition(featureCount_==o.featureCount_,"Cannot copy nodes with different feature count"); 00245 vigra_precondition(classCount_==o.classCount_,"Cannot copy nodes with different class counts"); 00246 vigra_precondition(parameters_size() ==o.parameters_size(),"Cannot copy nodes with different parameter sizes"); 00247 std::copy(o.topology_begin(), o.topology_end(), topology_); 00248 std::copy(o.parameters_begin(),o.parameters_end(), parameters_); 00249 } 00250 00251 /** create ReadOnly Base Node at position n (actual length is unknown) 00252 * only common features i.e. children etc are accessible. 00253 */ 00254 NodeBase( T_Container_type const & topology, 00255 P_Container_type const & parameter, 00256 INT n) 00257 : 00258 topology_ (const_cast<Topology_type>(topology.begin()+ n)), 00259 topology_size_(4), 00260 parameters_ (const_cast<Parameter_type>(parameter.begin() + parameter_addr())), 00261 parameter_size_(1), 00262 featureCount_(topology[0]), 00263 classCount_(topology[1]), 00264 hasData_(true) 00265 { 00266 /*while((int)xrange.size() < featureCount_) 00267 xrange.push_back(xrange.size());*/ 00268 } 00269 00270 /** create ReadOnly node with known length (the parameter range is valid) 00271 */ 00272 NodeBase( int tLen, 00273 int pLen, 00274 T_Container_type const & topology, 00275 P_Container_type const & parameter, 00276 INT n) 00277 : 00278 topology_ (const_cast<Topology_type>(topology.begin()+ n)), 00279 topology_size_(tLen), 00280 parameters_ (const_cast<Parameter_type>(parameter.begin() + parameter_addr())), 00281 parameter_size_(pLen), 00282 featureCount_(topology[0]), 00283 classCount_(topology[1]), 00284 hasData_(true) 00285 { 00286 /*while((int)xrange.size() < featureCount_) 00287 xrange.push_back(xrange.size());*/ 00288 } 00289 /** create ReadOnly node with known length 00290 * from existing Node 00291 */ 00292 NodeBase( int tLen, 00293 int pLen, 00294 NodeBase & node) 00295 : 00296 topology_ (node.topology_), 00297 topology_size_(tLen), 00298 parameters_ (node.parameters_), 00299 parameter_size_(pLen), 00300 featureCount_(node.featureCount_), 00301 classCount_(node.classCount_), 00302 hasData_(true) 00303 { 00304 /*while((int)xrange.size() < featureCount_) 00305 xrange.push_back(xrange.size());*/ 00306 } 00307 00308 00309 /** create new Node at end of vector 00310 * \param tLen number of integers needed in the topolog vector 00311 * \param pLen number of parameters needed (this includes the node 00312 * weight) 00313 * \param topology reference to Topology array of decision tree. 00314 * \param parameter reference to Parameter array of decision tree. 00315 **/ 00316 NodeBase( int tLen, 00317 int pLen, 00318 T_Container_type & topology, 00319 P_Container_type & parameter) 00320 : 00321 topology_size_(tLen), 00322 parameter_size_(pLen), 00323 featureCount_(topology[0]), 00324 classCount_(topology[1]), 00325 hasData_(true) 00326 { 00327 /*while((int)xrange.size() < featureCount_) 00328 xrange.push_back(xrange.size());*/ 00329 00330 size_t n = topology.size(); 00331 for(int ii = 0; ii < tLen; ++ii) 00332 topology.push_back(0); 00333 //topology.resize (n + tLen); 00334 00335 topology_ = topology.begin()+ n; 00336 typeID() = UnFilledNode; 00337 00338 parameter_addr() = static_cast<int>(parameter.size()); 00339 00340 //parameter.resize(parameter.size() + pLen); 00341 for(int ii = 0; ii < pLen; ++ii) 00342 parameter.push_back(0); 00343 00344 parameters_ = parameter.begin()+ parameter_addr(); 00345 weights() = 1; 00346 } 00347 00348 00349 /** PseudoCopy Constructor - 00350 * 00351 * Copy Node to the end of a container. 00352 * Since each Node views on different data there can't be a real 00353 * copy constructor (unless both objects should point to the 00354 * same underlying data. 00355 */ 00356 NodeBase( NodeBase const & toCopy, 00357 T_Container_type & topology, 00358 P_Container_type & parameter) 00359 : 00360 topology_size_(toCopy.topology_size()), 00361 parameter_size_(toCopy.parameters_size()), 00362 featureCount_(topology[0]), 00363 classCount_(topology[1]), 00364 hasData_(true) 00365 { 00366 /*while((int)xrange.size() < featureCount_) 00367 xrange.push_back(xrange.size());*/ 00368 00369 size_t n = topology.size(); 00370 for(int ii = 0; ii < toCopy.topology_size(); ++ii) 00371 topology.push_back(toCopy.topology_begin()[ii]); 00372 // topology.insert(topology.end(), toCopy.topology_begin(), toCopy.topology_end()); 00373 topology_ = topology.begin()+ n; 00374 parameter_addr() = static_cast<int>(parameter.size()); 00375 for(int ii = 0; ii < toCopy.parameters_size(); ++ii) 00376 parameter.push_back(toCopy.parameters_begin()[ii]); 00377 // parameter.insert(parameter.end(), toCopy.parameters_begin(), toCopy.parameters_end()); 00378 parameters_ = parameter.begin()+ parameter_addr(); 00379 } 00380 }; 00381 00382 00383 template<NodeTags NodeType> 00384 class Node; 00385 00386 template<> 00387 class Node<i_ThresholdNode> 00388 : public NodeBase 00389 { 00390 00391 00392 public: 00393 typedef NodeBase BT; 00394 00395 /**constructors **/ 00396 00397 Node( BT::T_Container_type & topology, 00398 BT::P_Container_type & param) 00399 : BT(5,2,topology, param) 00400 { 00401 BT::typeID() = i_ThresholdNode; 00402 } 00403 00404 Node( BT::T_Container_type const & topology, 00405 BT::P_Container_type const & param, 00406 INT n ) 00407 : BT(5,2,topology, param, n) 00408 {} 00409 00410 Node( BT & node_) 00411 : BT(5, 2, node_) 00412 {} 00413 00414 double& threshold() 00415 { 00416 return BT::parameters_begin()[1]; 00417 } 00418 00419 double const & threshold() const 00420 { 00421 return BT::parameters_begin()[1]; 00422 } 00423 00424 BT::INT& column() 00425 { 00426 return BT::column_data()[0]; 00427 } 00428 BT::INT const & column() const 00429 { 00430 return BT::column_data()[0]; 00431 } 00432 00433 template<class U, class C> 00434 BT::INT next(MultiArrayView<2,U,C> const & feature) const 00435 { 00436 return (feature(0, column()) < threshold())? child(0):child(1); 00437 } 00438 }; 00439 00440 00441 template<> 00442 class Node<i_HyperplaneNode> 00443 : public NodeBase 00444 { 00445 public: 00446 00447 typedef NodeBase BT; 00448 00449 /**constructors **/ 00450 00451 Node( int nCol, 00452 BT::T_Container_type & topology, 00453 BT::P_Container_type & split_param) 00454 : BT(nCol + 5,nCol + 2,topology, split_param) 00455 { 00456 BT::typeID() = i_HyperplaneNode; 00457 } 00458 00459 Node( BT::T_Container_type const & topology, 00460 BT::P_Container_type const & split_param, 00461 int n ) 00462 : NodeBase(5 , 2,topology, split_param, n) 00463 { 00464 //TODO : is there a more elegant way to do this? 00465 BT::topology_size_ += BT::column_data()[0]== AllColumns ? 00466 0 00467 : BT::column_data()[0]; 00468 BT::parameter_size_ += BT::columns_size(); 00469 } 00470 00471 Node( BT & node_) 00472 : BT(5, 2, node_) 00473 { 00474 //TODO : is there a more elegant way to do this? 00475 BT::topology_size_ += BT::column_data()[0]== AllColumns ? 00476 0 00477 : BT::column_data()[0]; 00478 BT::parameter_size_ += BT::columns_size(); 00479 } 00480 00481 00482 double const & intercept() const 00483 { 00484 return BT::parameters_begin()[1]; 00485 } 00486 double& intercept() 00487 { 00488 return BT::parameters_begin()[1]; 00489 } 00490 00491 BT::Parameter_type weights() const 00492 { 00493 return BT::parameters_begin()+2; 00494 } 00495 00496 BT::Parameter_type weights() 00497 { 00498 return BT::parameters_begin()+2; 00499 } 00500 00501 00502 template<class U, class C> 00503 BT::INT next(MultiArrayView<2,U,C> const & feature) const 00504 { 00505 double result = -1 * intercept(); 00506 if(*(BT::column_data()) == AllColumns) 00507 { 00508 for(int ii = 0; ii < BT::columns_size(); ++ii) 00509 { 00510 result +=feature[ii] * weights()[ii]; 00511 } 00512 } 00513 else 00514 { 00515 for(int ii = 0; ii < BT::columns_size(); ++ii) 00516 { 00517 result +=feature[BT::columns_begin()[ii]] * weights()[ii]; 00518 } 00519 } 00520 return result < 0 ? BT::child(0) 00521 : BT::child(1); 00522 } 00523 }; 00524 00525 00526 00527 template<> 00528 class Node<i_HypersphereNode> 00529 : public NodeBase 00530 { 00531 public: 00532 00533 typedef NodeBase BT; 00534 00535 /**constructors **/ 00536 00537 Node( int nCol, 00538 BT::T_Container_type & topology, 00539 BT::P_Container_type & param) 00540 : NodeBase(nCol + 5,nCol + 1,topology, param) 00541 { 00542 BT::typeID() = i_HypersphereNode; 00543 } 00544 00545 Node( BT::T_Container_type const & topology, 00546 BT::P_Container_type const & param, 00547 int n ) 00548 : NodeBase(5, 1,topology, param, n) 00549 { 00550 BT::topology_size_ += BT::column_data()[0]== AllColumns ? 00551 0 00552 : BT::column_data()[0]; 00553 BT::parameter_size_ += BT::columns_size(); 00554 } 00555 00556 Node( BT & node_) 00557 : BT(5, 1, node_) 00558 { 00559 BT::topology_size_ += BT::column_data()[0]== AllColumns ? 00560 0 00561 : BT::column_data()[0]; 00562 BT::parameter_size_ += BT::columns_size(); 00563 00564 } 00565 00566 double const & squaredRadius() const 00567 { 00568 return BT::parameters_begin()[1]; 00569 } 00570 00571 double& squaredRadius() 00572 { 00573 return BT::parameters_begin()[1]; 00574 } 00575 00576 BT::Parameter_type center() const 00577 { 00578 return BT::parameters_begin()+2; 00579 } 00580 00581 BT::Parameter_type center() 00582 { 00583 return BT::parameters_begin()+2; 00584 } 00585 00586 template<class U, class C> 00587 BT::INT next(MultiArrayView<2,U,C> const & feature) const 00588 { 00589 double result = -1 * squaredRadius(); 00590 if(*(BT::column_data()) == AllColumns) 00591 { 00592 for(int ii = 0; ii < BT::columns_size(); ++ii) 00593 { 00594 result += (feature[ii] - center()[ii])* 00595 (feature[ii] - center()[ii]); 00596 } 00597 } 00598 else 00599 { 00600 for(int ii = 0; ii < BT::columns_size(); ++ii) 00601 { 00602 result += (feature[BT::columns_begin()[ii]] - center()[ii])* 00603 (feature[BT::columns_begin()[ii]] - center()[ii]); 00604 } 00605 } 00606 return result < 0 ? BT::child(0) 00607 : BT::child(1); 00608 } 00609 }; 00610 00611 00612 /** ExteriorNodeBase class. 00613 00614 \ingroup DecicionTree 00615 00616 This class implements common features of all interior nodes. 00617 All interior nodes are derived classes of ExteriorNodeBase. 00618 */ 00619 00620 00621 00622 00623 00624 00625 template<> 00626 class Node<e_ConstProbNode> 00627 : public NodeBase 00628 { 00629 public: 00630 00631 typedef NodeBase BT; 00632 00633 Node( BT::T_Container_type & topology, 00634 BT::P_Container_type & param) 00635 : 00636 BT(2,topology[1]+1, topology, param) 00637 00638 { 00639 BT::typeID() = e_ConstProbNode; 00640 } 00641 00642 00643 Node( BT::T_Container_type const & topology, 00644 BT::P_Container_type const & param, 00645 int n ) 00646 : BT(2, topology[1]+1,topology, param, n) 00647 { } 00648 00649 00650 Node( BT & node_) 00651 : BT(2, node_.classCount_ +1, node_) 00652 {} 00653 BT::Parameter_type prob_begin() const 00654 { 00655 return BT::parameters_begin()+1; 00656 } 00657 BT::Parameter_type prob_end() const 00658 { 00659 return prob_begin() + prob_size(); 00660 } 00661 int prob_size() const 00662 { 00663 return BT::classCount_; 00664 } 00665 }; 00666 00667 template<> 00668 class Node<e_LogRegProbNode>; 00669 00670 } // namespace vigra 00671 00672 #endif //RF_nodeproxy
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|