Actual source code: baijfact81.c
2: /*
3: Factorization code for BAIJ format.
4: */
5: #include <../src/mat/impls/baij/seq/baij.h>
6: #include <petsc/private/kernels/blockinvert.h>
7: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
8: #include <immintrin.h>
9: #endif
10: /*
11: Version for when blocks are 9 by 9
12: */
13: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
14: PetscErrorCode MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B,Mat A,const MatFactorInfo *info)
15: {
16: Mat C =B;
17: Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data,*b=(Mat_SeqBAIJ*)C->data;
18: PetscInt i,j,k,nz,nzL,row;
19: const PetscInt n=a->mbs,*ai=a->i,*aj=a->j,*bi=b->i,*bj=b->j;
20: const PetscInt *ajtmp,*bjtmp,*bdiag=b->diag,*pj,bs2=a->bs2;
21: MatScalar *rtmp,*pc,*mwork,*v,*pv,*aa=a->a;
22: PetscInt flg;
23: PetscReal shift = info->shiftamount;
24: PetscBool allowzeropivot,zeropivotdetected;
26: allowzeropivot = PetscNot(A->erroriffailure);
28: /* generate work space needed by the factorization */
29: PetscMalloc2(bs2*n,&rtmp,bs2,&mwork);
30: PetscArrayzero(rtmp,bs2*n);
32: for (i=0; i<n; i++) {
33: /* zero rtmp */
34: /* L part */
35: nz = bi[i+1] - bi[i];
36: bjtmp = bj + bi[i];
37: for (j=0; j<nz; j++) {
38: PetscArrayzero(rtmp+bs2*bjtmp[j],bs2);
39: }
41: /* U part */
42: nz = bdiag[i] - bdiag[i+1];
43: bjtmp = bj + bdiag[i+1]+1;
44: for (j=0; j<nz; j++) {
45: PetscArrayzero(rtmp+bs2*bjtmp[j],bs2);
46: }
48: /* load in initial (unfactored row) */
49: nz = ai[i+1] - ai[i];
50: ajtmp = aj + ai[i];
51: v = aa + bs2*ai[i];
52: for (j=0; j<nz; j++) {
53: PetscArraycpy(rtmp+bs2*ajtmp[j],v+bs2*j,bs2);
54: }
56: /* elimination */
57: bjtmp = bj + bi[i];
58: nzL = bi[i+1] - bi[i];
59: for (k=0; k < nzL; k++) {
60: row = bjtmp[k];
61: pc = rtmp + bs2*row;
62: for (flg=0,j=0; j<bs2; j++) {
63: if (pc[j]!=0.0) {
64: flg = 1;
65: break;
66: }
67: }
68: if (flg) {
69: pv = b->a + bs2*bdiag[row];
70: /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
71: PetscKernel_A_gets_A_times_B_9(pc,pv,mwork);
73: pj = b->j + bdiag[row+1]+1; /* beginning of U(row,:) */
74: pv = b->a + bs2*(bdiag[row+1]+1);
75: nz = bdiag[row] - bdiag[row+1] - 1; /* num of entries inU(row,:), excluding diag */
76: for (j=0; j<nz; j++) {
77: /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
78: /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
79: v = rtmp + bs2*pj[j];
80: PetscKernel_A_gets_A_minus_B_times_C_9(v,pc,pv+81*j);
81: /* pv incremented in PetscKernel_A_gets_A_minus_B_times_C_9 */
82: }
83: PetscLogFlops(1458*nz+1377); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
84: }
85: }
87: /* finished row so stick it into b->a */
88: /* L part */
89: pv = b->a + bs2*bi[i];
90: pj = b->j + bi[i];
91: nz = bi[i+1] - bi[i];
92: for (j=0; j<nz; j++) {
93: PetscArraycpy(pv+bs2*j,rtmp+bs2*pj[j],bs2);
94: }
96: /* Mark diagonal and invert diagonal for simpler triangular solves */
97: pv = b->a + bs2*bdiag[i];
98: pj = b->j + bdiag[i];
99: PetscArraycpy(pv,rtmp+bs2*pj[0],bs2);
100: PetscKernel_A_gets_inverse_A_9(pv,shift,allowzeropivot,&zeropivotdetected);
101: if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;
103: /* U part */
104: pv = b->a + bs2*(bdiag[i+1]+1);
105: pj = b->j + bdiag[i+1]+1;
106: nz = bdiag[i] - bdiag[i+1] - 1;
107: for (j=0; j<nz; j++) {
108: PetscArraycpy(pv+bs2*j,rtmp+bs2*pj[j],bs2);
109: }
110: }
111: PetscFree2(rtmp,mwork);
113: C->ops->solve = MatSolve_SeqBAIJ_9_NaturalOrdering;
114: C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_N;
115: C->assembled = PETSC_TRUE;
117: PetscLogFlops(1.333333333333*9*9*9*n); /* from inverting diagonal blocks */
118: return 0;
119: }
121: PetscErrorCode MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A,Vec bb,Vec xx)
122: {
123: Mat_SeqBAIJ *a=(Mat_SeqBAIJ*)A->data;
124: const PetscInt *ai=a->i,*aj=a->j,*adiag=a->diag,*vi;
125: PetscInt i,k,n=a->mbs;
126: PetscInt nz,bs=A->rmap->bs,bs2=a->bs2;
127: const MatScalar *aa=a->a,*v;
128: PetscScalar *x,*s,*t,*ls;
129: const PetscScalar *b;
130: __m256d a0,a1,a2,a3,a4,a5,w0,w1,w2,w3,s0,s1,s2,v0,v1,v2,v3;
132: VecGetArrayRead(bb,&b);
133: VecGetArray(xx,&x);
134: t = a->solve_work;
136: /* forward solve the lower triangular */
137: PetscArraycpy(t,b,bs); /* copy 1st block of b to t */
139: for (i=1; i<n; i++) {
140: v = aa + bs2*ai[i];
141: vi = aj + ai[i];
142: nz = ai[i+1] - ai[i];
143: s = t + bs*i;
144: PetscArraycpy(s,b+bs*i,bs); /* copy i_th block of b to t */
146: __m256d s0,s1,s2;
147: s0 = _mm256_loadu_pd(s+0);
148: s1 = _mm256_loadu_pd(s+4);
149: s2 = _mm256_maskload_pd(s+8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
151: for (k=0;k<nz;k++) {
153: w0 = _mm256_set1_pd((t+bs*vi[k])[0]);
154: a0 = _mm256_loadu_pd(&v[ 0]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
155: a1 = _mm256_loadu_pd(&v[ 4]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
156: a2 = _mm256_loadu_pd(&v[ 8]); s2 = _mm256_fnmadd_pd(a2,w0,s2);
158: w1 = _mm256_set1_pd((t+bs*vi[k])[1]);
159: a3 = _mm256_loadu_pd(&v[ 9]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
160: a4 = _mm256_loadu_pd(&v[13]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
161: a5 = _mm256_loadu_pd(&v[17]); s2 = _mm256_fnmadd_pd(a5,w1,s2);
163: w2 = _mm256_set1_pd((t+bs*vi[k])[2]);
164: a0 = _mm256_loadu_pd(&v[18]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
165: a1 = _mm256_loadu_pd(&v[22]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
166: a2 = _mm256_loadu_pd(&v[26]); s2 = _mm256_fnmadd_pd(a2,w2,s2);
168: w3 = _mm256_set1_pd((t+bs*vi[k])[3]);
169: a3 = _mm256_loadu_pd(&v[27]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
170: a4 = _mm256_loadu_pd(&v[31]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
171: a5 = _mm256_loadu_pd(&v[35]); s2 = _mm256_fnmadd_pd(a5,w3,s2);
173: w0 = _mm256_set1_pd((t+bs*vi[k])[4]);
174: a0 = _mm256_loadu_pd(&v[36]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
175: a1 = _mm256_loadu_pd(&v[40]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
176: a2 = _mm256_loadu_pd(&v[44]); s2 = _mm256_fnmadd_pd(a2,w0,s2);
178: w1 = _mm256_set1_pd((t+bs*vi[k])[5]);
179: a3 = _mm256_loadu_pd(&v[45]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
180: a4 = _mm256_loadu_pd(&v[49]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
181: a5 = _mm256_loadu_pd(&v[53]); s2 = _mm256_fnmadd_pd(a5,w1,s2);
183: w2 = _mm256_set1_pd((t+bs*vi[k])[6]);
184: a0 = _mm256_loadu_pd(&v[54]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
185: a1 = _mm256_loadu_pd(&v[58]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
186: a2 = _mm256_loadu_pd(&v[62]); s2 = _mm256_fnmadd_pd(a2,w2,s2);
188: w3 = _mm256_set1_pd((t+bs*vi[k])[7]);
189: a3 = _mm256_loadu_pd(&v[63]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
190: a4 = _mm256_loadu_pd(&v[67]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
191: a5 = _mm256_loadu_pd(&v[71]); s2 = _mm256_fnmadd_pd(a5,w3,s2);
193: w0 = _mm256_set1_pd((t+bs*vi[k])[8]);
194: a0 = _mm256_loadu_pd(&v[72]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
195: a1 = _mm256_loadu_pd(&v[76]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
196: a2 = _mm256_maskload_pd(v+80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
197: s2 = _mm256_fnmadd_pd(a2,w0,s2);
198: v += bs2;
199: }
200: _mm256_storeu_pd(&s[0], s0);
201: _mm256_storeu_pd(&s[4], s1);
202: _mm256_maskstore_pd(&s[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), s2);
203: }
205: /* backward solve the upper triangular */
206: ls = a->solve_work + A->cmap->n;
207: for (i=n-1; i>=0; i--) {
208: v = aa + bs2*(adiag[i+1]+1);
209: vi = aj + adiag[i+1]+1;
210: nz = adiag[i] - adiag[i+1]-1;
211: PetscArraycpy(ls,t+i*bs,bs);
213: s0 = _mm256_loadu_pd(ls+0);
214: s1 = _mm256_loadu_pd(ls+4);
215: s2 = _mm256_maskload_pd(ls+8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
217: for (k=0; k<nz; k++) {
219: w0 = _mm256_set1_pd((t+bs*vi[k])[0]);
220: a0 = _mm256_loadu_pd(&v[ 0]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
221: a1 = _mm256_loadu_pd(&v[ 4]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
222: a2 = _mm256_loadu_pd(&v[ 8]); s2 = _mm256_fnmadd_pd(a2,w0,s2);
224: /* v += 9; */
225: w1 = _mm256_set1_pd((t+bs*vi[k])[1]);
226: a3 = _mm256_loadu_pd(&v[ 9]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
227: a4 = _mm256_loadu_pd(&v[13]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
228: a5 = _mm256_loadu_pd(&v[17]); s2 = _mm256_fnmadd_pd(a5,w1,s2);
230: /* v += 9; */
231: w2 = _mm256_set1_pd((t+bs*vi[k])[2]);
232: a0 = _mm256_loadu_pd(&v[18]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
233: a1 = _mm256_loadu_pd(&v[22]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
234: a2 = _mm256_loadu_pd(&v[26]); s2 = _mm256_fnmadd_pd(a2,w2,s2);
236: /* v += 9; */
237: w3 = _mm256_set1_pd((t+bs*vi[k])[3]);
238: a3 = _mm256_loadu_pd(&v[27]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
239: a4 = _mm256_loadu_pd(&v[31]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
240: a5 = _mm256_loadu_pd(&v[35]); s2 = _mm256_fnmadd_pd(a5,w3,s2);
242: /* v += 9; */
243: w0 = _mm256_set1_pd((t+bs*vi[k])[4]);
244: a0 = _mm256_loadu_pd(&v[36]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
245: a1 = _mm256_loadu_pd(&v[40]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
246: a2 = _mm256_loadu_pd(&v[44]); s2 = _mm256_fnmadd_pd(a2,w0,s2);
248: /* v += 9; */
249: w1 = _mm256_set1_pd((t+bs*vi[k])[5]);
250: a3 = _mm256_loadu_pd(&v[45]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
251: a4 = _mm256_loadu_pd(&v[49]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
252: a5 = _mm256_loadu_pd(&v[53]); s2 = _mm256_fnmadd_pd(a5,w1,s2);
254: /* v += 9; */
255: w2 = _mm256_set1_pd((t+bs*vi[k])[6]);
256: a0 = _mm256_loadu_pd(&v[54]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
257: a1 = _mm256_loadu_pd(&v[58]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
258: a2 = _mm256_loadu_pd(&v[62]); s2 = _mm256_fnmadd_pd(a2,w2,s2);
260: /* v += 9; */
261: w3 = _mm256_set1_pd((t+bs*vi[k])[7]);
262: a3 = _mm256_loadu_pd(&v[63]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
263: a4 = _mm256_loadu_pd(&v[67]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
264: a5 = _mm256_loadu_pd(&v[71]); s2 = _mm256_fnmadd_pd(a5,w3,s2);
266: /* v += 9; */
267: w0 = _mm256_set1_pd((t+bs*vi[k])[8]);
268: a0 = _mm256_loadu_pd(&v[72]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
269: a1 = _mm256_loadu_pd(&v[76]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
270: a2 = _mm256_maskload_pd(v+80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
271: s2 = _mm256_fnmadd_pd(a2,w0,s2);
272: v += bs2;
273: }
275: _mm256_storeu_pd(&ls[0], s0); _mm256_storeu_pd(&ls[4], s1); _mm256_maskstore_pd(&ls[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), s2);
277: w0 = _mm256_setzero_pd(); w1 = _mm256_setzero_pd(); w2 = _mm256_setzero_pd();
279: /* first row */
280: v0 = _mm256_set1_pd(ls[0]);
281: a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[0]); w0 = _mm256_fmadd_pd(a0,v0,w0);
282: a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[4]); w1 = _mm256_fmadd_pd(a1,v0,w1);
283: a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[8]); w2 = _mm256_fmadd_pd(a2,v0,w2);
285: /* second row */
286: v1 = _mm256_set1_pd(ls[1]);
287: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[9]); w0 = _mm256_fmadd_pd(a3,v1,w0);
288: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[13]); w1 = _mm256_fmadd_pd(a4,v1,w1);
289: a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[17]); w2 = _mm256_fmadd_pd(a5,v1,w2);
291: /* third row */
292: v2 = _mm256_set1_pd(ls[2]);
293: a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[18]); w0 = _mm256_fmadd_pd(a0,v2,w0);
294: a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[22]); w1 = _mm256_fmadd_pd(a1,v2,w1);
295: a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[26]); w2 = _mm256_fmadd_pd(a2,v2,w2);
297: /* fourth row */
298: v3 = _mm256_set1_pd(ls[3]);
299: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[27]); w0 = _mm256_fmadd_pd(a3,v3,w0);
300: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[31]); w1 = _mm256_fmadd_pd(a4,v3,w1);
301: a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[35]); w2 = _mm256_fmadd_pd(a5,v3,w2);
303: /* fifth row */
304: v0 = _mm256_set1_pd(ls[4]);
305: a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[36]); w0 = _mm256_fmadd_pd(a0,v0,w0);
306: a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[40]); w1 = _mm256_fmadd_pd(a1,v0,w1);
307: a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[44]); w2 = _mm256_fmadd_pd(a2,v0,w2);
309: /* sixth row */
310: v1 = _mm256_set1_pd(ls[5]);
311: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[45]); w0 = _mm256_fmadd_pd(a3,v1,w0);
312: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[49]); w1 = _mm256_fmadd_pd(a4,v1,w1);
313: a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[53]); w2 = _mm256_fmadd_pd(a5,v1,w2);
315: /* seventh row */
316: v2 = _mm256_set1_pd(ls[6]);
317: a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[54]); w0 = _mm256_fmadd_pd(a0,v2,w0);
318: a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[58]); w1 = _mm256_fmadd_pd(a1,v2,w1);
319: a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[62]); w2 = _mm256_fmadd_pd(a2,v2,w2);
321: /* eighth row */
322: v3 = _mm256_set1_pd(ls[7]);
323: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[63]); w0 = _mm256_fmadd_pd(a3,v3,w0);
324: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[67]); w1 = _mm256_fmadd_pd(a4,v3,w1);
325: a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[71]); w2 = _mm256_fmadd_pd(a5,v3,w2);
327: /* ninth row */
328: v0 = _mm256_set1_pd(ls[8]);
329: a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[72]); w0 = _mm256_fmadd_pd(a3,v0,w0);
330: a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[76]); w1 = _mm256_fmadd_pd(a4,v0,w1);
331: a2 = _mm256_maskload_pd((&(aa+bs2*adiag[i])[80]), _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
332: w2 = _mm256_fmadd_pd(a2,v0,w2);
334: _mm256_storeu_pd(&(t+i*bs)[0], w0); _mm256_storeu_pd(&(t+i*bs)[4], w1); _mm256_maskstore_pd(&(t+i*bs)[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), w2);
336: PetscArraycpy(x+i*bs,t+i*bs,bs);
337: }
339: VecRestoreArrayRead(bb,&b);
340: VecRestoreArray(xx,&x);
341: PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
342: return 0;
343: }
344: #endif