diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 8ca1874da..9ae4bc198 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -332,8 +332,8 @@ static void init_model(struct llama_model * input, struct my_llama_model * model assert_shape_1d(layer.attention_norm, hparams.n_embd); assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd); - assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd); - assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd); + assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd_gqa()); + assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd_gqa()); assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd); assert_shape_1d(layer.ffn_norm, hparams.n_embd); assert_shape_2d(layer.w1, hparams.n_embd, hparams.n_ff);