Another very minor speedup on metal

This commit is contained in:
Iwan Kawrakow 2023-09-01 14:54:10 +03:00
parent 2cb47e0e16
commit e3ff8c20c8

View file

@ -133,19 +133,24 @@ kernel void kernel_soft_max(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
// broadcast //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
if (tpitg[0] == 0) { // the loop, and when that is done, buf[0] has the correct (synchronized) value
buf[0] = buf[0]; //if (tpitg[0] == 0) {
} // buf[0] = buf[0];
//}
threadgroup_barrier(mem_flags::mem_threadgroup); //threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0]; const float max = buf[0];
// parallel sum // parallel sum
buf[tpitg[0]] = 0.0f; buf[tpitg[0]] = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] += exp(psrc0[i00] - max); const float exp_psrc0 = exp(psrc0[i00] - max);
buf[tpitg[0]] += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
pdst[i00] = exp_psrc0;
} }
// reduce // reduce
@ -157,17 +162,18 @@ kernel void kernel_soft_max(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
// broadcast // broadcast - not needed, see above
if (tpitg[0] == 0) { //// broadcast
buf[0] = buf[0]; //if (tpitg[0] == 0) {
} // buf[0] = buf[0];
//}
threadgroup_barrier(mem_flags::mem_threadgroup); //threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0]; const float sum = buf[0];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] = exp(psrc0[i00] - max) / sum; pdst[i00] /= sum;
} }
} }
@ -214,25 +220,27 @@ kernel void kernel_norm(
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
// broadcast //// broadcast
if (tpitg == 0) { //if (tpitg == 0) {
sum[0] /= ne00; // sum[0] /= ne00;
} //}
threadgroup_barrier(mem_flags::mem_threadgroup); //threadgroup_barrier(mem_flags::mem_threadgroup);
const float mean = sum[0]; const float mean = sum[0];
// recenter // recenter and VARIANCE
device float * y = dst + tgpig*ne00; device float * y = dst + tgpig*ne00;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
y[i00] = x[i00] - mean;
}
// VARIANCE
// parallel sum
sum[tpitg] = 0.0f; sum[tpitg] = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) { for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
y[i00] = x[i00] - mean;
sum[tpitg] += y[i00] * y[i00]; sum[tpitg] += y[i00] * y[i00];
} }
//// VARIANCE
//// parallel sum
//sum[tpitg] = 0.0f;
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
// sum[tpitg] += y[i00] * y[i00];
//}
// reduce // reduce
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg/2; i > 0; i /= 2) { for (uint i = ntg/2; i > 0; i /= 2) {
@ -241,11 +249,11 @@ kernel void kernel_norm(
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
// broadcast //// broadcast
if (tpitg == 0) { //if (tpitg == 0) {
sum[0] /= ne00; // sum[0] /= ne00;
} //}
threadgroup_barrier(mem_flags::mem_threadgroup); //threadgroup_barrier(mem_flags::mem_threadgroup);
const float variance = sum[0]; const float variance = sum[0];
const float scale = 1.0f/sqrt(variance + eps); const float scale = 1.0f/sqrt(variance + eps);