metal : fix bug in soft_max kernels (out-of-bounds access) (#3194)

This commit is contained in:
Georgi Gerganov 2023-09-15 20:17:24 +03:00 committed by GitHub
parent e3d87a6c36
commit c6f1491da0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -118,7 +118,7 @@ kernel void kernel_soft_max(
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
float lmax = psrc0[tpitg[0]];
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
lmax = MAX(lmax, psrc0[i00]);
}
@ -158,7 +158,7 @@ kernel void kernel_soft_max_4(
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
float4 lmax4 = psrc4[tpitg[0]];
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
lmax4 = fmax(lmax4, psrc4[i00]);
}