10 #ifndef EIGEN_MATRIX_SQUARE_ROOT
11 #define EIGEN_MATRIX_SQUARE_ROOT
26 template <
typename MatrixType>
42 eigen_assert(A.rows() == A.cols());
53 template <
typename ResultType>
void compute(ResultType &result);
56 typedef typename MatrixType::Index Index;
57 typedef typename MatrixType::Scalar Scalar;
59 void computeDiagonalPartOfSqrt(MatrixType& sqrtT,
const MatrixType& T);
60 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
const MatrixType& T);
61 void compute2x2diagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
typename MatrixType::Index i);
62 void compute1x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
63 typename MatrixType::Index i,
typename MatrixType::Index j);
64 void compute1x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
65 typename MatrixType::Index i,
typename MatrixType::Index j);
66 void compute2x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
67 typename MatrixType::Index i,
typename MatrixType::Index j);
68 void compute2x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
69 typename MatrixType::Index i,
typename MatrixType::Index j);
71 template <
typename SmallMatrixType>
72 static void solveAuxiliaryEquation(SmallMatrixType& X,
const SmallMatrixType& A,
73 const SmallMatrixType& B,
const SmallMatrixType& C);
75 const MatrixType& m_A;
78 template <
typename MatrixType>
79 template <
typename ResultType>
84 const MatrixType& T = schurOfA.
matrixT();
85 const MatrixType& U = schurOfA.
matrixU();
88 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
89 computeDiagonalPartOfSqrt(sqrtT, T);
90 computeOffDiagonalPartOfSqrt(sqrtT, T);
93 result = U * sqrtT * U.adjoint();
98 template <
typename MatrixType>
102 const Index size = m_A.rows();
103 for (Index i = 0; i < size; i++) {
104 if (i == size - 1 || T.coeff(i+1, i) == 0) {
105 eigen_assert(T(i,i) > 0);
106 sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
109 compute2x2diagonalBlock(sqrtT, T, i);
117 template <
typename MatrixType>
118 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
121 const Index size = m_A.rows();
122 for (Index j = 1; j < size; j++) {
123 if (T.coeff(j, j-1) != 0)
125 for (Index i = j-1; i >= 0; i--) {
126 if (i > 0 && T.coeff(i, i-1) != 0)
128 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
129 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
130 if (iBlockIs2x2 && jBlockIs2x2)
131 compute2x2offDiagonalBlock(sqrtT, T, i, j);
132 else if (iBlockIs2x2 && !jBlockIs2x2)
133 compute2x1offDiagonalBlock(sqrtT, T, i, j);
134 else if (!iBlockIs2x2 && jBlockIs2x2)
135 compute1x2offDiagonalBlock(sqrtT, T, i, j);
136 else if (!iBlockIs2x2 && !jBlockIs2x2)
137 compute1x1offDiagonalBlock(sqrtT, T, i, j);
144 template <
typename MatrixType>
145 void MatrixSquareRootQuasiTriangular<MatrixType>
146 ::compute2x2diagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
typename MatrixType::Index i)
150 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
151 EigenSolver<Matrix<Scalar,2,2> > es(block);
152 sqrtT.template block<2,2>(i,i)
153 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
159 template <
typename MatrixType>
160 void MatrixSquareRootQuasiTriangular<MatrixType>
161 ::compute1x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
162 typename MatrixType::Index i,
typename MatrixType::Index j)
164 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
165 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
169 template <
typename MatrixType>
170 void MatrixSquareRootQuasiTriangular<MatrixType>
171 ::compute1x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
172 typename MatrixType::Index i,
typename MatrixType::Index j)
174 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
176 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
177 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
178 A += sqrtT.template block<2,2>(j,j).transpose();
179 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
183 template <
typename MatrixType>
184 void MatrixSquareRootQuasiTriangular<MatrixType>
185 ::compute2x1offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
186 typename MatrixType::Index i,
typename MatrixType::Index j)
188 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
190 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
191 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
192 A += sqrtT.template block<2,2>(i,i);
193 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
197 template <
typename MatrixType>
198 void MatrixSquareRootQuasiTriangular<MatrixType>
199 ::compute2x2offDiagonalBlock(MatrixType& sqrtT,
const MatrixType& T,
200 typename MatrixType::Index i,
typename MatrixType::Index j)
202 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
203 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
204 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
206 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
207 Matrix<Scalar,2,2> X;
208 solveAuxiliaryEquation(X, A, B, C);
209 sqrtT.template block<2,2>(i,j) = X;
213 template <
typename MatrixType>
214 template <
typename SmallMatrixType>
215 void MatrixSquareRootQuasiTriangular<MatrixType>
216 ::solveAuxiliaryEquation(SmallMatrixType& X,
const SmallMatrixType& A,
217 const SmallMatrixType& B,
const SmallMatrixType& C)
219 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
220 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
222 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
223 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
224 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
225 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
226 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
227 coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
228 coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
229 coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
230 coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
231 coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
232 coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
233 coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
234 coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
236 Matrix<Scalar,4,1> rhs;
237 rhs.coeffRef(0) = C.coeff(0,0);
238 rhs.coeffRef(1) = C.coeff(0,1);
239 rhs.coeffRef(2) = C.coeff(1,0);
240 rhs.coeffRef(3) = C.coeff(1,1);
242 Matrix<Scalar,4,1> result;
243 result = coeffMatrix.fullPivLu().solve(rhs);
245 X.coeffRef(0,0) = result.coeff(0);
246 X.coeffRef(0,1) = result.coeff(1);
247 X.coeffRef(1,0) = result.coeff(2);
248 X.coeffRef(1,1) = result.coeff(3);
263 template <
typename MatrixType>
270 eigen_assert(A.rows() == A.cols());
282 template <
typename ResultType>
void compute(ResultType &result);
285 const MatrixType& m_A;
288 template <
typename MatrixType>
289 template <
typename ResultType>
294 const MatrixType& T = schurOfA.
matrixT();
295 const MatrixType& U = schurOfA.
matrixU();
299 result.resize(m_A.rows(), m_A.cols());
300 typedef typename MatrixType::Index Index;
301 for (Index i = 0; i < m_A.rows(); i++) {
302 result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
304 for (Index j = 1; j < m_A.cols(); j++) {
305 for (Index i = j-1; i >= 0; i--) {
306 typedef typename MatrixType::Scalar Scalar;
308 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
310 result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
316 tmp.noalias() = U * result.template triangularView<Upper>();
317 result.noalias() = tmp * U.adjoint();
328 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
349 template <
typename ResultType>
void compute(ResultType &result);
355 template <
typename MatrixType>
363 eigen_assert(A.rows() == A.cols());
366 template <
typename ResultType>
void compute(ResultType &result)
369 const RealSchur<MatrixType> schurOfA(m_A);
370 const MatrixType& T = schurOfA.matrixT();
371 const MatrixType& U = schurOfA.matrixU();
374 MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
375 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
379 result = U * sqrtT * U.adjoint();
383 const MatrixType& m_A;
389 template <
typename MatrixType>
390 class MatrixSquareRoot<MatrixType, 1>
397 eigen_assert(A.rows() == A.cols());
400 template <
typename ResultType>
void compute(ResultType &result)
403 const ComplexSchur<MatrixType> schurOfA(m_A);
404 const MatrixType& T = schurOfA.matrixT();
405 const MatrixType& U = schurOfA.matrixU();
408 MatrixSquareRootTriangular<MatrixType> tmp(T);
409 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
413 result = U * sqrtT * U.adjoint();
417 const MatrixType& m_A;
434 :
public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
436 typedef typename Derived::Index Index;
450 template <
typename ResultType>
451 inline void evalTo(ResultType& result)
const
453 const typename Derived::PlainObject srcEvaluated = m_src.eval();
458 Index rows()
const {
return m_src.rows(); }
459 Index cols()
const {
return m_src.cols(); }
462 const Derived& m_src;
468 template<
typename Derived>
469 struct traits<MatrixSquareRootReturnValue<Derived> >
471 typedef typename Derived::PlainObject ReturnType;
475 template <
typename Derived>
476 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt()
const
478 eigen_assert(rows() == cols());
479 return MatrixSquareRootReturnValue<Derived>(derived());
484 #endif // EIGEN_MATRIX_FUNCTION