[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_nodeproxy.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_RANDOM_FOREST_NP_HXX
00037 #define VIGRA_RANDOM_FOREST_NP_HXX
00038 
00039 #include <algorithm>
00040 #include <map>
00041 #include <numeric>
00042 #include "vigra/mathutil.hxx"
00043 #include "vigra/array_vector.hxx"
00044 #include "vigra/sized_int.hxx"
00045 #include "vigra/matrix.hxx"
00046 #include "vigra/random.hxx"
00047 #include "vigra/functorexpression.hxx"
00048 
00049 
00050 namespace vigra
00051 {
00052 
00053 
00054 
00055 enum NodeTags
00056 {
00057     UnFilledNode        = 42,
00058     AllColumns          = 0x00000000,
00059     ToBePrunedTag       = 0x80000000,
00060     LeafNodeTag         = 0x40000000,
00061 
00062     i_ThresholdNode     = 0,
00063     i_HyperplaneNode    = 1,
00064     i_HypersphereNode   = 2,
00065     e_ConstProbNode     = 0 | LeafNodeTag,
00066     e_LogRegProbNode    = 1 | LeafNodeTag
00067 };
00068 
00069 /** NodeBase class.
00070 
00071     \ingroup DecicionTree
00072 
00073     This class implements common features of all nodes.
00074     Memory Structure:
00075         Int32   Array:  TypeID, ParameterAddr, Child0, Child1, [ColumnData]0_
00076         double  Array:  NodeWeight, [Parameters]1_
00077 
00078         TODO: Throw away the crappy iterators and use vigra::ArrayVectorView
00079              it is not like anybody else is going to use this NodeBase class
00080              is it?
00081 
00082         TODO: use the RF_Traits::ProblemSpec_t to specify the external 
00083              parameters instead of the options.
00084 */
00085 
00086 
00087 class NodeBase
00088 {
00089   public:
00090     typedef Int32                               INT;
00091     typedef ArrayVector<INT>                    T_Container_type;
00092     typedef ArrayVector<double>                 P_Container_type;
00093     typedef T_Container_type::iterator          Topology_type;
00094     typedef P_Container_type::iterator          Parameter_type;
00095 
00096 
00097     mutable Topology_type                       topology_;
00098     int                                         topology_size_;
00099 
00100     mutable Parameter_type                      parameters_;
00101     int                                         parameter_size_ ;
00102 
00103         // Tree Parameters
00104     int                                         featureCount_;
00105     int                                         classCount_;
00106 
00107         // Node Parameters
00108     bool                                        hasData_;
00109 
00110 
00111 
00112 
00113     /** get Node Weight
00114      */
00115     double &      weights()
00116     {
00117             return parameters_begin()[0];
00118     }
00119 
00120     double const &      weights() const
00121     {
00122             return parameters_begin()[0];
00123     }
00124 
00125     /** has the data been set?
00126      * todo: throw this out - bad design
00127      */
00128     bool          data() const
00129     {
00130         return hasData_;
00131     }
00132 
00133     /** get the node type id
00134      * \sa NodeTags
00135      */
00136     INT&          typeID()
00137     {
00138         return topology_[0];
00139     }
00140 
00141     INT const &          typeID() const
00142     {
00143         return topology_[0];
00144     }
00145 
00146     /** Where in the parameter_ array are the weights?
00147      */
00148     INT &          parameter_addr()
00149     {
00150         return topology_[1];
00151     }
00152 
00153     INT const &    parameter_addr() const
00154     {
00155         return topology_[1];
00156     }
00157 
00158     /** Column Range **/
00159     Topology_type  column_data() const
00160     {
00161         return topology_ + 4 ;
00162     }
00163 
00164     /** get the start iterator to the columns
00165      *  - once again - throw out - static members are crap.
00166      */
00167     Topology_type columns_begin() const
00168     {
00169             return column_data()+1;
00170     }
00171 
00172     /** how many columns?
00173      */
00174     int      columns_size() const
00175     {
00176         if(*column_data() == AllColumns)
00177             return featureCount_;
00178         else
00179             return *column_data();;
00180     }
00181 
00182     /** end iterator to the columns
00183      */
00184     Topology_type  columns_end() const
00185     {
00186         return columns_begin() + columns_size();
00187     }
00188 
00189     /** Topology Range - gives access to the raw Topo memory
00190      * the size_ member was added as a result of premature 
00191      * optimisation.
00192      */ 
00193     Topology_type   topology_begin() const
00194     {
00195         return topology_;
00196     }
00197     Topology_type   topology_end() const
00198     {
00199         return topology_begin() + topology_size();
00200     }
00201     int          topology_size() const
00202     {
00203         return topology_size_;
00204     }
00205 
00206     /** Parameter Range **/
00207     Parameter_type  parameters_begin() const
00208     {
00209         return parameters_;
00210     }
00211     Parameter_type  parameters_end() const
00212     {
00213         return parameters_begin() + parameters_size();
00214     }
00215 
00216     int          parameters_size() const
00217     {
00218         return parameter_size_;
00219     }
00220 
00221 
00222     /** where are the child nodes?
00223      */
00224     INT &           child(Int32 l)
00225     {
00226         return topology_begin()[2+l];
00227     }
00228 
00229     /** where are the child nodes?
00230      */
00231     INT const  &           child(Int32 l) const
00232     {
00233         return topology_begin()[2+l];
00234     }
00235 
00236     /** Default Constructor**/
00237     NodeBase()
00238     :
00239                     hasData_(false)
00240     {}
00241     void copy(const NodeBase& o)
00242     {
00243         vigra_precondition(topology_size_==o.topology_size_,"Cannot copy nodes of different sizes");
00244         vigra_precondition(featureCount_==o.featureCount_,"Cannot copy nodes with different feature count");
00245         vigra_precondition(classCount_==o.classCount_,"Cannot copy nodes with different class counts");
00246         vigra_precondition(parameters_size() ==o.parameters_size(),"Cannot copy nodes with different parameter sizes");
00247         std::copy(o.topology_begin(), o.topology_end(), topology_);
00248         std::copy(o.parameters_begin(),o.parameters_end(), parameters_);
00249     }
00250 
00251     /** create ReadOnly Base Node at position n (actual length is unknown)
00252      * only common features i.e. children etc are accessible.
00253      */
00254     NodeBase(   T_Container_type const   &  topology,
00255                 P_Container_type const   &  parameter,
00256                 INT                         n)
00257     :
00258                     topology_   (const_cast<Topology_type>(topology.begin()+ n)),
00259                     topology_size_(4),
00260                     parameters_  (const_cast<Parameter_type>(parameter.begin() + parameter_addr())),
00261                     parameter_size_(1),
00262                     featureCount_(topology[0]),
00263                     classCount_(topology[1]),
00264                     hasData_(true)
00265     {
00266         /*while((int)xrange.size() <  featureCount_)
00267             xrange.push_back(xrange.size());*/
00268     }
00269 
00270     /** create ReadOnly node with known length (the parameter range is valid)
00271      */
00272     NodeBase(   int                      tLen,
00273                 int                      pLen,
00274                 T_Container_type const & topology,
00275                 P_Container_type const & parameter,
00276                 INT                         n)
00277     :
00278                     topology_   (const_cast<Topology_type>(topology.begin()+ n)),
00279                     topology_size_(tLen),
00280                     parameters_  (const_cast<Parameter_type>(parameter.begin() + parameter_addr())),
00281                     parameter_size_(pLen),
00282                     featureCount_(topology[0]),
00283                     classCount_(topology[1]),
00284                     hasData_(true)
00285     {
00286         /*while((int)xrange.size() <  featureCount_)
00287             xrange.push_back(xrange.size());*/
00288     }
00289     /** create ReadOnly node with known length 
00290      * from existing Node
00291      */
00292     NodeBase(   int                      tLen,
00293                 int                      pLen,
00294                 NodeBase &               node)
00295     :
00296                     topology_   (node.topology_),
00297                     topology_size_(tLen),
00298                     parameters_  (node.parameters_),
00299                     parameter_size_(pLen),
00300                     featureCount_(node.featureCount_),
00301                     classCount_(node.classCount_),
00302                     hasData_(true)
00303     {
00304         /*while((int)xrange.size() <  featureCount_)
00305             xrange.push_back(xrange.size());*/
00306     }
00307 
00308 
00309    /** create new Node at end of vector
00310     * \param tLen number of integers needed in the topolog vector
00311     * \param pLen number of parameters needed (this includes the node
00312     *           weight)
00313     * \param topology reference to Topology array of decision tree.
00314     * \param parameter reference to Parameter array of decision tree.
00315     **/
00316     NodeBase(   int                      tLen,
00317                 int                      pLen,
00318                 T_Container_type   &        topology,
00319                 P_Container_type   &        parameter)
00320     :
00321                     topology_size_(tLen),
00322                     parameter_size_(pLen),
00323                     featureCount_(topology[0]),
00324                     classCount_(topology[1]),
00325                     hasData_(true)
00326     {
00327         /*while((int)xrange.size() <  featureCount_)
00328             xrange.push_back(xrange.size());*/
00329 
00330         size_t n = topology.size();
00331         for(int ii = 0; ii < tLen; ++ii)
00332             topology.push_back(0);
00333         //topology.resize (n  + tLen);
00334 
00335         topology_           =   topology.begin()+ n;
00336         typeID()            =   UnFilledNode;
00337 
00338         parameter_addr()    =   static_cast<int>(parameter.size());
00339 
00340         //parameter.resize(parameter.size() + pLen);
00341         for(int ii = 0; ii < pLen; ++ii)
00342             parameter.push_back(0);
00343 
00344         parameters_          =   parameter.begin()+ parameter_addr();
00345         weights() = 1;
00346     }
00347 
00348 
00349   /** PseudoCopy Constructor  - 
00350    *
00351    * Copy Node to the end of a container. 
00352    * Since each Node views on different data there can't be a real 
00353    * copy constructor (unless both objects should point to the 
00354    * same underlying data.                                  
00355    */
00356     NodeBase(   NodeBase      const  &    toCopy,
00357                 T_Container_type      &    topology,
00358                 P_Container_type     &    parameter)
00359     :
00360                     topology_size_(toCopy.topology_size()),
00361                     parameter_size_(toCopy.parameters_size()),
00362                     featureCount_(topology[0]),
00363                     classCount_(topology[1]),
00364                     hasData_(true)
00365     {
00366         /*while((int)xrange.size() <  featureCount_)
00367             xrange.push_back(xrange.size());*/
00368 
00369         size_t n            = topology.size();
00370         for(int ii = 0; ii < toCopy.topology_size(); ++ii)
00371             topology.push_back(toCopy.topology_begin()[ii]);
00372 //        topology.insert(topology.end(), toCopy.topology_begin(), toCopy.topology_end());
00373         topology_           =   topology.begin()+ n;
00374         parameter_addr()    =   static_cast<int>(parameter.size());
00375         for(int ii = 0; ii < toCopy.parameters_size(); ++ii)
00376             parameter.push_back(toCopy.parameters_begin()[ii]);
00377 //        parameter.insert(parameter.end(), toCopy.parameters_begin(), toCopy.parameters_end());
00378         parameters_          =   parameter.begin()+ parameter_addr();
00379     }
00380 };
00381 
00382 
00383 template<NodeTags NodeType>
00384 class Node;
00385 
00386 template<>
00387 class Node<i_ThresholdNode>
00388 : public NodeBase
00389 {
00390 
00391 
00392     public:
00393     typedef NodeBase BT;
00394 
00395         /**constructors **/
00396 
00397     Node(   BT::T_Container_type &   topology,
00398             BT::P_Container_type &   param)
00399                 :   BT(5,2,topology, param)
00400     {
00401         BT::typeID() = i_ThresholdNode;
00402     }
00403 
00404     Node(   BT::T_Container_type const     &   topology,
00405             BT::P_Container_type const     &   param,
00406                     INT                   n             )
00407                 :   BT(5,2,topology, param, n)
00408     {}
00409 
00410     Node( BT & node_)
00411         :   BT(5, 2, node_) 
00412     {}
00413 
00414     double& threshold()
00415     {
00416         return BT::parameters_begin()[1];
00417     }
00418 
00419     double const & threshold() const
00420     {
00421         return BT::parameters_begin()[1];
00422     }
00423 
00424     BT::INT& column()
00425     {
00426         return BT::column_data()[0];
00427     }
00428     BT::INT const & column() const
00429     {
00430         return BT::column_data()[0];
00431     }
00432 
00433     template<class U, class C>
00434     BT::INT  next(MultiArrayView<2,U,C> const & feature) const
00435     {
00436         return (feature(0, column()) < threshold())? child(0):child(1);
00437     }
00438 };
00439 
00440 
00441 template<>
00442 class Node<i_HyperplaneNode>
00443 : public NodeBase
00444 {
00445     public:
00446 
00447     typedef NodeBase BT;
00448 
00449         /**constructors **/
00450 
00451     Node(           int                      nCol,
00452                     BT::T_Container_type    &   topology,
00453                     BT::P_Container_type    &   split_param)
00454                 :   BT(nCol + 5,nCol + 2,topology, split_param)
00455     {
00456         BT::typeID() = i_HyperplaneNode;
00457     }
00458 
00459     Node(           BT::T_Container_type  const  &   topology,
00460                     BT::P_Container_type  const  &   split_param,
00461                     int                  n             )
00462                 :   NodeBase(5 , 2,topology, split_param, n)
00463     {
00464         //TODO : is there a more elegant way to do this?
00465         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
00466                                         0
00467                                     :   BT::column_data()[0];
00468         BT::parameter_size_ += BT::columns_size();
00469     }
00470 
00471     Node( BT & node_)
00472         :   BT(5, 2, node_) 
00473     {
00474         //TODO : is there a more elegant way to do this?
00475         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
00476                                         0
00477                                     :   BT::column_data()[0];
00478         BT::parameter_size_ += BT::columns_size();
00479     }
00480 
00481 
00482     double const & intercept() const
00483     {
00484         return BT::parameters_begin()[1];
00485     }
00486     double& intercept()
00487     {
00488         return BT::parameters_begin()[1];
00489     }
00490 
00491     BT::Parameter_type weights() const
00492     {
00493         return BT::parameters_begin()+2;
00494     }
00495 
00496     BT::Parameter_type weights()
00497     {
00498         return BT::parameters_begin()+2;
00499     }
00500 
00501 
00502     template<class U, class C>
00503     BT::INT next(MultiArrayView<2,U,C> const & feature) const
00504     {
00505         double result = -1 * intercept();
00506         if(*(BT::column_data()) == AllColumns)
00507         {
00508             for(int ii = 0; ii < BT::columns_size(); ++ii)
00509             {
00510                 result +=feature[ii] * weights()[ii];
00511             }
00512         }
00513         else
00514         {
00515             for(int ii = 0; ii < BT::columns_size(); ++ii)
00516             {
00517                 result +=feature[BT::columns_begin()[ii]] * weights()[ii];
00518             }
00519         }
00520         return result < 0 ? BT::child(0)
00521                           : BT::child(1);
00522     }
00523 };
00524 
00525 
00526 
00527 template<>
00528 class Node<i_HypersphereNode>
00529 : public NodeBase
00530 {
00531     public:
00532 
00533     typedef NodeBase BT;
00534 
00535         /**constructors **/
00536 
00537     Node(           int                      nCol,
00538                     BT::T_Container_type    &   topology,
00539                     BT::P_Container_type    &   param)
00540                 :   NodeBase(nCol + 5,nCol + 1,topology, param)
00541     {
00542         BT::typeID() = i_HypersphereNode;
00543     }
00544 
00545     Node(           BT::T_Container_type  const  &   topology,
00546                     BT::P_Container_type  const  &  param,
00547                     int                  n             )
00548                 :   NodeBase(5, 1,topology, param, n)
00549     {
00550         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
00551                                         0
00552                                     :   BT::column_data()[0];
00553         BT::parameter_size_ += BT::columns_size();
00554     }
00555 
00556     Node( BT & node_)
00557         :   BT(5, 1, node_) 
00558     {
00559         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
00560                                         0
00561                                     :   BT::column_data()[0];
00562         BT::parameter_size_ += BT::columns_size();
00563 
00564     }
00565 
00566     double const & squaredRadius() const
00567     {
00568         return BT::parameters_begin()[1];
00569     }
00570 
00571     double& squaredRadius()
00572     {
00573         return BT::parameters_begin()[1];
00574     }
00575 
00576     BT::Parameter_type center() const
00577     {
00578         return BT::parameters_begin()+2;
00579     }
00580 
00581     BT::Parameter_type center()
00582     {
00583         return BT::parameters_begin()+2;
00584     }
00585 
00586     template<class U, class C>
00587     BT::INT next(MultiArrayView<2,U,C> const & feature) const
00588     {
00589         double result = -1 * squaredRadius();
00590         if(*(BT::column_data()) == AllColumns)
00591         {
00592             for(int ii = 0; ii < BT::columns_size(); ++ii)
00593             {
00594                 result += (feature[ii] - center()[ii])*
00595                           (feature[ii] - center()[ii]);
00596             }
00597         }
00598         else
00599         {
00600             for(int ii = 0; ii < BT::columns_size(); ++ii)
00601             {
00602                 result += (feature[BT::columns_begin()[ii]] - center()[ii])*
00603                           (feature[BT::columns_begin()[ii]] - center()[ii]);
00604             }
00605         }
00606         return result < 0 ? BT::child(0)
00607                           : BT::child(1);
00608     }
00609 };
00610 
00611 
00612 /** ExteriorNodeBase class.
00613 
00614     \ingroup DecicionTree
00615 
00616     This class implements common features of all interior nodes.
00617     All interior nodes are derived classes of ExteriorNodeBase.
00618 */
00619 
00620 
00621 
00622 
00623 
00624 
00625 template<>
00626 class Node<e_ConstProbNode>
00627 : public NodeBase
00628 {
00629     public:
00630 
00631     typedef     NodeBase    BT;
00632 
00633     Node(           BT::T_Container_type    &   topology,
00634                     BT::P_Container_type    &   param)
00635                     :
00636                 BT(2,topology[1]+1, topology, param)
00637 
00638     {
00639         BT::typeID() = e_ConstProbNode;
00640     }
00641 
00642 
00643     Node(           BT::T_Container_type const &   topology,
00644                     BT::P_Container_type const &   param,
00645                     int                  n             )
00646                 :   BT(2, topology[1]+1,topology, param, n)
00647     { }
00648 
00649 
00650     Node( BT & node_)
00651         :   BT(2, node_.classCount_ +1, node_) 
00652     {}
00653     BT::Parameter_type  prob_begin() const
00654     {
00655         return BT::parameters_begin()+1;
00656     }
00657     BT::Parameter_type  prob_end() const
00658     {
00659         return prob_begin() + prob_size();
00660     }
00661     int prob_size() const
00662     {
00663         return BT::classCount_;
00664     }
00665 };
00666 
00667 template<>
00668 class Node<e_LogRegProbNode>;
00669 
00670 } // namespace vigra
00671 
00672 #endif //RF_nodeproxy

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.8.0 (20 Sep 2011)