00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 template<typename eT>
00021 arma_inline
00022 u32 glue_times::mul_storage_cost(const Mat<eT>& X, const Mat<eT>& Y)
00023 {
00024 return X.n_rows * Y.n_cols;
00025 }
00026
00027
00028
00029
00030
00031 template<typename eT>
00032 inline
00033 void
00034 glue_times::apply_noalias(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B)
00035 {
00036 arma_extra_debug_sigprint();
00037
00038 arma_debug_assert_mul_size(A, B, "matrix multiply");
00039
00040 out.set_size(A.n_rows,B.n_cols);
00041 gemm<>::apply(out,A,B);
00042 }
00043
00044
00045
00046 template<typename eT>
00047 inline
00048 void
00049 glue_times::apply(Mat<eT>& out, const Mat<eT>& A_in, const Mat<eT>& B_in)
00050 {
00051 arma_extra_debug_sigprint();
00052
00053 if( (&out != &A_in) && (&out != &B_in) )
00054 {
00055 glue_times::apply_noalias(out,A_in,B_in);
00056 }
00057 else
00058 {
00059
00060 if( (&out == &A_in) && (&out != &B_in) )
00061 {
00062 Mat<eT> A_copy(A_in);
00063 glue_times::apply_noalias(out,A_copy,B_in);
00064 }
00065 else
00066 if( (&out != &A_in) && (&out == &B_in) )
00067 {
00068 Mat<eT> B_copy(B_in);
00069 glue_times::apply_noalias(out,A_in,B_copy);
00070 }
00071 else
00072 if( (&out == &A_in) && (&out == &B_in) )
00073 {
00074 Mat<eT> tmp(A_in);
00075 glue_times::apply_noalias(out,tmp,tmp);
00076 }
00077
00078 }
00079
00080 }
00081
00082
00083 template<typename eT>
00084 inline
00085 void
00086 glue_times::apply(Mat<eT>& out, const Mat<eT>& A, const Mat<eT>& B, const Mat<eT>& C)
00087 {
00088 arma_extra_debug_sigprint();
00089
00090 arma_debug_assert_mul_size(A, B, "matrix multiply");
00091 arma_debug_assert_mul_size(B, C, "matrix multiply");
00092
00093 if( mul_storage_cost(A,B) <= mul_storage_cost(B,C) )
00094 {
00095 Mat<eT> tmp;
00096 glue_times::apply_noalias(tmp, A, B);
00097
00098 if(&out != &C)
00099 {
00100 glue_times::apply_noalias(out, tmp, C);
00101 }
00102 else
00103 {
00104 Mat<eT> C_copy = C;
00105 glue_times::apply_noalias(out, tmp, C_copy);
00106 }
00107
00108 }
00109 else
00110 {
00111 Mat<eT> tmp;
00112 glue_times::apply_noalias(tmp, B, C);
00113
00114 if(&out != &A)
00115 {
00116 glue_times::apply_noalias(out, A, tmp);
00117 }
00118 else
00119 {
00120 Mat<eT> A_copy = A;
00121 glue_times::apply_noalias(out, A_copy, tmp);
00122 }
00123 }
00124
00125 }
00126
00127
00128
00129 template<typename T1, typename T2>
00130 void
00131 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_times>& X)
00132 {
00133 arma_extra_debug_sigprint();
00134
00135 typedef typename T1::elem_type eT;
00136
00137 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num;
00138
00139 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat);
00140
00141 if(N_mat == 2)
00142 {
00143 const unwrap<T1> tmp1(X.A);
00144 const unwrap<T2> tmp2(X.B);
00145
00146 glue_times::apply(out, tmp1.M, tmp2.M);
00147 }
00148 else
00149 {
00150
00151
00152 const Mat<eT>* ptrs[N_mat];
00153 bool del[N_mat];
00154
00155
00156 mat_ptrs_outcheck<glue_times, Glue<T1,T2,glue_times> >::get_ptrs(ptrs, del, X, &out);
00157
00158 for(s32 i=0; i<N_mat; ++i) arma_extra_debug_print( arma_boost::format("ptrs[%d] = %x") % i % ptrs[i] );
00159 for(s32 i=0; i<N_mat; ++i) arma_extra_debug_print( arma_boost::format(" del[%d] = %d") % i % del[i] );
00160
00161
00162 arma_extra_debug_print( arma_boost::format("required size of 'out': %d, %d") % ptrs[0]->n_rows % ptrs[N_mat-1]->n_cols );
00163
00164 int order[N_mat]; for(s32 i=0; i<N_mat; ++i) order[i] = -1;
00165
00166 int first_id = 0;
00167 int last_id = N_mat-1;
00168 int starting_id = -1;
00169
00170 int mat_count = N_mat;
00171
00172 int largest_size = 0;
00173
00174 while(mat_count != 0)
00175 {
00176
00177 for(s32 i=first_id; i != N_mat; ++i)
00178 {
00179 if(order[i] == -1) { first_id = i; break; }
00180 }
00181
00182 for(s32 i=last_id; i != -1; --i)
00183 {
00184 if(order[i] == -1) { last_id = i; break; }
00185 }
00186
00187 arma_extra_debug_print();
00188 arma_extra_debug_print(arma_boost::format("mat_count = %d") % mat_count );
00189 arma_extra_debug_print(arma_boost::format("first_id = %d") % first_id );
00190 arma_extra_debug_print(arma_boost::format("last_id = %d") % last_id );
00191
00192 if(first_id == last_id) { order[first_id] = 0; starting_id = first_id; break; }
00193
00194 s32 storage_cost_wo_last = mul_storage_cost( *ptrs[ first_id ], *ptrs[ last_id-1 ] );
00195 s32 storage_cost_wo_first = mul_storage_cost( *ptrs[ first_id+1 ], *ptrs[ last_id ] );
00196
00197 if(storage_cost_wo_last < storage_cost_wo_first)
00198 {
00199 order[last_id] = mat_count-1;
00200 if(storage_cost_wo_last > largest_size) largest_size = storage_cost_wo_last;
00201 }
00202 else
00203 {
00204 order[first_id] = mat_count-1;
00205 if(storage_cost_wo_first > largest_size) largest_size = storage_cost_wo_first;
00206 }
00207
00208 arma_extra_debug_print(arma_boost::format("storage_cost_wo_last = %d") % storage_cost_wo_last );
00209 arma_extra_debug_print(arma_boost::format("storage_cost_wo_first = %d") % storage_cost_wo_first );
00210
00211 arma_extra_debug_print("order = ");
00212 for(s32 i=0; i != N_mat; ++i) arma_extra_debug_print(order[i]);
00213
00214 --mat_count;
00215 }
00216
00217 arma_extra_debug_print("final order = ");
00218 for(s32 i=0; i != N_mat; ++i) arma_extra_debug_print(order[i]);
00219
00220 arma_extra_debug_print(arma_boost::format("*** largest_size = %d") % largest_size);
00221 arma_extra_debug_print(arma_boost::format("starting_id = %d") % starting_id);
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239 const u32 N_mul = N_mat - 1;
00240 int mul_count = N_mul;
00241 int current_id = starting_id;
00242
00243 const Mat<eT>* src_mat_1_ptr = ptrs[current_id];
00244 const Mat<eT>* src_mat_2_ptr = 0;
00245
00246
00247
00248
00249
00250 Mat<eT> tmp_mat_1;
00251 Mat<eT> tmp_mat_2;
00252
00253 Mat<eT>* tmp_mat_1_ptr = &tmp_mat_1;
00254 Mat<eT>* tmp_mat_2_ptr = (N_mul <= 2) ? 0 : &tmp_mat_2;
00255
00256 Mat<eT>* dest_mat_ptr = tmp_mat_2_ptr;
00257
00258 arma_extra_debug_print(arma_boost::format("tmp_mat_1_ptr = %x") % tmp_mat_1_ptr );
00259 arma_extra_debug_print(arma_boost::format("tmp_mat_2_ptr = %x") % tmp_mat_2_ptr );
00260 arma_extra_debug_print(arma_boost::format("&out = %x") % &out );
00261
00262 while(mul_count != 0)
00263 {
00264 arma_extra_debug_print("");
00265 arma_extra_debug_print("");
00266 arma_extra_debug_print(arma_boost::format("mul_count = %d") % mul_count);
00267
00268 arma_extra_debug_print("order = ");
00269 for(s32 i=0; i != N_mat; ++i) arma_extra_debug_print(order[i]);
00270 arma_extra_debug_print("");
00271
00272
00273 if(mul_count == 1)
00274 {
00275 arma_extra_debug_print("dest_mat = &out");
00276 dest_mat_ptr = &out;
00277 }
00278 else
00279 {
00280 if(dest_mat_ptr == tmp_mat_2_ptr)
00281 {
00282 arma_extra_debug_print("dest_mat_ptr = tmp_mat_2_ptr");
00283 dest_mat_ptr = tmp_mat_1_ptr;
00284 }
00285 else
00286 {
00287 arma_extra_debug_print("dest_mat_ptr = tmp_mat_1_ptr");
00288 dest_mat_ptr = tmp_mat_2_ptr;
00289 }
00290 }
00291
00292 arma_extra_debug_print(arma_boost::format("dest_mat_ptr = %x") % dest_mat_ptr );
00293
00294
00295 s32 left_val = N_mat;
00296 s32 left_id = -1;
00297
00298 s32 right_val = N_mat;
00299 s32 right_id = -1;
00300
00301
00302 for(s32 i=current_id-1; i >= 0; --i)
00303 if( order[i] > order[current_id] ) { left_val = order[i]; left_id = i; break; }
00304
00305
00306 for(s32 i=current_id+1; i < N_mat; ++i)
00307 if( order[current_id] < order[i] ) { right_val = order[i]; right_id = i; break; }
00308
00309 arma_extra_debug_print("");
00310 arma_extra_debug_print(arma_boost::format("left_id = %d") % left_id );
00311 arma_extra_debug_print(arma_boost::format("left_val = %f") % left_val );
00312
00313 arma_extra_debug_print("");
00314 arma_extra_debug_print(arma_boost::format("right_id = %d") % right_id );
00315 arma_extra_debug_print(arma_boost::format("right_val = %f") % right_val );
00316
00317
00318 if(left_val < right_val)
00319 {
00320
00321 src_mat_2_ptr = ptrs[left_id];
00322
00323 arma_extra_debug_print("");
00324 arma_extra_debug_print(arma_boost::format("case pre-multiply with matrix %d") % left_id);
00325 arma_extra_debug_print(arma_boost::format("required destination size: %d, %d (%d)") % src_mat_2_ptr->n_rows % src_mat_1_ptr->n_cols % (src_mat_2_ptr->n_rows * src_mat_1_ptr->n_cols) );
00326
00327 glue_times::apply_noalias(*dest_mat_ptr, *src_mat_2_ptr, *src_mat_1_ptr);
00328
00329 order[current_id] = -1;
00330 current_id = left_id;
00331 }
00332 else
00333 {
00334
00335 src_mat_2_ptr = ptrs[right_id];
00336
00337 arma_extra_debug_print("");
00338 arma_extra_debug_print(arma_boost::format("case post-multiply with matrix %d") % right_id);
00339 arma_extra_debug_print(arma_boost::format("required destination size: %d, %d (%d)") % src_mat_1_ptr->n_rows % src_mat_2_ptr->n_cols % (src_mat_1_ptr->n_rows * src_mat_2_ptr->n_cols) );
00340
00341 glue_times::apply_noalias(*dest_mat_ptr, *src_mat_1_ptr, *src_mat_2_ptr);
00342
00343 order[current_id] = -1;
00344 current_id = right_id;
00345 }
00346
00347
00348 src_mat_1_ptr = dest_mat_ptr;
00349
00350 --mul_count;
00351 }
00352
00353
00354 for(s32 i=0; i<N_mat; ++i)
00355 {
00356 if(del[i] == true)
00357 {
00358 arma_extra_debug_print(arma_boost::format("delete mat_ptr[%d]") % i );
00359 delete ptrs[i];
00360 }
00361 }
00362 }
00363 }
00364
00365
00366
00367 template<typename eT>
00368 inline
00369 void
00370 glue_times::apply(Mat<eT>& out, const Glue<Mat<eT>,Mat<eT>,glue_times>& X)
00371 {
00372 glue_times::apply(out, X.A, X.B);
00373 }
00374
00375
00376
00377 template<typename eT>
00378 inline
00379 void
00380 glue_times::apply(Mat<eT>& out, const Glue< Glue<Mat<eT>,Mat<eT>, glue_times>, Mat<eT>, glue_times>& X)
00381 {
00382 glue_times::apply(out, X.A.A, X.A.B, X.B);
00383 }
00384
00385
00386
00387 template<typename eT>
00388 inline
00389 void
00390 glue_times::apply_inplace(Mat<eT>& out, const Mat<eT>& B)
00391 {
00392 arma_extra_debug_sigprint();
00393
00394 arma_debug_assert_mul_size(out, B, "matrix multiply");
00395
00396 if(out.n_cols == B.n_cols)
00397 {
00398 podarray<eT> tmp(out.n_cols);
00399 eT* tmp_rowdata = tmp.memptr();
00400
00401 for(u32 out_row=0; out_row < out.n_rows; ++out_row)
00402 {
00403 for(u32 out_col=0; out_col < out.n_cols; ++out_col)
00404 {
00405 tmp_rowdata[out_col] = out.at(out_row,out_col);
00406 }
00407
00408 for(u32 B_col=0; B_col < B.n_cols; ++B_col)
00409 {
00410 const eT* B_coldata = B.colptr(B_col);
00411
00412 eT val = eT(0);
00413 for(u32 i=0; i < B.n_rows; ++i)
00414 {
00415 val += tmp_rowdata[i] * B_coldata[i];
00416 }
00417
00418 out.at(out_row,B_col) = val;
00419 }
00420 }
00421
00422 }
00423 else
00424 {
00425 Mat<eT> tmp = out;
00426 glue_times::apply(out, tmp, B);
00427 }
00428
00429 }
00430
00431
00432
00433 template<typename T1, typename op_type>
00434 inline
00435 void
00436 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const Op<T1, op_type>& X)
00437 {
00438 arma_extra_debug_sigprint();
00439
00440 typedef typename T1::elem_type eT;
00441
00442 const Mat<eT> tmp(X);
00443 glue_times::apply(out, out, tmp);
00444 }
00445
00446
00447
00448 template<typename T1, typename T2, typename glue_type>
00449 inline
00450 void
00451 glue_times::apply_inplace(Mat<typename T1::elem_type>& out, const Glue<T1, T2, glue_type>& X)
00452 {
00453 arma_extra_debug_sigprint();
00454
00455 out = out * X;
00456 }
00457
00458
00459
00460
00461 template<typename T1, typename T2>
00462 inline
00463 void
00464 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Op<T2,op_trans>, glue_times>& X)
00465 {
00466 arma_extra_debug_sigprint();
00467
00468
00469 typedef typename T1::elem_type eT;
00470
00471
00472
00473 const unwrap<T1> tmp1(X.A);
00474 const unwrap<T2> tmp2(X.B.m);
00475
00476 const Mat<eT>& A = tmp1.M;
00477 const Mat<eT>& B = tmp2.M;
00478
00479 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_cols, B.n_rows, "matrix multiply");
00480
00481 if( (A.n_rows*B.n_rows) > 0)
00482 {
00483 if(&A != &B)
00484 {
00485 unwrap_check< Mat<eT> > A_safe_tmp(A, out);
00486 unwrap_check< Mat<eT> > B_safe_tmp(B, out);
00487
00488 const Mat<eT>& A_safe = A_safe_tmp.M;
00489 const Mat<eT>& B_safe = B_safe_tmp.M;
00490
00491 out.set_size(A_safe.n_rows, B_safe.n_rows);
00492
00493 gemm<false,true>::apply(out, A, B);
00494 }
00495 else
00496 {
00497 arma_extra_debug_print("glue_times::apply(): detected A*A'");
00498
00499 Mat<eT> tmp;
00500 op_trans::apply(tmp,A);
00501
00502
00503 out.set_size(A.n_rows, A.n_rows);
00504
00505 for(u32 row=0; row != A.n_rows; ++row)
00506 {
00507 for(u32 col=0; col <= row; ++col)
00508 {
00509 const eT* coldata1 = tmp.colptr(row);
00510 const eT* coldata2 = tmp.colptr(col);
00511
00512 eT val = eT(0);
00513 for(u32 i=0; i < tmp.n_rows; ++i)
00514 {
00515 val += coldata1[i] * coldata2[i];
00516 }
00517
00518 out.at(row,col) = val;
00519 out.at(col,row) = val;
00520 }
00521 }
00522
00523 }
00524
00525 }
00526
00527 }
00528
00529
00530
00531
00532 template<typename T1, typename T2>
00533 inline
00534 void
00535 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1,op_trans>, T2, glue_times>& X)
00536 {
00537 arma_extra_debug_sigprint();
00538
00539 typedef typename T1::elem_type eT;
00540
00541 const unwrap_check<T1> tmp1(X.A.m, out);
00542 const unwrap_check<T2> tmp2(X.B, out);
00543
00544 const Mat<eT>& A = tmp1.M;
00545 const Mat<eT>& B = tmp2.M;
00546
00547 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiply");
00548
00549 if( (A.n_cols*B.n_cols) > 0 )
00550 {
00551 out.set_size(A.n_cols, B.n_cols);
00552
00553 gemm<true,false>::apply(out, A, B);
00554 }
00555
00556 }
00557
00558
00559
00560
00561 template<typename T1, typename T2>
00562 inline
00563 void
00564 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1,op_trans>, Op<T2,op_trans>, glue_times>& X)
00565 {
00566 arma_extra_debug_sigprint();
00567
00568 typedef typename T1::elem_type eT;
00569
00570 const unwrap_check<T1> tmp1(X.A.m, out);
00571 const unwrap_check<T2> tmp2(X.B.m, out);
00572
00573 const Mat<eT>& A = tmp1.M;
00574 const Mat<eT>& B = tmp2.M;
00575
00576 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_cols, B.n_rows, "matrix multiply");
00577
00578 if( (A.n_cols*B.n_rows) > 0 )
00579 {
00580 out.set_size(A.n_cols, B.n_rows);
00581
00582 gemm<true,true>::apply(out, A, B);
00583
00584 }
00585
00586 }
00587
00588
00589
00590
00591
00592 template<typename T1, typename T2>
00593 inline
00594 void
00595 glue_times::apply(Mat<typename T1::elem_type>& out, const Glue< Op<T1, op_neg>, T2, glue_times>& X)
00596 {
00597 arma_extra_debug_sigprint();
00598
00599 typedef typename T1::elem_type eT;
00600
00601 const unwrap_check<T1> tmp1(X.A.m, out);
00602 const unwrap_check<T2> tmp2(X.B, out);
00603
00604 const Mat<eT>& A = tmp1.M;
00605 const Mat<eT>& B = tmp2.M;
00606
00607 glue_times::apply(out, A, B);
00608
00609 const u32 n_elem = out.n_elem;
00610 for(u32 i=0; i<n_elem; ++i)
00611 {
00612 out[i] = -out[i];
00613 }
00614 }
00615
00616
00617
00618
00619 template<typename eT>
00620 inline
00621 eT
00622 glue_times::direct_rowvec_mat_colvec
00623 (
00624 const eT* A_mem,
00625 const Mat<eT>& B,
00626 const eT* C_mem
00627 )
00628 {
00629 arma_extra_debug_sigprint();
00630
00631 const u32 cost_AB = B.n_cols;
00632 const u32 cost_BC = B.n_rows;
00633
00634 if(cost_AB <= cost_BC)
00635 {
00636 podarray<eT> tmp(B.n_cols);
00637
00638 for(u32 col=0; col<B.n_cols; ++col)
00639 {
00640 const eT* B_coldata = B.colptr(col);
00641
00642 eT val = eT(0);
00643 for(u32 i=0; i<B.n_rows; ++i)
00644 {
00645 val += A_mem[i] * B_coldata[i];
00646 }
00647
00648 tmp[col] = val;
00649 }
00650
00651 return op_dot::direct_dot(B.n_cols, tmp.mem, C_mem);
00652 }
00653 else
00654 {
00655 podarray<eT> tmp(B.n_rows);
00656
00657 for(u32 row=0; row<B.n_rows; ++row)
00658 {
00659 eT val = eT(0);
00660 for(u32 col=0; col<B.n_cols; ++col)
00661 {
00662 val += B.at(row,col) * C_mem[col];
00663 }
00664
00665 tmp[row] = val;
00666 }
00667
00668 return op_dot::direct_dot(B.n_rows, A_mem, tmp.mem);
00669 }
00670
00671
00672 }
00673
00674
00675
00676 template<typename eT>
00677 inline
00678 eT
00679 glue_times::direct_rowvec_diagmat_colvec
00680 (
00681 const eT* A_mem,
00682 const Mat<eT>& B,
00683 const eT* C_mem
00684 )
00685 {
00686 arma_extra_debug_sigprint();
00687
00688 eT val = eT(0);
00689
00690 for(u32 i=0; i<B.n_rows; ++i)
00691 {
00692 val += A_mem[i] * B.at(i,i) * C_mem[i];
00693 }
00694
00695 return val;
00696 }
00697
00698
00699
00700 template<typename eT>
00701 inline
00702 eT
00703 glue_times::direct_rowvec_invdiagmat_colvec
00704 (
00705 const eT* A_mem,
00706 const Mat<eT>& B,
00707 const eT* C_mem
00708 )
00709 {
00710 arma_extra_debug_sigprint();
00711
00712 eT val = eT(0);
00713
00714 for(u32 i=0; i<B.n_rows; ++i)
00715 {
00716 val += (A_mem[i] * C_mem[i]) / B.at(i,i);
00717 }
00718
00719 return val;
00720 }
00721
00722
00723
00724 template<typename eT>
00725 inline
00726 eT
00727 glue_times::direct_rowvec_invdiagvec_colvec
00728 (
00729 const eT* A_mem,
00730 const Mat<eT>& B,
00731 const eT* C_mem
00732 )
00733 {
00734 arma_extra_debug_sigprint();
00735
00736 const eT* B_mem = B.mem;
00737
00738 eT val = eT(0);
00739
00740 for(u32 i=0; i<B.n_elem; ++i)
00741 {
00742 val += (A_mem[i] * C_mem[i]) / B_mem[i];
00743 }
00744
00745 return val;
00746 }
00747
00748
00749
00750
00751
00752
00753 template<typename eT1, typename eT2>
00754 inline
00755 void
00756 glue_times::apply_mixed(Mat<typename promote_type<eT1,eT2>::result>& out, const Mat<eT1>& X, const Mat<eT2>& Y)
00757 {
00758 arma_extra_debug_sigprint();
00759
00760 typedef typename promote_type<eT1,eT2>::result out_eT;
00761
00762 arma_debug_assert_mul_size(X,Y, "matrix multiply");
00763
00764 out.set_size(X.n_rows,Y.n_cols);
00765 gemm_mixed<>::apply(out, X, Y);
00766 }
00767
00768
00769
00770
00771
00772
00773
00774 template<typename T1, typename T2>
00775 inline
00776 void
00777 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const T1& A_orig, const Op<T2,op_diagmat>& B_orig)
00778 {
00779 arma_extra_debug_sigprint();
00780
00781 isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00782
00783 const unwrap_check<T1> tmp1(A_orig, out);
00784 const unwrap_check<T2> tmp2(B_orig.m, out);
00785
00786 typedef typename T1::elem_type eT;
00787
00788 const Mat<eT>& A = tmp1.M;
00789 const Mat<eT>& B = tmp2.M;
00790
00791 arma_debug_check( (B.is_square() == false), "glue_times_diag::apply(): incompatible matrix dimensions" );
00792 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00793
00794 out.set_size(A.n_rows, B.n_cols);
00795
00796 for(u32 col=0; col<A.n_cols; ++col)
00797 {
00798 const eT val = B.at(col,col);
00799
00800 const eT* A_coldata = A.colptr(col);
00801 eT* out_coldata = out.colptr(col);
00802
00803 for(u32 row=0; row<B.n_rows; ++row)
00804 {
00805 out_coldata[row] = A_coldata[row] * val;
00806 }
00807
00808 }
00809
00810 }
00811
00812
00813
00814 template<typename T1, typename T2>
00815 inline
00816 void
00817 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_diagmat>& A_orig, const T2& B_orig)
00818 {
00819 arma_extra_debug_sigprint();
00820
00821 isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00822
00823 const unwrap_check<T1> tmp1(A_orig.m, out);
00824 const unwrap_check<T2> tmp2(B_orig, out);
00825
00826 typedef typename T1::elem_type eT;
00827
00828 const Mat<eT>& A = tmp1.M;
00829 const Mat<eT>& B = tmp2.M;
00830
00831 arma_debug_check( (A.is_square() == false), "glue_times_diag::apply(): incompatible matrix dimensions" );
00832 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00833
00834 out.set_size(A.n_rows, B.n_cols);
00835
00836
00837 for(u32 col=0; col<A.n_cols; ++col)
00838 {
00839 const eT* B_coldata = B.colptr(col);
00840 eT* out_coldata = out.colptr(col);
00841
00842 for(u32 row=0; row<B.n_rows; ++row)
00843 {
00844 out_coldata[row] = A.at(row,row) * B_coldata[row];
00845 }
00846
00847 }
00848
00849 }
00850
00851
00852
00853 template<typename T1, typename T2>
00854 inline
00855 void
00856 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_diagmat>& A_orig, const Op<T2,op_diagmat>& B_orig)
00857 {
00858 arma_extra_debug_sigprint();
00859
00860 isnt_same_type<typename T1::elem_type, typename T2::elem_type>::check();
00861
00862 unwrap_check<T1> tmp1(A_orig.m, out);
00863 unwrap_check<T2> tmp2(B_orig.m, out);
00864
00865 typedef typename T1::elem_type eT;
00866
00867 const Mat<eT>& A = tmp1.M;
00868 const Mat<eT>& B = tmp2.M;
00869
00870 arma_debug_check( !A.is_square() || !B.is_square(), "glue_times_diag::apply(): incompatible matrix dimensions" );
00871 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiply");
00872
00873 out.zeros(A.n_rows, B.n_cols);
00874
00875 for(u32 i=0; i<A.n_rows; ++i)
00876 {
00877 out.at(i,i) = A.at(i,i) * B.at(i,i);
00878 }
00879 }
00880
00881
00882
00883 template<typename T1, typename T2>
00884 inline
00885 void
00886 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Op<T2,op_diagmat>, glue_times_diag>& X)
00887 {
00888 glue_times_diag::apply(out, X.A, X.B);
00889 }
00890
00891
00892
00893 template<typename T1, typename T2>
00894 inline
00895 void
00896 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1,op_diagmat>, T2, glue_times_diag>& X)
00897 {
00898 glue_times_diag::apply(out, X.A, X.B);
00899 }
00900
00901
00902
00903 template<typename T1, typename T2>
00904 inline
00905 void
00906 glue_times_diag::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1,op_diagmat>, Op<T2,op_diagmat>, glue_times_diag>& X)
00907 {
00908 glue_times_diag::apply(out, X.A, X.B);
00909 }
00910
00911
00912
00913
00914
00915
00916
00917
00918 template<typename eT>
00919 inline
00920 void
00921 glue_times_vec::mul_col_row(Mat<eT>& out, const eT* A, const eT* B)
00922 {
00923 const u32 n_rows = out.n_rows;
00924 const u32 n_cols = out.n_cols;
00925
00926 for(u32 col=0; col < n_cols; ++col)
00927 {
00928 const eT val = B[col];
00929
00930 eT* out_coldata = out.colptr(col);
00931
00932 for(u32 row=0; row < n_rows; ++row)
00933 {
00934 out_coldata[row] = A[row] * val;
00935 }
00936 }
00937
00938 }
00939
00940
00941
00942 template<typename T1>
00943 inline
00944 void
00945 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Col<typename T1::elem_type>,glue_times_vec>& X)
00946 {
00947 arma_extra_debug_sigprint();
00948
00949 typedef typename T1::elem_type eT;
00950
00951 unwrap_check< T1 > tmp1(X.A, out);
00952 unwrap_check< Col<eT> > tmp2(X.B, out);
00953
00954 const Mat<eT>& A = tmp1.M;
00955 const Col<eT>& B = tmp2.M;
00956
00957 arma_debug_assert_mul_size(A, B, "vector multiply");
00958
00959 out.set_size(A.n_rows, 1);
00960
00961
00962 gemv<>::apply(out.memptr(), A, B.mem);
00963 }
00964
00965
00966
00967 template<typename T1>
00968 inline
00969 void
00970 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<T1, Row<typename T1::elem_type>,glue_times_vec>& X)
00971 {
00972 arma_extra_debug_sigprint();
00973
00974
00975
00976 typedef typename T1::elem_type eT;
00977
00978 unwrap_check< T1 > tmp1(X.A, out);
00979 unwrap_check< Row<eT> > tmp2(X.B, out);
00980
00981 const Mat<eT>& A = tmp1.M;
00982 const Mat<eT>& B = tmp2.M;
00983
00984 arma_debug_assert_mul_size(A, B, "vector multiply");
00985
00986 out.set_size(A.n_rows, B.n_cols);
00987
00988 glue_times_vec::mul_col_row(out, A.mem, B.mem);
00989 }
00990
00991
00992
00993 template<typename T1>
00994 inline
00995 void
00996 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<Col<typename T1::elem_type>,T1,glue_times_vec>& X)
00997 {
00998 arma_extra_debug_sigprint();
00999
01000
01001
01002
01003 typedef typename T1::elem_type eT;
01004
01005 unwrap_check< Col<eT> > tmp1(X.A, out);
01006 unwrap_check< T1 > tmp2(X.B, out);
01007
01008 const Mat<eT>& A = tmp1.M;
01009 const Mat<eT>& B = tmp2.M;
01010
01011 arma_debug_assert_mul_size(A, B, "vector multiply");
01012
01013 out.set_size(A.n_rows, B.n_cols);
01014
01015 glue_times_vec::mul_col_row(out, A.mem, B.mem);
01016 }
01017
01018
01019
01020 template<typename T1>
01021 inline
01022 void
01023 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<Row<typename T1::elem_type>,T1,glue_times_vec>& X)
01024 {
01025 arma_extra_debug_sigprint();
01026
01027 typedef typename T1::elem_type eT;
01028
01029 unwrap_check< Row<eT> > tmp1(X.A, out);
01030 unwrap_check< T1 > tmp2(X.B, out);
01031
01032 const Row<eT>& A = tmp1.M;
01033 const Mat<eT>& B = tmp2.M;
01034
01035 arma_debug_assert_mul_size(A, B, "vector multiply");
01036
01037 out.set_size(A.n_rows, B.n_cols);
01038
01039
01040
01041
01042
01043
01044
01045
01046
01047
01048
01049
01050
01051
01052
01053
01054
01055
01056
01057
01058 gemv<true>::apply(out.memptr(), B, A.mem);
01059 }
01060
01061
01062
01063 template<typename eT>
01064 inline
01065 void
01066 glue_times_vec::apply(Mat<eT>& out, const Glue<Col<eT>,Row<eT>,glue_times_vec>& X)
01067 {
01068 arma_extra_debug_sigprint();
01069
01070 unwrap_check< Col<eT> > tmp1(X.A, out);
01071 unwrap_check< Row<eT> > tmp2(X.B, out);
01072
01073 const Col<eT>& A = tmp1.M;
01074 const Row<eT>& B = tmp2.M;
01075
01076 arma_debug_assert_mul_size(A, B, "vector multiply");
01077
01078 out.set_size(A.n_rows, B.n_cols);
01079
01080 glue_times_vec::mul_col_row(out, A.mem, B.mem);
01081 }
01082
01083
01084
01085 template<typename eT>
01086 inline
01087 void
01088 glue_times_vec::apply(Mat<eT>& out, const Glue< Op<Row<eT>, op_trans>, Row<eT>, glue_times_vec>& X)
01089 {
01090 arma_extra_debug_sigprint();
01091
01092 unwrap_check< Row<eT> > tmp1(X.A.m, out);
01093 unwrap_check< Row<eT> > tmp2(X.B, out);
01094
01095 const Row<eT>& A = tmp1.M;
01096 const Row<eT>& B = tmp2.M;
01097
01098 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "vector multiply");
01099
01100 out.set_size(A.n_cols, B.n_cols);
01101
01102 glue_times_vec::mul_col_row(out, A.mem, B.mem);
01103 }
01104
01105
01106
01107 template<typename eT>
01108 inline
01109 void
01110 glue_times_vec::apply(Mat<eT>& out, const Glue< Col<eT>, Op<Col<eT>, op_trans>, glue_times_vec>& X)
01111 {
01112 arma_extra_debug_sigprint();
01113
01114 unwrap_check< Col<eT> > tmp1(X.A, out);
01115 unwrap_check< Col<eT> > tmp2(X.B.m, out);
01116
01117 const Col<eT>& A = tmp1.M;
01118 const Col<eT>& B = tmp2.M;
01119
01120 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_cols, B.n_rows, "vector multiply");
01121
01122 out.set_size(A.n_rows, B.n_rows);
01123
01124 glue_times_vec::mul_col_row(out, A.mem, B.mem);
01125 }
01126
01127
01128
01129 template<typename T1>
01130 inline
01131 void
01132 glue_times_vec::apply(Mat<typename T1::elem_type>& out, const Glue<Op<T1, op_trans>, Col<typename T1::elem_type>,glue_times_vec>& X)
01133 {
01134 arma_extra_debug_sigprint();
01135
01136 typedef typename T1::elem_type eT;
01137
01138 unwrap_check< T1 > tmp1(X.A.m, out);
01139 unwrap_check< Col<eT> > tmp2(X.B, out);
01140
01141 const Mat<eT>& A = tmp1.M;
01142 const Col<eT>& B = tmp2.M;
01143
01144 arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "vector multiply");
01145
01146 out.set_size(A.n_cols, B.n_cols);
01147
01148
01149
01150
01151
01152
01153
01154
01155
01156
01157
01158
01159
01160
01161
01162
01163
01164
01165
01166
01167 gemv<true>::apply(out.memptr(), A, B.mem);
01168 }
01169
01170
01171
01172