Let's say we decide to cover the basics, the classical 3 do loop implementation of matmul would be:
void ATL_USERMM (const int M, const int N, const int K, const double alpha, const double *A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc) { int i, j, k; register double c00; for (j=0; j < N; j++) { for (i=0; i < M; i++) { #ifdef BETA0 c00 = 0.0; #elif defined(BETA1) c00 = C[i+j*ldc]; #else c00 = C[i+j*ldc] * beta; #endif for (k=0; k < K; k++) c00 += A[k+i*lda] * B[k+j*ldb]; C[i+j*ldc] = c00; } } }
We then save this paragon of performance to ATLAS/tune/blas/gemm/CASES/ATL_mm1x1x1.c. From ATLAS/tune/blas/gemm/<arch>, we can test that it gets the right answer by:
make mmutstcase pre=d nb=40 mmrout=../CASES/ATL_mm1x1x1.c beta=0 make mmutstcase pre=d nb=40 mmrout=../CASES/ATL_mm1x1x1.c beta=1 make mmutstcase pre=d nb=40 mmrout=../CASES/ATL_mm1x1x1.c beta=7
We pass four arguments to mmutstcase, a precision specifier (d : double precision real; s : single precision real; z : double precision complex; c : single precision complex), the size of the blocking parameter , the beta value to test (0, 1, and other), and finally, the filename to test.
If these messages pass the test, we can then see what kind of performance we get by (this is the actual output on my 266Mhz PII):
make ummcase pre=d nb=40 mmrout=../CASES/ATL_mm1x1x1.c beta=1 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=1.820000, mflop=53.731868 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=1.810000, mflop=54.028729 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=1.830000, mflop=53.438251
This is the same timing repeated three times (this just tries to ensure timings are repeatable), and the only output of real interest is the MFLOP rate at the end. The values the timer prints (mu, nu, ku, lat) are all defaults because we didn't specify them; specifying them has no effect when the timer is used in this way, so don't worry about them.
Now we can trivially improve the implementation by using the macro constants in order to let the compiler unroll the loops:
void ATL_USERMM (const int M, const int N, const int K, const double alpha, const double *A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc) { int i, j, k; register double c00; for (j=0; j < NB; j++) { for (i=0; i < MB; i++) { #ifdef BETA0 c00 = 0.0; #elif defined(BETA1) c00 = C[i+j*ldc]; #else c00 = C[i+j*ldc] * beta; #endif for (k=0; k < KB; k++) c00 += A[k+i*KB] * B[k+j*KB]; C[i+j*ldc] = c00; } } }
We save this to ATL_mm1x1x1b.c, and then time:
make ummcase pre=d nb=40 mmrout=../CASES/ATL_mm1x1x1b.c beta=1 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=1.670000, mflop=58.558084 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=1.660000, mflop=58.910843 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=1.670000, mflop=58.558084
OK, maybe a little explicit loop unrolling will make things work better:
void ATL_USERMM (const int M, const int N, const int K, const double alpha, const double *A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc) { int i, j, k; register double c00, c10, b0; const double *pA0, *pB=B; #if ( (KB / 8)*8 != KB ) || (MB / 2)*2 != MB create syntax error!$@@& #endif for (j=0; j < NB; j++, pB += KB) { pA0 = A; for (i=0; i < MB; i += 2, pA0 += KB2) { #ifdef BETA0 c00 = c10 = 0.0; #elif defined(BETA1) c00 = C[i+j*ldc]; c10 = C[i+1+j*ldc]; #else c00 = beta*C[i+j*ldc]; c10 = beta*C[i+1+j*ldc]; #endif for (k=0; k < KB; k += 8) { b0 = pB[k]; c00 += pA0[k] * b0; c10 += pA0[KB+k] * b0; b0 = pB[k+1]; c00 += pA0[k+1] * b0; c10 += pA0[KB+k+1] * b0; b0 = pB[k+2]; c00 += pA0[k+2] * b0; c10 += pA0[KB+k+2] * b0; b0 = pB[k+3]; c00 += pA0[k+3] * b0; c10 += pA0[KB+k+3] * b0; b0 = pB[k+4]; c00 += pA0[k+4] * b0; c10 += pA0[KB+k+4] * b0; b0 = pB[k+5]; c00 += pA0[k+5] * b0; c10 += pA0[KB+k+5] * b0; b0 = pB[k+6]; c00 += pA0[k+6] * b0; c10 += pA0[KB+k+6] * b0; b0 = pB[k+7]; c00 += pA0[k+7] * b0; c10 += pA0[KB+k+7] * b0; } C[i+j*ldc] = c00; C[i+1+j*ldc] = c10; } } }
And with this ode to beauty and elegance we get (after checking that it still gets the right answer, of course):
make ummcase pre=d nb=40 mmrout=../CASES/ATL_mm2x1x8.c beta=1 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.720000, mflop=135.822222 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.710000, mflop=137.735211 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.710000, mflop=137.735211
Its interesting to see the effects of differing on the code:
make ummcase pre=d nb=40 mmrout=../CASES/ATL_mm2x1x8a.c beta=0 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.700000, mflop=139.702857 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.700000, mflop=139.702857 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.700000, mflop=139.702857 make ummcase pre=d nb=40 mmrout=../CASES/ATL_mm2x1x8a.c beta=7 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.720000, mflop=135.822222 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.730000, mflop=133.961644 dNB=40, ldc=40, mu=4, nu=4, ku=1, lat=4: time=0.720000, mflop=135.822222
As well as differing blocking factors:
make ummcase pre=d mmrout=../CASES/ATL_mm2x1x8a.c beta=1 nb=16 dNB=16, ldc=16, mu=4, nu=4, ku=1, lat=4: time=0.850000, mflop=115.112056 dNB=16, ldc=16, mu=4, nu=4, ku=1, lat=4: time=0.860000, mflop=113.773544 dNB=16, ldc=16, mu=4, nu=4, ku=1, lat=4: time=0.850000, mflop=115.112056 make ummcase pre=d mmrout=../CASES/ATL_mm2x1x8a.c beta=1 nb=32 dNB=32, ldc=32, mu=4, nu=4, ku=1, lat=4: time=0.730000, mflop=134.034586 dNB=32, ldc=32, mu=4, nu=4, ku=1, lat=4: time=0.740000, mflop=132.223308 dNB=32, ldc=32, mu=4, nu=4, ku=1, lat=4: time=0.730000, mflop=134.034586 make ummcase pre=d mmrout=../CASES/ATL_mm2x1x8a.c beta=1 nb=48 dNB=48, ldc=48, mu=4, nu=4, ku=1, lat=4: time=0.820000, mflop=119.223571 dNB=48, ldc=48, mu=4, nu=4, ku=1, lat=4: time=0.820000, mflop=119.223571 dNB=48, ldc=48, mu=4, nu=4, ku=1, lat=4: time=0.820000, mflop=119.223571
If we wanted to have ATLAS try these crappy implementations during the ATLAS search, we would have the following ATLAS/tune/blas/gemm/CASES/dcases.dsc:
<ID> <flag> <mb> <nb> <kb> <muladd> <lat> <mu> <nu> <ku> <rout> "<Contributer>" 3 1 0 0 0 0 1 1 1 1 1 ATL_mm1x1x1.c "R. Clint Whaley" 2 0 1 1 1 1 1 1 1 1 ATL_mm1x1x1b.c "R. Clint Whaley" 3 0 2 1 8 1 1 2 1 8 ATL_mm2x1x8a.c "R. Clint Whaley"