$treeview $search $mathjax
Eigen
3.2.5
$projectbrief
|
$projectbrief
|
$searchbox |
00001 // This file is part of Eigen, a lightweight C++ template library 00002 // for linear algebra. 00003 // 00004 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr> 00005 // 00006 // This Source Code Form is subject to the terms of the Mozilla 00007 // Public License v. 2.0. If a copy of the MPL was not distributed 00008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 00009 00010 #ifndef EIGEN_SPARSE_CWISE_BINARY_OP_H 00011 #define EIGEN_SPARSE_CWISE_BINARY_OP_H 00012 00013 namespace Eigen { 00014 00015 // Here we have to handle 3 cases: 00016 // 1 - sparse op dense 00017 // 2 - dense op sparse 00018 // 3 - sparse op sparse 00019 // We also need to implement a 4th iterator for: 00020 // 4 - dense op dense 00021 // Finally, we also need to distinguish between the product and other operations : 00022 // configuration returned mode 00023 // 1 - sparse op dense product sparse 00024 // generic dense 00025 // 2 - dense op sparse product sparse 00026 // generic dense 00027 // 3 - sparse op sparse product sparse 00028 // generic sparse 00029 // 4 - dense op dense product dense 00030 // generic dense 00031 00032 namespace internal { 00033 00034 template<> struct promote_storage_type<Dense,Sparse> 00035 { typedef Sparse ret; }; 00036 00037 template<> struct promote_storage_type<Sparse,Dense> 00038 { typedef Sparse ret; }; 00039 00040 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived, 00041 typename _LhsStorageMode = typename traits<Lhs>::StorageKind, 00042 typename _RhsStorageMode = typename traits<Rhs>::StorageKind> 00043 class sparse_cwise_binary_op_inner_iterator_selector; 00044 00045 } // end namespace internal 00046 00047 template<typename BinaryOp, typename Lhs, typename Rhs> 00048 class CwiseBinaryOpImpl<BinaryOp, Lhs, Rhs, Sparse> 00049 : public SparseMatrixBase<CwiseBinaryOp<BinaryOp, Lhs, Rhs> > 00050 { 00051 public: 00052 class InnerIterator; 00053 class ReverseInnerIterator; 00054 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> Derived; 00055 EIGEN_SPARSE_PUBLIC_INTERFACE(Derived) 00056 CwiseBinaryOpImpl() 00057 { 00058 typedef typename internal::traits<Lhs>::StorageKind LhsStorageKind; 00059 typedef typename internal::traits<Rhs>::StorageKind RhsStorageKind; 00060 EIGEN_STATIC_ASSERT(( 00061 (!internal::is_same<LhsStorageKind,RhsStorageKind>::value) 00062 || ((Lhs::Flags&RowMajorBit) == (Rhs::Flags&RowMajorBit))), 00063 THE_STORAGE_ORDER_OF_BOTH_SIDES_MUST_MATCH); 00064 } 00065 }; 00066 00067 template<typename BinaryOp, typename Lhs, typename Rhs> 00068 class CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator 00069 : public internal::sparse_cwise_binary_op_inner_iterator_selector<BinaryOp,Lhs,Rhs,typename CwiseBinaryOpImpl<BinaryOp,Lhs,Rhs,Sparse>::InnerIterator> 00070 { 00071 public: 00072 typedef typename Lhs::Index Index; 00073 typedef internal::sparse_cwise_binary_op_inner_iterator_selector< 00074 BinaryOp,Lhs,Rhs, InnerIterator> Base; 00075 00076 // NOTE: we have to prefix Index by "typename Lhs::" to avoid an ICE with VC11 00077 EIGEN_STRONG_INLINE InnerIterator(const CwiseBinaryOpImpl& binOp, typename Lhs::Index outer) 00078 : Base(binOp.derived(),outer) 00079 {} 00080 }; 00081 00082 /*************************************************************************** 00083 * Implementation of inner-iterators 00084 ***************************************************************************/ 00085 00086 // template<typename T> struct internal::func_is_conjunction { enum { ret = false }; }; 00087 // template<typename T> struct internal::func_is_conjunction<internal::scalar_product_op<T> > { enum { ret = true }; }; 00088 00089 // TODO generalize the internal::scalar_product_op specialization to all conjunctions if any ! 00090 00091 namespace internal { 00092 00093 // sparse - sparse (generic) 00094 template<typename BinaryOp, typename Lhs, typename Rhs, typename Derived> 00095 class sparse_cwise_binary_op_inner_iterator_selector<BinaryOp, Lhs, Rhs, Derived, Sparse, Sparse> 00096 { 00097 typedef CwiseBinaryOp<BinaryOp, Lhs, Rhs> CwiseBinaryXpr; 00098 typedef typename traits<CwiseBinaryXpr>::Scalar Scalar; 00099 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 00100 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 00101 typedef typename _LhsNested::InnerIterator LhsIterator; 00102 typedef typename _RhsNested::InnerIterator RhsIterator; 00103 typedef typename Lhs::Index Index; 00104 00105 public: 00106 00107 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 00108 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()) 00109 { 00110 this->operator++(); 00111 } 00112 00113 EIGEN_STRONG_INLINE Derived& operator++() 00114 { 00115 if (m_lhsIter && m_rhsIter && (m_lhsIter.index() == m_rhsIter.index())) 00116 { 00117 m_id = m_lhsIter.index(); 00118 m_value = m_functor(m_lhsIter.value(), m_rhsIter.value()); 00119 ++m_lhsIter; 00120 ++m_rhsIter; 00121 } 00122 else if (m_lhsIter && (!m_rhsIter || (m_lhsIter.index() < m_rhsIter.index()))) 00123 { 00124 m_id = m_lhsIter.index(); 00125 m_value = m_functor(m_lhsIter.value(), Scalar(0)); 00126 ++m_lhsIter; 00127 } 00128 else if (m_rhsIter && (!m_lhsIter || (m_lhsIter.index() > m_rhsIter.index()))) 00129 { 00130 m_id = m_rhsIter.index(); 00131 m_value = m_functor(Scalar(0), m_rhsIter.value()); 00132 ++m_rhsIter; 00133 } 00134 else 00135 { 00136 m_value = 0; // this is to avoid a compilation warning 00137 m_id = -1; 00138 } 00139 return *static_cast<Derived*>(this); 00140 } 00141 00142 EIGEN_STRONG_INLINE Scalar value() const { return m_value; } 00143 00144 EIGEN_STRONG_INLINE Index index() const { return m_id; } 00145 EIGEN_STRONG_INLINE Index row() const { return Lhs::IsRowMajor ? m_lhsIter.row() : index(); } 00146 EIGEN_STRONG_INLINE Index col() const { return Lhs::IsRowMajor ? index() : m_lhsIter.col(); } 00147 00148 EIGEN_STRONG_INLINE operator bool() const { return m_id>=0; } 00149 00150 protected: 00151 LhsIterator m_lhsIter; 00152 RhsIterator m_rhsIter; 00153 const BinaryOp& m_functor; 00154 Scalar m_value; 00155 Index m_id; 00156 }; 00157 00158 // sparse - sparse (product) 00159 template<typename T, typename Lhs, typename Rhs, typename Derived> 00160 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Sparse> 00161 { 00162 typedef scalar_product_op<T> BinaryFunc; 00163 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 00164 typedef typename CwiseBinaryXpr::Scalar Scalar; 00165 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 00166 typedef typename _LhsNested::InnerIterator LhsIterator; 00167 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 00168 typedef typename _RhsNested::InnerIterator RhsIterator; 00169 typedef typename Lhs::Index Index; 00170 public: 00171 00172 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 00173 : m_lhsIter(xpr.lhs(),outer), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()) 00174 { 00175 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 00176 { 00177 if (m_lhsIter.index() < m_rhsIter.index()) 00178 ++m_lhsIter; 00179 else 00180 ++m_rhsIter; 00181 } 00182 } 00183 00184 EIGEN_STRONG_INLINE Derived& operator++() 00185 { 00186 ++m_lhsIter; 00187 ++m_rhsIter; 00188 while (m_lhsIter && m_rhsIter && (m_lhsIter.index() != m_rhsIter.index())) 00189 { 00190 if (m_lhsIter.index() < m_rhsIter.index()) 00191 ++m_lhsIter; 00192 else 00193 ++m_rhsIter; 00194 } 00195 return *static_cast<Derived*>(this); 00196 } 00197 00198 EIGEN_STRONG_INLINE Scalar value() const { return m_functor(m_lhsIter.value(), m_rhsIter.value()); } 00199 00200 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); } 00201 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 00202 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 00203 00204 EIGEN_STRONG_INLINE operator bool() const { return (m_lhsIter && m_rhsIter); } 00205 00206 protected: 00207 LhsIterator m_lhsIter; 00208 RhsIterator m_rhsIter; 00209 const BinaryFunc& m_functor; 00210 }; 00211 00212 // sparse - dense (product) 00213 template<typename T, typename Lhs, typename Rhs, typename Derived> 00214 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Sparse, Dense> 00215 { 00216 typedef scalar_product_op<T> BinaryFunc; 00217 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 00218 typedef typename CwiseBinaryXpr::Scalar Scalar; 00219 typedef typename traits<CwiseBinaryXpr>::_LhsNested _LhsNested; 00220 typedef typename traits<CwiseBinaryXpr>::RhsNested RhsNested; 00221 typedef typename _LhsNested::InnerIterator LhsIterator; 00222 typedef typename Lhs::Index Index; 00223 enum { IsRowMajor = (int(Lhs::Flags)&RowMajorBit)==RowMajorBit }; 00224 public: 00225 00226 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 00227 : m_rhs(xpr.rhs()), m_lhsIter(xpr.lhs(),outer), m_functor(xpr.functor()), m_outer(outer) 00228 {} 00229 00230 EIGEN_STRONG_INLINE Derived& operator++() 00231 { 00232 ++m_lhsIter; 00233 return *static_cast<Derived*>(this); 00234 } 00235 00236 EIGEN_STRONG_INLINE Scalar value() const 00237 { return m_functor(m_lhsIter.value(), 00238 m_rhs.coeff(IsRowMajor?m_outer:m_lhsIter.index(),IsRowMajor?m_lhsIter.index():m_outer)); } 00239 00240 EIGEN_STRONG_INLINE Index index() const { return m_lhsIter.index(); } 00241 EIGEN_STRONG_INLINE Index row() const { return m_lhsIter.row(); } 00242 EIGEN_STRONG_INLINE Index col() const { return m_lhsIter.col(); } 00243 00244 EIGEN_STRONG_INLINE operator bool() const { return m_lhsIter; } 00245 00246 protected: 00247 RhsNested m_rhs; 00248 LhsIterator m_lhsIter; 00249 const BinaryFunc m_functor; 00250 const Index m_outer; 00251 }; 00252 00253 // sparse - dense (product) 00254 template<typename T, typename Lhs, typename Rhs, typename Derived> 00255 class sparse_cwise_binary_op_inner_iterator_selector<scalar_product_op<T>, Lhs, Rhs, Derived, Dense, Sparse> 00256 { 00257 typedef scalar_product_op<T> BinaryFunc; 00258 typedef CwiseBinaryOp<BinaryFunc, Lhs, Rhs> CwiseBinaryXpr; 00259 typedef typename CwiseBinaryXpr::Scalar Scalar; 00260 typedef typename traits<CwiseBinaryXpr>::_RhsNested _RhsNested; 00261 typedef typename _RhsNested::InnerIterator RhsIterator; 00262 typedef typename Lhs::Index Index; 00263 00264 enum { IsRowMajor = (int(Rhs::Flags)&RowMajorBit)==RowMajorBit }; 00265 public: 00266 00267 EIGEN_STRONG_INLINE sparse_cwise_binary_op_inner_iterator_selector(const CwiseBinaryXpr& xpr, Index outer) 00268 : m_xpr(xpr), m_rhsIter(xpr.rhs(),outer), m_functor(xpr.functor()), m_outer(outer) 00269 {} 00270 00271 EIGEN_STRONG_INLINE Derived& operator++() 00272 { 00273 ++m_rhsIter; 00274 return *static_cast<Derived*>(this); 00275 } 00276 00277 EIGEN_STRONG_INLINE Scalar value() const 00278 { return m_functor(m_xpr.lhs().coeff(IsRowMajor?m_outer:m_rhsIter.index(),IsRowMajor?m_rhsIter.index():m_outer), m_rhsIter.value()); } 00279 00280 EIGEN_STRONG_INLINE Index index() const { return m_rhsIter.index(); } 00281 EIGEN_STRONG_INLINE Index row() const { return m_rhsIter.row(); } 00282 EIGEN_STRONG_INLINE Index col() const { return m_rhsIter.col(); } 00283 00284 EIGEN_STRONG_INLINE operator bool() const { return m_rhsIter; } 00285 00286 protected: 00287 const CwiseBinaryXpr& m_xpr; 00288 RhsIterator m_rhsIter; 00289 const BinaryFunc& m_functor; 00290 const Index m_outer; 00291 }; 00292 00293 } // end namespace internal 00294 00295 /*************************************************************************** 00296 * Implementation of SparseMatrixBase and SparseCwise functions/operators 00297 ***************************************************************************/ 00298 00299 template<typename Derived> 00300 template<typename OtherDerived> 00301 EIGEN_STRONG_INLINE Derived & 00302 SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other) 00303 { 00304 return derived() = derived() - other.derived(); 00305 } 00306 00307 template<typename Derived> 00308 template<typename OtherDerived> 00309 EIGEN_STRONG_INLINE Derived & 00310 SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other) 00311 { 00312 return derived() = derived() + other.derived(); 00313 } 00314 00315 template<typename Derived> 00316 template<typename OtherDerived> 00317 EIGEN_STRONG_INLINE const EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE 00318 SparseMatrixBase<Derived>::cwiseProduct(const MatrixBase<OtherDerived> &other) const 00319 { 00320 return EIGEN_SPARSE_CWISE_PRODUCT_RETURN_TYPE(derived(), other.derived()); 00321 } 00322 00323 } // end namespace Eigen 00324 00325 #endif // EIGEN_SPARSE_CWISE_BINARY_OP_H