[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

double precision use of AltiVec for G4



Guys,

With Nick's help, I've got his kernel compiling under OS X.  For anyone
wanting to use the instructions under OS X, remember to throw the
-faltivec flag . . .

Anyway, with that hurdle crossed, I was able to whip together a quick
D/ZGEMM kernel that uses the Altivec's prefetch to get some speedup.
For DGEMM, the speedup appears to be roughly 12% better than the best
kernel we had before.  On my 533Mhz G4, the generated kernel got ~603Mflop,
and the one I include below gets ~679.  I've got an initial stab at adding
OS X to config as well.  All this is very preliminary, but when it comes
together, I'll hope for another developer release.

In the meantime, I include my first draft of a prefetched dgemm kernel in
case anyone is interested.

Cheers,
Clint

Here's the command I used for timing:

make ummcase pre=d MMFLAGS="-fomit-frame-pointer -O -traditional-cpp -fschedule-insns -faltivec" mmrout=../CASES/ATL_mm4x4x2_1p.c nb=56

/*
 *             Automatically Tuned Linear Algebra Software v3.3.0Dev
 **************** THIS IS AN UNSUPPORTED DEVELOPER RELEASE *****************
 *                    (C) Copyright 2000 R. Clint Whaley                     
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *   1. Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice, this list of conditions, and the following disclaimer in the
 *      documentation and/or other materials provided with the distribution.
 *   3. The name of the University of Tennessee, the ATLAS group,
 *      or the names of its contributers may not be used to endorse
 *      or promote products derived from this software without specific
 *      written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE. 
 *
 */
#include "atlas_misc.h"

#define ATL_AltiVec
#if KB > 64  /* want a column of KB to fit in prefetch block */
   #ifdef ATL_AltiVec
      #undef ATL_AltiVec
   #endif
#endif

#ifndef ATL_AltiVec
   #define vec_dst(addr, ctrlw, str)
#endif

void ATL_USERMM
   (const int M, const int N, const int K, const TYPE alpha, const TYPE *A, const int lda, const TYPE *B, const int ldb, const TYPE beta, TYPE *C, const int ldc)
/*
 * matmul with muladd=1, TA=T, TB=N, mu=4, nu=4, ku=2, prefetching A and B
 */
{
   const TYPE *stM = A + KB*MB;
   const TYPE *stN = B + KB*NB;
   const int incAn = -KB*MB;
   const int incBm = -KB;
   #define incAm KB3
   #define incBn KB4
   #ifdef TREAL
      #define incCm 4
      const int incCn = (((ldc) << 2)) - MB;
   #else
      #define incCm 8
      const int incCn = (((ldc) << 3)) - (MB+MB);
   #endif
   TYPE *pC0=C, *pC1=pC0+(ldc SHIFT), *pC2=pC1+(ldc SHIFT),*pC3=pC2+(ldc SHIFT);
   TYPE *bp = (TYPE *) &beta;
   const TYPE *pA0=A;
   const TYPE *pB0=B;
   #ifdef ATL_AltiVec
      unsigned char blksize=0; /* number of 16-byte chunks, 0-31 */
      unsigned char blkcount=0; /* number of blks, 0=256 */
      short blkstride;         /* stride in bytes between blocks 0=+32768 */
      int cwrd, ccwrd;
   #endif
   register int k;
   register TYPE rA0, rA1, rA2, rA3, ra0, ra1, ra2, ra3;
   register TYPE rB0, rB1, rB2, rB3, rb0, rb1, rb2, rb3;
   register TYPE rC0_0, rC1_0, rC2_0, rC3_0, rC0_1, rC1_1, rC2_1, rC3_1, 
                 rC0_2, rC1_2, rC2_2, rC3_2, rC0_3, rC1_3, rC2_3, rC3_3;

   #ifdef ATL_AltiVec
      if (KB == 64) blksize = 0;
      else blksize = (KB>>1);
      blkstride = blksize<<1;

      blkcount = 2;
      cwrd = blkstride | (blkcount<<16) | (blksize<<24);
      vec_dst((vector float *) B, cwrd, 0); 
      blkcount = KB;
      cwrd = blkstride | (blkcount<<16) | (blksize<<24);
      vec_dst((vector float *)A, cwrd, 3);
      blkcount = 1;
      cwrd = blkstride | (blkcount<<16) | (blksize<<24);
      vec_dst((vector float *)(B+KB2), cwrd, 1);
      vec_dst((vector float *)(B+KB3), cwrd, 2);
      ccwrd = 0 | (1<<16) | (4<<24);
   #endif
   do /* N-loop */
   {
      vec_dst((vector float *)(pB0+KB4), cwrd, 0);
      vec_dst((vector float *)(pB0+KB5), cwrd, 1);
      vec_dst((vector float *)(pB0+KB6), cwrd, 2);
      vec_dst((vector float *)(pB0+KB7), cwrd, 3);
      do /* M-loop */
      {
         #ifdef BETA0
            rC0_0 = rC1_0 = rC2_0 = rC3_0 =
            rC0_1 = rC1_1 = rC2_1 = rC3_1 =
            rC0_2 = rC1_2 = rC2_2 = rC3_2 =
            rC0_3 = rC1_3 = rC2_3 = rC3_3 = ATL_rzero;
         #else
            #ifdef TREAL
               rC0_0 = *pC0; rC1_0 = pC0[1]; rC2_0 = pC0[2]; rC3_0 = pC0[3];
               rC0_1 = *pC1; rC1_1 = pC1[1]; rC2_1 = pC1[2]; rC3_1 = pC1[3];
               rC0_2 = *pC2; rC1_2 = pC2[1]; rC2_2 = pC2[2]; rC3_2 = pC2[3];
               rC0_3 = *pC3; rC1_3 = pC3[1]; rC2_3 = pC3[2]; rC3_3 = pC3[3];
            #else
               rC0_0 = *pC0; rC1_0 = pC0[2]; rC2_0 = pC0[4]; rC3_0 = pC0[6];
               rC0_1 = *pC1; rC1_1 = pC1[2]; rC2_1 = pC1[4]; rC3_1 = pC1[6];
               rC0_2 = *pC2; rC1_2 = pC2[2]; rC2_2 = pC2[4]; rC3_2 = pC2[6];
               rC0_3 = *pC3; rC1_3 = pC3[2]; rC2_3 = pC3[4]; rC3_3 = pC3[6];
            #endif
            #ifdef BETAX
               rA0 = *bp;
               rC0_0 *= rA0; rC1_0 *= rA0; rC2_0 *= rA0; rC3_0 *= rA0;
               rC0_1 *= rA0; rC1_1 *= rA0; rC2_1 *= rA0; rC3_1 *= rA0;
               rC0_2 *= rA0; rC1_2 *= rA0; rC2_2 *= rA0; rC3_2 *= rA0;
               rC0_3 *= rA0; rC1_3 *= rA0; rC2_3 *= rA0; rC3_3 *= rA0;
            #endif
         #endif
     
         rA0 = *pA0; rA1 = pA0[KB]; rA2 = pA0[KB2]; rA3 = pA0[KB3]; pA0++;
         ra0 = *pA0; ra1 = pA0[KB]; ra2 = pA0[KB2]; ra3 = pA0[KB3]; pA0++;
         rB0 = *pB0; rB1 = pB0[KB]; rB2 = pB0[KB2]; rB3 = pB0[KB3]; pB0++;
         rb0 = *pB0; rb1 = pB0[KB]; rb2 = pB0[KB2]; rb3 = pB0[KB3]; pB0++;
         for (k=(KB>>1)-1; k; k --) /* easy loop to unroll */
         {
            rC0_0 += rA0 * rB0;
            rC1_0 += rA1 * rB0;
            rC2_0 += rA2 * rB0;
            rC3_0 += rA3 * rB0; rB0 = *pB0;
            rC0_1 += rA0 * rB1;
            rC1_1 += rA1 * rB1;
            rC2_1 += rA2 * rB1;
            rC3_1 += rA3 * rB1; rB1 = pB0[KB];
            rC0_2 += rA0 * rB2;
            rC1_2 += rA1 * rB2;
            rC2_2 += rA2 * rB2;
            rC3_2 += rA3 * rB2; rB2 = pB0[KB2];
            rC0_3 += rA0 * rB3; rA0 = *pA0;
            rC1_3 += rA1 * rB3; rA1 = pA0[KB];
            rC2_3 += rA2 * rB3; rA2 = pA0[KB2];
            rC3_3 += rA3 * rB3; rB3 = pB0[KB3];

            rC0_0 += ra0 * rb0; rA3 = pA0[KB3];
            rC1_0 += ra1 * rb0;
            rC2_0 += ra2 * rb0;
            rC3_0 += ra3 * rb0; rb0 = pB0[1];
            rC0_1 += ra0 * rb1;
            rC1_1 += ra1 * rb1;
            rC2_1 += ra2 * rb1;
            rC3_1 += ra3 * rb1; rb1 = pB0[KB+1];
            rC0_2 += ra0 * rb2;
            rC1_2 += ra1 * rb2;
            rC2_2 += ra2 * rb2;
            rC3_2 += ra3 * rb2; rb2 = pB0[KB2+1];
            rC0_3 += ra0 * rb3; ra0 = pA0[1];
            rC1_3 += ra1 * rb3; ra1 = pA0[KB+1];
            rC2_3 += ra2 * rb3; ra2 = pA0[KB2+1];
            rC3_3 += ra3 * rb3; rb3 = pB0[KB3+1]; ra3 = pA0[KB3+1];
            pB0 += 2; pA0 += 2;
         }
         rC0_0 += rA0 * rB0; vec_dst((vector float *)pC0, ccwrd, 0);
         rC1_0 += rA1 * rB0; vec_dst((vector float *)pC1, ccwrd, 1);
         rC2_0 += rA2 * rB0; vec_dst((vector float *)pC2, ccwrd, 2);
         rC3_0 += rA3 * rB0; vec_dst((vector float *)pC3, ccwrd, 3);
         rC0_1 += rA0 * rB1;
         rC1_1 += rA1 * rB1;
         rC2_1 += rA2 * rB1;
         rC3_1 += rA3 * rB1;
         rC0_2 += rA0 * rB2;
         rC1_2 += rA1 * rB2;
         rC2_2 += rA2 * rB2;
         rC3_2 += rA3 * rB2;
         rC0_3 += rA0 * rB3;
         rC1_3 += rA1 * rB3;
         rC2_3 += rA2 * rB3;
         rC3_3 += rA3 * rB3;

         rC0_0 += ra0 * rb0; vec_dst((vector float *)(pA0+KB4), cwrd, 0);
         rC1_0 += ra1 * rb0; vec_dst((vector float *)(pA0+KB5), cwrd, 1);
         rC2_0 += ra2 * rb0; vec_dst((vector float *)(pA0+KB6), cwrd, 2);
         rC3_0 += ra3 * rb0; vec_dst((vector float *)(pA0+KB7), cwrd, 3);
         rC0_1 += ra0 * rb1;
         rC1_1 += ra1 * rb1;
         rC2_1 += ra2 * rb1;
         rC3_1 += ra3 * rb1;
         rC0_2 += ra0 * rb2;
         rC1_2 += ra1 * rb2;
         rC2_2 += ra2 * rb2;
         rC3_2 += ra3 * rb2;
         rC0_3 += ra0 * rb3;
         rC1_3 += ra1 * rb3;
         rC2_3 += ra2 * rb3;
         rC3_3 += ra3 * rb3;
         #ifdef TREAL
            *pC0 = rC0_0; pC0[1] = rC1_0; pC0[2] = rC2_0; pC0[3] = rC3_0;
            *pC1 = rC0_1; pC1[1] = rC1_1; pC1[2] = rC2_1; pC1[3] = rC3_1;
            *pC2 = rC0_2; pC2[1] = rC1_2; pC2[2] = rC2_2; pC2[3] = rC3_2;
            *pC3 = rC0_3; pC3[1] = rC1_3; pC3[2] = rC2_3; pC3[3] = rC3_3;
         #else
            *pC0 = rC0_0; pC0[2] = rC1_0; pC0[4] = rC2_0; pC0[6] = rC3_0;
            *pC1 = rC0_1; pC1[2] = rC1_1; pC1[4] = rC2_1; pC1[6] = rC3_1;
            *pC2 = rC0_2; pC2[2] = rC1_2; pC2[4] = rC2_2; pC2[6] = rC3_2;
            *pC3 = rC0_3; pC3[2] = rC1_3; pC3[4] = rC2_3; pC3[6] = rC3_3;
         #endif
         pC0 += incCm;
         pC1 += incCm;
         pC2 += incCm;
         pC3 += incCm;
         pA0 += incAm;
         pB0 += incBm;
      }
      while(pA0 != stM);
      pC0 += incCn;
      pC1 += incCn;
      pC2 += incCn;
      pC3 += incCn;
      pA0 += incAn;
      pB0 += incBn;
   }
   while(pB0 != stN);
}
#ifdef incAm
   #undef incAm
#endif
#ifdef incBn
   #undef incBn
#endif
#ifdef incCm
   #undef incCm
#endif