Actual source code: mpimatmatmatmult.c
1: /*
2: Defines matrix-matrix-matrix product routines for MPIAIJ matrices
3: D = A * B * C
4: */
5: #include <../src/mat/impls/aij/mpi/mpiaij.h>
7: #if defined(PETSC_HAVE_HYPRE)
8: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,PetscReal,Mat);
9: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,Mat);
11: PETSC_INTERN PetscErrorCode MatProductNumeric_ABC_Transpose_AIJ_AIJ(Mat RAP)
12: {
13: Mat_Product *product = RAP->product;
14: Mat Rt,R=product->A,A=product->B,P=product->C;
16: MatTransposeGetMat(R,&Rt);
17: MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,RAP);
18: return 0;
19: }
21: PETSC_INTERN PetscErrorCode MatProductSymbolic_ABC_Transpose_AIJ_AIJ(Mat RAP)
22: {
23: Mat_Product *product = RAP->product;
24: Mat Rt,R=product->A,A=product->B,P=product->C;
25: PetscBool flg;
27: /* local sizes of matrices will be checked by the calling subroutines */
28: MatTransposeGetMat(R,&Rt);
29: PetscObjectTypeCompareAny((PetscObject)Rt,&flg,MATSEQAIJ,MATSEQAIJMKL,MATMPIAIJ,NULL);
31: MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,product->fill,RAP);
32: RAP->ops->productnumeric = MatProductNumeric_ABC_Transpose_AIJ_AIJ;
33: return 0;
34: }
36: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose_AIJ_AIJ(Mat C)
37: {
38: Mat_Product *product = C->product;
40: if (product->type == MATPRODUCT_ABC) {
41: C->ops->productsymbolic = MatProductSymbolic_ABC_Transpose_AIJ_AIJ;
42: } else SETERRQ(PetscObjectComm((PetscObject)C),PETSC_ERR_SUP,"MatProduct type %s is not supported for Transpose, AIJ and AIJ matrices",MatProductTypes[product->type]);
43: return 0;
44: }
45: #endif
47: PetscErrorCode MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,PetscReal fill,Mat D)
48: {
49: Mat BC;
50: PetscBool scalable;
51: Mat_Product *product;
53: MatCheckProduct(D,4);
55: product = D->product;
56: MatProductCreate(B,C,NULL,&BC);
57: MatProductSetType(BC,MATPRODUCT_AB);
58: PetscStrcmp(product->alg,"scalable",&scalable);
59: if (scalable) {
60: MatMatMultSymbolic_MPIAIJ_MPIAIJ(B,C,fill,BC);
61: MatZeroEntries(BC); /* initialize value entries of BC */
62: MatMatMultSymbolic_MPIAIJ_MPIAIJ(A,BC,fill,D);
63: } else {
64: MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(B,C,fill,BC);
65: MatZeroEntries(BC); /* initialize value entries of BC */
66: MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(A,BC,fill,D);
67: }
68: MatDestroy(&product->Dwork);
69: product->Dwork = BC;
71: D->ops->matmatmultnumeric = MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ;
72: return 0;
73: }
75: PetscErrorCode MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,Mat D)
76: {
77: Mat_Product *product;
78: Mat BC;
80: MatCheckProduct(D,4);
82: product = D->product;
83: BC = product->Dwork;
85: (*BC->ops->matmultnumeric)(B,C,BC);
87: (*D->ops->matmultnumeric)(A,BC,D);
88: return 0;
89: }
91: /* ----------------------------------------------------- */
92: PetscErrorCode MatDestroy_MPIAIJ_RARt(void *data)
93: {
94: Mat_RARt *rart = (Mat_RARt*)data;
96: MatDestroy(&rart->Rt);
97: if (rart->destroy) {
98: (*rart->destroy)(rart->data);
99: }
100: PetscFree(rart);
101: return 0;
102: }
104: PetscErrorCode MatProductNumeric_RARt_MPIAIJ_MPIAIJ(Mat C)
105: {
106: Mat_RARt *rart;
107: Mat A,R,Rt;
109: MatCheckProduct(C,1);
111: rart = (Mat_RARt*)C->product->data;
112: A = C->product->A;
113: R = C->product->B;
114: Rt = rart->Rt;
115: MatTranspose(R,MAT_REUSE_MATRIX,&Rt);
116: if (rart->data) C->product->data = rart->data;
117: (*C->ops->matmatmultnumeric)(R,A,Rt,C);
118: C->product->data = rart;
119: return 0;
120: }
122: PetscErrorCode MatProductSymbolic_RARt_MPIAIJ_MPIAIJ(Mat C)
123: {
124: Mat A,R,Rt;
125: Mat_RARt *rart;
127: MatCheckProduct(C,1);
129: A = C->product->A;
130: R = C->product->B;
131: MatTranspose(R,MAT_INITIAL_MATRIX,&Rt);
132: /* product->Dwork is used to store A*Rt in MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ() */
133: MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(R,A,Rt,C->product->fill,C);
134: C->ops->productnumeric = MatProductNumeric_RARt_MPIAIJ_MPIAIJ;
136: /* create a supporting struct */
137: PetscNew(&rart);
138: rart->Rt = Rt;
139: rart->data = C->product->data;
140: rart->destroy = C->product->destroy;
141: C->product->data = rart;
142: C->product->destroy = MatDestroy_MPIAIJ_RARt;
143: return 0;
144: }