Added first version of 2D register tiling kernel with A and C transposed as well

pull/274/head
Cedric Nugteren 2018-04-03 21:18:40 +02:00
parent 63996eb68b
commit eae25f5727
2 changed files with 216 additions and 79 deletions

View File

@ -7,15 +7,19 @@
// Author(s):
// Cedric Nugteren <www.cedricnugteren.nl>
//
// This file contains an optimized matrix-multiplication kernel inspired by the paper by Matsumoto
// et al. and the tutorial on http://www.cedricnugteren.nl/tutorial.php. It is fully configurable
// (and tunable!) using more or less the same parameters/naming conventions as in the paper. It
// supports different data-types (SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM) through a pre-processor define.
// This file contains two optimized matrix-multiplication kernels:
// - Kernel 0: inspired by the paper by Matsumoto et al. and the tutorial on
// http://www.cedricnugteren.nl/tutorial.php
// - Kernel 1: inspired by a Qualcomm optimized GPU kernel with 2D register tiling
// https://developer.qualcomm.com/blog/matrix-multiply-adreno-gpus-part-2-host-code-and-kernel
// Both are fully configurable (and tunable!) using many parameters. Both kernels support
// different data-types (SGEMM/DGEMM/CGEMM/ZGEMM/HGEMM) through a pre-processor define.
//
// Matrices are accessed as follows:
// For kernel 0 matrices are accessed as follows:
// A: [k*M + m], with 'k' ranging from 0:K and 'm' from 0:M (m,k,m)
// B: [k*N + n], with 'k' ranging from 0:K and 'n' from 0:N (n,k,n)
// C: [n*M + m], with 'n' ranging from 0:N and 'm' from 0:M (m,n,m)
// For kernel 1, both A and C are transposed w.r.t. the above
//
// Or as an image (assuming column-major)
// K
@ -31,7 +35,7 @@
// o-------o o-----o
//
//
// This kernel is separated into three files. This is part 1 out of 4.
// This kernel is separated into multiple files. This is part 1 out of 4.
//
// =================================================================================================
@ -43,6 +47,9 @@ R"(
// Parameters set by the tuner or by the database. Here they are given a basic default value in case
// this kernel file is used outside of the CLBlast library.
#ifndef GEMMK
#define GEMMK 0 // Kernel to choose: 0 regular, 1 with 2D register tiling
#endif
#ifndef MWG
#define MWG 8 // Tile-size in dimension M (e.g. 64, 128)
#endif
@ -59,10 +66,10 @@ R"(
#define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32)
#endif
#ifndef MDIMA
#define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
#define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA (kernel 0 only)
#endif
#ifndef NDIMB
#define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
#define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB (kernel 0 only)
#endif
#ifndef KWI
#define KWI 1 // Unroll factor of the KWG loop (smaller or equal than KWG)
@ -74,16 +81,19 @@ R"(
#define VWN 1 // Vector width of matrix B
#endif
#ifndef STRM
#define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0)
#define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) (kernel 0 only)
#endif
#ifndef STRN
#define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0)
#define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) (kernel 0 only)
#endif
#ifndef SA
#define SA 0 // Use local/shared memory to cache matrix A (1) or not (0)
#define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) (kernel 0 only)
#endif
#ifndef SB
#define SB 0 // Use local/shared memory to cache matrix B (1) or not (0)
#define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) (kernel 0 only)
#endif
#ifndef KREG
#define KREG 1 // Amount of register tiling in second dimension, multiple of VWN (kernel 1 only)
#endif
// Helper parameters based on the above tuning parameters
@ -244,7 +254,7 @@ INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm, LOCAL_PTR re
// Caches global off-chip memory directly into per-thread private memory (registers). This function
// is specific for caching the A input matrix.
#if SA == 0
#if SA == 0 && GEMMK == 0
INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int _mi,
const int kSizeM, const int idk, const int kwg) {
// Computes the indices based on strided/non-strided access
@ -263,7 +273,7 @@ INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int
#endif
// Same as above, but now for the B input matrix
#if SB == 0
#if SB == 0 && GEMMK == 0
INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int _ni,
const int kSizeN, const int idk) {
// Computes the indices based on strided/non-strided access
@ -281,6 +291,45 @@ INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int
}
#endif
// =================================================================================================
#if GEMMK == 1
// Caches global off-chip memory directly into per-thread private memory (registers). This function
// is specific for caching the A input matrix for kernel 1.
INLINE_FUNC realN GlobalToPrivateA2D(const __global real* restrict a_ptr, const int tid_y, const int _ni,
const int kSizeK, const int idk, const int _ki) {
const int a_index = (tid_y * NWI + _ni) * kSizeK + idk + _ki * VWN;
#if VWN == 1
return a_ptr[a_index];
#elif VWN == 2
return vload2(0, a_ptr + a_index);
#elif VWN == 4
return vload4(0, a_ptr + a_index);
#elif VWN == 8
return vload8(0, a_ptr + a_index);
#elif VWN == 16
return vload16(0, a_ptr + a_index);
#endif
}
// Same as above, but now for the B input matrix
INLINE_FUNC realM GlobalToPrivateB2D(const __global real* restrict b_ptr, const int tid_x, const int _mi,
const int kSizeN, const int idk, const int _ki) {
const int b_index = (idk + _ki) * kSizeN + tid_x * MWI + _mi * VWM;
#if VWM == 1
return b_ptr[b_index];
#elif VWM == 2
return vload2(0, b_ptr + b_index);
#elif VWM == 4
return vload4(0, b_ptr + b_index);
#elif VWM == 8
return vload8(0, b_ptr + b_index);
#elif VWM == 16
return vload16(0, b_ptr + b_index);
#endif
}
#endif
// =================================================================================================
// Caches on-chip local memory into per-thread private memory (registers). This function is specific

View File

@ -31,12 +31,26 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
) {
// Allocates workitem-private memory (registers)
#if GEMMK == 0
#pragma promote_to_registers
realM apm[MWI/VWM]; // MWI * 1
#pragma promote_to_registers
realN bpm[NWI/VWN]; // 1 * NWI
#elif GEMMK == 1
#pragma promote_to_registers
realN apm[NWI*(KREG/VWN)]; // NWI * KREG
#pragma promote_to_registers
realM bpm[KREG*(MWI/VWM)]; // KREG * MWI
#endif
#pragma promote_to_registers
realM apm[MWI/VWM];
#pragma promote_to_registers
realN bpm[NWI/VWN];
#pragma promote_to_registers
realM cpm[NWI*(MWI/VWM)];
realM cpm[NWI*(MWI/VWM)]; // NWI * MWI
#if GEMMK == 1
const __global real* restrict a_ptr = (const __global real* restrict) &agm[0];
const __global real* restrict b_ptr = (const __global real* restrict) &bgm[0];
const int tid_x = get_global_id(0);
const int tid_y = get_global_id(1);
#endif
// Combined thread identifier (volatile to disable caching)
#if SA == 1 || SB == 1
@ -52,9 +66,8 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
}
}
// Loops over all workgroup tiles
for (int kwg = 0; kwg < kSizeK; kwg += KWG) {
for (int kwg = 0; kwg < kSizeK; kwg += KWG * KREG) {
// Loads data: off-chip --> local (matrix A)
#if SA == 1
@ -69,9 +82,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// Loops over all workitem tiles, unrolled by a factor KWI
for (int pwi = 0; pwi < KWG; pwi += KWI) {
for (int pwi = 0; pwi < KWG * KREG; pwi += KWI * KREG) {
#pragma unroll
for (int _pit = 0; _pit < KWI; _pit += 1) {
for (int _pit = 0; _pit < KWI * KREG; _pit += KREG) {
#if SA == 0 || SB == 0
int idk = kwg + pwi + _pit;
#endif
@ -79,73 +92,143 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
int kg = pwi + _pit;
#endif
// Loads matrix A (kernel 0) or matrix B (kernel 1)
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
// Loads data: local --> private (matrix A)
#if SA == 1
#if GEMMK == 0 && SA == 1
apm[_mi] = LocalToPrivateA(alm, _mi, kg);
// Loads data: off-chip --> private (matrix A)
#else
#elif GEMMK == 0 && SA == 0
apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg);
// Loads data: 2D global --> 2D private (matrix B)
#elif GEMMK == 1
#pragma unroll
for (int _ki = 0; _ki < KREG; _ki += 1) {
bpm[_ki * (MWI/VWM) + _mi] = GlobalToPrivateB2D(b_ptr, tid_x, _mi, kSizeN, idk, _ki);
}
#endif
}
// Loads data: local --> private (matrix B)
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
#if SB == 1
bpm[_ni] = LocalToPrivateB(blm, _ni, kg);
// Loads data: off-chip --> private (matrix B)
#else
bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk);
#endif
}
// Performs the accumulation (Cpm += Apm * Bpm)
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
// Loads matrix B (kernel 0) or matrix A (kernel 1)
#if GEMMK == 0
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
const realM aval = apm[_mi];
#if VWN == 1
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
#elif VWN == 2
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
#elif VWN == 4
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
#elif VWN == 8
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
#elif VWN == 16
cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
// Loads data: local --> private (matrix B)
#if SB == 1
bpm[_ni] = LocalToPrivateB(blm, _ni, kg);
// Loads data: off-chip --> private (matrix B)
#else
bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk);
#endif
}
}
#elif GEMMK == 1
// Loads data: 2D global --> 2D private (matrix A)
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
#pragma unroll
for (int _ki = 0; _ki < KREG/VWN; _ki += 1) {
apm[_ni * (KREG/VWN) + _ki] = GlobalToPrivateA2D(a_ptr, tid_y, _ni, kSizeK, idk, _ki);
}
}
#endif
// Performs the accumulation (Cpm += Apm * Bpm)
#if GEMMK == 0
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
const realM aval = apm[_mi];
#if VWN == 1
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
#elif VWN == 2
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
#elif VWN == 4
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
#elif VWN == 8
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
#elif VWN == 16
cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
#endif
}
}
#elif GEMMK == 1
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
#pragma unroll
for (int _ki = 0; _ki < KREG/VWN; _ki += 1) {
const int index = _ni * (MWI/VWM) + _mi;
const realN aval = apm[_ni * (KREG/VWN) + _ki];
#if VWN == 1
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval);
#elif VWN == 2
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y);
#elif VWN == 4
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.x);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.y);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.z);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.w);
#elif VWN == 8
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0) * (MWI/VWM) + _mi], aval.s0);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1) * (MWI/VWM) + _mi], aval.s1);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2) * (MWI/VWM) + _mi], aval.s2);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3) * (MWI/VWM) + _mi], aval.s3);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4) * (MWI/VWM) + _mi], aval.s4);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5) * (MWI/VWM) + _mi], aval.s5);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6) * (MWI/VWM) + _mi], aval.s6);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7) * (MWI/VWM) + _mi], aval.s7);
#elif VWN == 16
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 0 ) * (MWI/VWM) + _mi], aval.s0);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 1 ) * (MWI/VWM) + _mi], aval.s1);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 2 ) * (MWI/VWM) + _mi], aval.s2);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 3 ) * (MWI/VWM) + _mi], aval.s3);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 4 ) * (MWI/VWM) + _mi], aval.s4);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 5 ) * (MWI/VWM) + _mi], aval.s5);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 6 ) * (MWI/VWM) + _mi], aval.s6);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 7 ) * (MWI/VWM) + _mi], aval.s7);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 8 ) * (MWI/VWM) + _mi], aval.s8);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 9 ) * (MWI/VWM) + _mi], aval.s9);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 10) * (MWI/VWM) + _mi], aval.sA);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 11) * (MWI/VWM) + _mi], aval.sB);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 12) * (MWI/VWM) + _mi], aval.sC);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 13) * (MWI/VWM) + _mi], aval.sD);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 14) * (MWI/VWM) + _mi], aval.sE);
cpm[index] = MultiplyAddVector(cpm[index], bpm[(VWN * _ki + 15) * (MWI/VWM) + _mi], aval.sF);
#endif
}
}
}
#endif
}
}
@ -158,11 +241,16 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK,
#endif
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
#if GEMMK == 0
const int cld = kSizeM;
#elif GEMMK == 1
const int cld = kSizeN;
#endif
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, kSizeM, alpha, beta);
StoreResults(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, cld, alpha, beta);
}
}
}