$treeview $search $mathjax
Eigen
3.2.5
$projectbrief
|
$projectbrief
|
$searchbox |
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_PRODUCTBASE_H 00011 #define EIGEN_PRODUCTBASE_H 00012 00013 namespace Eigen { 00014 00020 namespace internal { 00021 template<typename Derived, typename _Lhs, typename _Rhs> 00022 struct traits<ProductBase<Derived,_Lhs,_Rhs> > 00023 { 00024 typedef MatrixXpr XprKind; 00025 typedef typename remove_all<_Lhs>::type Lhs; 00026 typedef typename remove_all<_Rhs>::type Rhs; 00027 typedef typename scalar_product_traits<typename Lhs::Scalar, typename Rhs::Scalar>::ReturnType Scalar; 00028 typedef typename promote_storage_type<typename traits<Lhs>::StorageKind, 00029 typename traits<Rhs>::StorageKind>::ret StorageKind; 00030 typedef typename promote_index_type<typename traits<Lhs>::Index, 00031 typename traits<Rhs>::Index>::type Index; 00032 enum { 00033 RowsAtCompileTime = traits<Lhs>::RowsAtCompileTime, 00034 ColsAtCompileTime = traits<Rhs>::ColsAtCompileTime, 00035 MaxRowsAtCompileTime = traits<Lhs>::MaxRowsAtCompileTime, 00036 MaxColsAtCompileTime = traits<Rhs>::MaxColsAtCompileTime, 00037 Flags = (MaxRowsAtCompileTime==1 ? RowMajorBit : 0) 00038 | EvalBeforeNestingBit | EvalBeforeAssigningBit | NestByRefBit, 00039 // Note that EvalBeforeNestingBit and NestByRefBit 00040 // are not used in practice because nested is overloaded for products 00041 CoeffReadCost = 0 // FIXME why is it needed ? 00042 }; 00043 }; 00044 } 00045 00046 #define EIGEN_PRODUCT_PUBLIC_INTERFACE(Derived) \ 00047 typedef ProductBase<Derived, Lhs, Rhs > Base; \ 00048 EIGEN_DENSE_PUBLIC_INTERFACE(Derived) \ 00049 typedef typename Base::LhsNested LhsNested; \ 00050 typedef typename Base::_LhsNested _LhsNested; \ 00051 typedef typename Base::LhsBlasTraits LhsBlasTraits; \ 00052 typedef typename Base::ActualLhsType ActualLhsType; \ 00053 typedef typename Base::_ActualLhsType _ActualLhsType; \ 00054 typedef typename Base::RhsNested RhsNested; \ 00055 typedef typename Base::_RhsNested _RhsNested; \ 00056 typedef typename Base::RhsBlasTraits RhsBlasTraits; \ 00057 typedef typename Base::ActualRhsType ActualRhsType; \ 00058 typedef typename Base::_ActualRhsType _ActualRhsType; \ 00059 using Base::m_lhs; \ 00060 using Base::m_rhs; 00061 00062 template<typename Derived, typename Lhs, typename Rhs> 00063 class ProductBase : public MatrixBase<Derived> 00064 { 00065 public: 00066 typedef MatrixBase<Derived> Base; 00067 EIGEN_DENSE_PUBLIC_INTERFACE(ProductBase) 00068 00069 typedef typename Lhs::Nested LhsNested; 00070 typedef typename internal::remove_all<LhsNested>::type _LhsNested; 00071 typedef internal::blas_traits<_LhsNested> LhsBlasTraits; 00072 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; 00073 typedef typename internal::remove_all<ActualLhsType>::type _ActualLhsType; 00074 typedef typename internal::traits<Lhs>::Scalar LhsScalar; 00075 00076 typedef typename Rhs::Nested RhsNested; 00077 typedef typename internal::remove_all<RhsNested>::type _RhsNested; 00078 typedef internal::blas_traits<_RhsNested> RhsBlasTraits; 00079 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; 00080 typedef typename internal::remove_all<ActualRhsType>::type _ActualRhsType; 00081 typedef typename internal::traits<Rhs>::Scalar RhsScalar; 00082 00083 // Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once 00084 typedef CoeffBasedProduct<LhsNested, RhsNested, 0> FullyLazyCoeffBaseProductType; 00085 00086 public: 00087 00088 #ifndef EIGEN_NO_MALLOC 00089 typedef typename Base::PlainObject BasePlainObject; 00090 typedef Matrix<Scalar,RowsAtCompileTime==1?1:Dynamic,ColsAtCompileTime==1?1:Dynamic,BasePlainObject::Options> DynPlainObject; 00091 typedef typename internal::conditional<(BasePlainObject::SizeAtCompileTime==Dynamic) || (BasePlainObject::SizeAtCompileTime*int(sizeof(Scalar)) < int(EIGEN_STACK_ALLOCATION_LIMIT)), 00092 BasePlainObject, DynPlainObject>::type PlainObject; 00093 #else 00094 typedef typename Base::PlainObject PlainObject; 00095 #endif 00096 00097 ProductBase(const Lhs& a_lhs, const Rhs& a_rhs) 00098 : m_lhs(a_lhs), m_rhs(a_rhs) 00099 { 00100 eigen_assert(a_lhs.cols() == a_rhs.rows() 00101 && "invalid matrix product" 00102 && "if you wanted a coeff-wise or a dot product use the respective explicit functions"); 00103 } 00104 00105 inline Index rows() const { return m_lhs.rows(); } 00106 inline Index cols() const { return m_rhs.cols(); } 00107 00108 template<typename Dest> 00109 inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst,Scalar(1)); } 00110 00111 template<typename Dest> 00112 inline void addTo(Dest& dst) const { scaleAndAddTo(dst,Scalar(1)); } 00113 00114 template<typename Dest> 00115 inline void subTo(Dest& dst) const { scaleAndAddTo(dst,Scalar(-1)); } 00116 00117 template<typename Dest> 00118 inline void scaleAndAddTo(Dest& dst, const Scalar& alpha) const { derived().scaleAndAddTo(dst,alpha); } 00119 00120 const _LhsNested& lhs() const { return m_lhs; } 00121 const _RhsNested& rhs() const { return m_rhs; } 00122 00123 // Implicit conversion to the nested type (trigger the evaluation of the product) 00124 operator const PlainObject& () const 00125 { 00126 m_result.resize(m_lhs.rows(), m_rhs.cols()); 00127 derived().evalTo(m_result); 00128 return m_result; 00129 } 00130 00131 const Diagonal<const FullyLazyCoeffBaseProductType,0> diagonal() const 00132 { return FullyLazyCoeffBaseProductType(m_lhs, m_rhs); } 00133 00134 template<int Index> 00135 const Diagonal<FullyLazyCoeffBaseProductType,Index> diagonal() const 00136 { return FullyLazyCoeffBaseProductType(m_lhs, m_rhs); } 00137 00138 const Diagonal<FullyLazyCoeffBaseProductType,Dynamic> diagonal(Index index) const 00139 { return FullyLazyCoeffBaseProductType(m_lhs, m_rhs).diagonal(index); } 00140 00141 // restrict coeff accessors to 1x1 expressions. No need to care about mutators here since this isnt a Lvalue expression 00142 typename Base::CoeffReturnType coeff(Index row, Index col) const 00143 { 00144 #ifdef EIGEN2_SUPPORT 00145 return lhs().row(row).cwiseProduct(rhs().col(col).transpose()).sum(); 00146 #else 00147 EIGEN_STATIC_ASSERT_SIZE_1x1(Derived) 00148 eigen_assert(this->rows() == 1 && this->cols() == 1); 00149 Matrix<Scalar,1,1> result = *this; 00150 return result.coeff(row,col); 00151 #endif 00152 } 00153 00154 typename Base::CoeffReturnType coeff(Index i) const 00155 { 00156 EIGEN_STATIC_ASSERT_SIZE_1x1(Derived) 00157 eigen_assert(this->rows() == 1 && this->cols() == 1); 00158 Matrix<Scalar,1,1> result = *this; 00159 return result.coeff(i); 00160 } 00161 00162 const Scalar& coeffRef(Index row, Index col) const 00163 { 00164 EIGEN_STATIC_ASSERT_SIZE_1x1(Derived) 00165 eigen_assert(this->rows() == 1 && this->cols() == 1); 00166 return derived().coeffRef(row,col); 00167 } 00168 00169 const Scalar& coeffRef(Index i) const 00170 { 00171 EIGEN_STATIC_ASSERT_SIZE_1x1(Derived) 00172 eigen_assert(this->rows() == 1 && this->cols() == 1); 00173 return derived().coeffRef(i); 00174 } 00175 00176 protected: 00177 00178 LhsNested m_lhs; 00179 RhsNested m_rhs; 00180 00181 mutable PlainObject m_result; 00182 }; 00183 00184 // here we need to overload the nested rule for products 00185 // such that the nested type is a const reference to a plain matrix 00186 namespace internal { 00187 template<typename Lhs, typename Rhs, int Mode, int N, typename PlainObject> 00188 struct nested<GeneralProduct<Lhs,Rhs,Mode>, N, PlainObject> 00189 { 00190 typedef typename GeneralProduct<Lhs,Rhs,Mode>::PlainObject const& type; 00191 }; 00192 template<typename Lhs, typename Rhs, int Mode, int N, typename PlainObject> 00193 struct nested<const GeneralProduct<Lhs,Rhs,Mode>, N, PlainObject> 00194 { 00195 typedef typename GeneralProduct<Lhs,Rhs,Mode>::PlainObject const& type; 00196 }; 00197 } 00198 00199 template<typename NestedProduct> 00200 class ScaledProduct; 00201 00202 // Note that these two operator* functions are not defined as member 00203 // functions of ProductBase, because, otherwise we would have to 00204 // define all overloads defined in MatrixBase. Furthermore, Using 00205 // "using Base::operator*" would not work with MSVC. 00206 // 00207 // Also note that here we accept any compatible scalar types 00208 template<typename Derived,typename Lhs,typename Rhs> 00209 const ScaledProduct<Derived> 00210 operator*(const ProductBase<Derived,Lhs,Rhs>& prod, const typename Derived::Scalar& x) 00211 { return ScaledProduct<Derived>(prod.derived(), x); } 00212 00213 template<typename Derived,typename Lhs,typename Rhs> 00214 typename internal::enable_if<!internal::is_same<typename Derived::Scalar,typename Derived::RealScalar>::value, 00215 const ScaledProduct<Derived> >::type 00216 operator*(const ProductBase<Derived,Lhs,Rhs>& prod, const typename Derived::RealScalar& x) 00217 { return ScaledProduct<Derived>(prod.derived(), x); } 00218 00219 00220 template<typename Derived,typename Lhs,typename Rhs> 00221 const ScaledProduct<Derived> 00222 operator*(const typename Derived::Scalar& x,const ProductBase<Derived,Lhs,Rhs>& prod) 00223 { return ScaledProduct<Derived>(prod.derived(), x); } 00224 00225 template<typename Derived,typename Lhs,typename Rhs> 00226 typename internal::enable_if<!internal::is_same<typename Derived::Scalar,typename Derived::RealScalar>::value, 00227 const ScaledProduct<Derived> >::type 00228 operator*(const typename Derived::RealScalar& x,const ProductBase<Derived,Lhs,Rhs>& prod) 00229 { return ScaledProduct<Derived>(prod.derived(), x); } 00230 00231 namespace internal { 00232 template<typename NestedProduct> 00233 struct traits<ScaledProduct<NestedProduct> > 00234 : traits<ProductBase<ScaledProduct<NestedProduct>, 00235 typename NestedProduct::_LhsNested, 00236 typename NestedProduct::_RhsNested> > 00237 { 00238 typedef typename traits<NestedProduct>::StorageKind StorageKind; 00239 }; 00240 } 00241 00242 template<typename NestedProduct> 00243 class ScaledProduct 00244 : public ProductBase<ScaledProduct<NestedProduct>, 00245 typename NestedProduct::_LhsNested, 00246 typename NestedProduct::_RhsNested> 00247 { 00248 public: 00249 typedef ProductBase<ScaledProduct<NestedProduct>, 00250 typename NestedProduct::_LhsNested, 00251 typename NestedProduct::_RhsNested> Base; 00252 typedef typename Base::Scalar Scalar; 00253 typedef typename Base::PlainObject PlainObject; 00254 // EIGEN_PRODUCT_PUBLIC_INTERFACE(ScaledProduct) 00255 00256 ScaledProduct(const NestedProduct& prod, const Scalar& x) 00257 : Base(prod.lhs(),prod.rhs()), m_prod(prod), m_alpha(x) {} 00258 00259 template<typename Dest> 00260 inline void evalTo(Dest& dst) const { dst.setZero(); scaleAndAddTo(dst, Scalar(1)); } 00261 00262 template<typename Dest> 00263 inline void addTo(Dest& dst) const { scaleAndAddTo(dst, Scalar(1)); } 00264 00265 template<typename Dest> 00266 inline void subTo(Dest& dst) const { scaleAndAddTo(dst, Scalar(-1)); } 00267 00268 template<typename Dest> 00269 inline void scaleAndAddTo(Dest& dst, const Scalar& a_alpha) const { m_prod.derived().scaleAndAddTo(dst,a_alpha * m_alpha); } 00270 00271 const Scalar& alpha() const { return m_alpha; } 00272 00273 protected: 00274 const NestedProduct& m_prod; 00275 Scalar m_alpha; 00276 }; 00277 00280 template<typename Derived> 00281 template<typename ProductDerived, typename Lhs, typename Rhs> 00282 Derived& MatrixBase<Derived>::lazyAssign(const ProductBase<ProductDerived, Lhs,Rhs>& other) 00283 { 00284 other.derived().evalTo(derived()); 00285 return derived(); 00286 } 00287 00288 } // end namespace Eigen 00289 00290 #endif // EIGEN_PRODUCTBASE_H