ggml : fix conv_2d batch mode (ggml/737)

Co-authored-by: bssrdf <bssrdf@gmail.com>
pull/1891/head
bssrdf 2024-02-20 14:17:09 -05:00 committed by Georgi Gerganov
parent eb23f4ef16
commit d352dbd163
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 3 additions and 1 deletions

4
ggml.c
View File

@ -5629,7 +5629,9 @@ struct ggml_tensor * ggml_conv_2d(
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])); // [OCIC, KH, KW] => [OC, IC * KH * KW]
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], a->ne[3], im2col->ne[3]); // [N, OC, OH, OW]
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
return result;
}