diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 2dc2988d3..e7d2ad592 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -575,10 +575,7 @@ static struct ggml_tensor * forward( // KQ_scaled = KQ / sqrt(n_embd/n_head) // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head)); // KQ_masked = mask_past(KQ_scaled) // KQ_masked shape [n_past + N, N, n_head, 1] @@ -844,10 +841,7 @@ static struct ggml_tensor * forward_batch( // KQ_scaled = KQ / sqrt(n_embd/n_head) // KQ_scaled shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ_scaled = - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head)); assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch); // KQ_masked = mask_past(KQ_scaled) @@ -1131,10 +1125,7 @@ static struct ggml_tensor * forward_lora( // KQ_scaled = KQ / sqrt(n_embd/n_head) // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = - ggml_scale(ctx0, - KQ, - ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head)); // KQ_masked = mask_past(KQ_scaled) // KQ_masked shape [n_past + N, N, n_head, 1] diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index c8754ce70..58fbe204d 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -309,7 +309,7 @@ static struct ggml_cgraph * build_graph_lora( ) { struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b); if (scaling != 1.0f) { - ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling)); + ab = ggml_scale(ctx, ab, scaling); } struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab); diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 6a668d764..7b1333a9d 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -269,7 +269,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h float rope_freq_scale = 1.0f; GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); - GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); if (rope_freq_scale != 1.0f) { hparams->rope_freq_scale = 1.0f / rope_freq_scale; } @@ -612,6 +612,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( const int n_rot = hparams.n_embd_head(); const int n_embd_head = hparams.n_embd_head(); const int n_embd_gqa = hparams.n_embd_gqa(); + const float rms_norm_eps = hparams.f_norm_rms_eps; const float rope_freq_base = hparams.rope_freq_base; const float rope_freq_scale = hparams.rope_freq_scale; @@ -680,10 +681,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( checkpoints.push_back(t01); } - struct ggml_tensor * kv_scale = NULL; - if (!enable_flash_attn) { - kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head)); - } + const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head); for (int il = 0; il < n_layer; ++il) { struct my_llama_layer & layer = model->layers[il]; @@ -781,32 +779,32 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( // make sure some tensors are not reallocated by inserting new temporary nodes depending on them int n_leafs_before = gb->n_leafs; int n_nodes_before = gb->n_nodes; - struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f); + // output tensors - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f)); // input gradient - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f)); GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL); ggml_allocr_alloc(alloc, t36->grad); // KQ_pos - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f)); // make sure base model tensors data cannot be used in viewable operations - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f)); for (int il = 0; il < n_layer; ++il) { struct my_llama_layer & layer = model->layers[il]; - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f)); } // allocating checkpoints in one block to reduce memory fragmentation diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 112465968..f06ec400d 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -330,12 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima ggml_repeat(ctx0, model.pre_ln_b, embeddings)); } - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_allocr_alloc(ctx->alloc, KQ_scale); - if (!ggml_allocr_is_measure(ctx->alloc)) { - ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head)); - } - // loop over layers for (int il = 0; il < n_layer - 1; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states @@ -356,7 +350,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima struct ggml_tensor * Q = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur)); - Q = ggml_scale_inplace(ctx0, Q, KQ_scale); + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index f7ed63365..4a9a2340b 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -369,10 +369,7 @@ static struct ggml_tensor * llama_build_train_graphs( checkpoints.push_back(t00); checkpoints.push_back(t01); - struct ggml_tensor * kv_scale = NULL; - if (!enable_flash_attn) { - kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head)); - } + const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head); for (int il = 0; il < n_layer; ++il) { struct my_llama_layer & layer = model->layers[il]; @@ -444,14 +441,13 @@ static struct ggml_tensor * llama_build_train_graphs( // make sure some tensors are not reallocated by inserting new temporary nodes depending on them int n_leafs_before = gb->n_leafs; int n_nodes_before = gb->n_nodes; - struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f); // output tensors - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one)); - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f)); // input gradient - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f)); // KQ_pos - ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one)); + ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f)); GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL); ggml_allocr_alloc(alloc, t36->grad); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f5e060d32..ac91ee12e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7700,17 +7700,9 @@ inline void ggml_cuda_op_scale( const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - float scale; - // HACK: support for ggml backend interface - if (src1->backend == GGML_BACKEND_CPU) { - scale = ((float *) src1->data)[0]; - } else { - // TODO: pass pointer to kernel instead of copying to host - CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost)); - } + const float scale = ((float *) dst->op_params)[0]; scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream); CUDA_CHECK(cudaGetLastError()); @@ -7757,8 +7749,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU; const bool dst_on_device = dst->backend == GGML_BACKEND_GPU; - const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE; - // dd = data device float * src0_ddf = nullptr; float * src1_ddf = nullptr; @@ -7779,7 +7769,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream)); } - if (use_src1 && !src1_stays_on_host) { + if (use_src1) { if (src1_on_device) { src1_ddf = (float *) src1_extra->data_device[g_main_device]; } else { diff --git a/ggml-metal.m b/ggml-metal.m index e60b93b36..51a72ae33 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1293,7 +1293,7 @@ void ggml_metal_graph_compute( { GGML_ASSERT(ggml_is_contiguous(src0)); - const float scale = *(const float *) src1->data; + const float scale = *(const float *) dst->op_params; int64_t n = ggml_nelements(dst); @@ -1304,8 +1304,8 @@ void ggml_metal_graph_compute( [encoder setComputePipelineState:ctx->pipeline_scale]; } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; diff --git a/ggml.c b/ggml.c index 236148514..f27920a2d 100644 --- a/ggml.c +++ b/ggml.c @@ -4171,23 +4171,23 @@ struct ggml_tensor * ggml_out_prod( static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b, + float s, bool inplace) { - GGML_ASSERT(ggml_is_scalar(b)); GGML_ASSERT(ggml_is_padded_1d(a)); bool is_node = false; - if (a->grad || b->grad) { + if (a->grad) { is_node = true; } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params(result, &s, sizeof(s)); + result->op = GGML_OP_SCALE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -4195,15 +4195,15 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_tensor * ggml_scale( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_scale_impl(ctx, a, b, false); + float s) { + return ggml_scale_impl(ctx, a, s, false); } struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_scale_impl(ctx, a, b, true); + float s) { + return ggml_scale_impl(ctx, a, s, true); } // ggml_set @@ -10325,19 +10325,17 @@ static void ggml_compute_forward_out_prod( static void ggml_compute_forward_scale_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } // scale factor - const float v = *(float *) src1->data; + const float v = *(float *) dst->op_params; const int ith = params->ith; const int nth = params->nth; @@ -10368,12 +10366,11 @@ static void ggml_compute_forward_scale_f32( static void ggml_compute_forward_scale( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_scale_f32(params, src0, src1, dst); + ggml_compute_forward_scale_f32(params, src0, dst); } break; default: { @@ -14383,7 +14380,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SCALE: { - ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_scale(params, tensor->src[0], tensor); } break; case GGML_OP_SET: { @@ -14839,7 +14836,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) { if (ggml_hash_contains(zero_table, a)) { - struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0)); + struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); } else { return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); @@ -14975,7 +14972,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad, ggml_scale(ctx, ggml_mul(ctx, src0, tensor->grad), - ggml_new_f32(ctx, 2.0f)), + 2.0f), zero_table); } } break; @@ -14989,7 +14986,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_div(ctx, tensor->grad, tensor), - ggml_new_f32(ctx, 0.5f)), + 0.5f), zero_table); } } break; @@ -15155,17 +15152,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { + const float s = ((float *) tensor->op_params)[0]; + src0->grad = ggml_add_or_set(ctx, src0->grad, - ggml_scale_impl(ctx, tensor->grad, src1, false), - zero_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)), + ggml_scale_impl(ctx, tensor->grad, s, false), zero_table); } } break; diff --git a/ggml.h b/ggml.h index b17314897..75918502b 100644 --- a/ggml.h +++ b/ggml.h @@ -1094,13 +1094,13 @@ extern "C" { GGML_API struct ggml_tensor * ggml_scale( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + float s); // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + float s); // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( diff --git a/llama.cpp b/llama.cpp index ba970ce8d..d6c192441 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4032,13 +4032,12 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * wo, struct ggml_tensor * wo_b, struct ggml_tensor * q_cur, - struct ggml_tensor * kq_scale, struct ggml_tensor * kq_mask, int64_t n_ctx, int32_t n_tokens, int32_t n_kv, float max_alibi_bias, - float scale, + float kq_scale, const llm_build_cb & cb, int il) { const int64_t n_embd = hparams.n_embd; @@ -4086,7 +4085,7 @@ static struct ggml_tensor * llm_build_kqv( kq = ggml_soft_max(ctx, kq); cb(kq, "kq_soft_max", il); } else { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, scale); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale); cb(kq, "kq_soft_max_ext", il); } @@ -4231,10 +4230,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -4295,7 +4290,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4416,10 +4411,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -4478,7 +4469,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4536,10 +4527,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -4602,7 +4589,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4659,10 +4646,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -4702,7 +4685,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4759,10 +4742,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -4911,7 +4890,7 @@ struct llm_build_context { // TODO: not tested, could be broken cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Q, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -4965,10 +4944,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); cb(inpL, "inp_embd", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -5002,7 +4977,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5056,10 +5031,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); cb(inpL, "inp_embd", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -5099,7 +5070,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5150,10 +5121,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); cb(inpL, "inp_embd", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -5193,7 +5160,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5253,10 +5220,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -5306,7 +5269,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5366,10 +5329,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -5423,7 +5382,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } @@ -5482,14 +5441,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // Q_scale - struct ggml_tensor * Q_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(Q_scale, "Q_scale", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); @@ -5531,7 +5482,9 @@ struct llm_build_context { ); cb(Qcur, "Qcur", il); - Qcur = ggml_scale(ctx0, Qcur, Q_scale); + // with phi2, we scale the Q to avoid precision issues + // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 + Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); Kcur = ggml_rope_custom( @@ -5544,7 +5497,7 @@ struct llm_build_context { cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il); cb(cur, "kqv_out", il); } @@ -5681,8 +5634,6 @@ static const std::unordered_map k_offload_map { "pos_embd", OFFLOAD_FUNC_NR }, { "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope) - { "Q_scale", OFFLOAD_FUNC_NOP }, - { "KQ_scale", OFFLOAD_FUNC_NOP }, { "KQ_mask", OFFLOAD_FUNC_FRC }, { "K_shift", OFFLOAD_FUNC_FRC }, @@ -5784,8 +5735,6 @@ static struct ggml_cgraph * llama_build_graph( bool alloc_inp_tokens = false; bool alloc_inp_embd = false; bool alloc_inp_pos = false; - bool alloc_inp_Q_scale = false; - bool alloc_inp_KQ_scale = false; bool alloc_inp_KQ_mask = false; bool alloc_inp_K_shift = false; @@ -5849,37 +5798,6 @@ static struct ggml_cgraph * llama_build_graph( alloc_inp_pos = true; } - if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) { - ggml_allocr_alloc(lctx.alloc, cur); - - if (!ggml_allocr_is_measure(lctx.alloc)) { - const int64_t n_embd_head = model.hparams.n_embd_head(); - float f = 1.0f/sqrtf(float(n_embd_head)); - ggml_backend_tensor_set(cur, &f, 0, sizeof(f)); - } - - alloc_inp_Q_scale = true; - } - - if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) { - ggml_allocr_alloc(lctx.alloc, cur); - - if (!ggml_allocr_is_measure(lctx.alloc)) { - const int64_t n_embd_head = model.hparams.n_embd_head(); - float f; - if (model.arch == LLM_ARCH_PHI2) { - // with phi2, we scale the Q to avoid precision issues - // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 - f = 1.0f; - } else { - f = 1.0f/sqrtf(float(n_embd_head)); - } - ggml_backend_tensor_set(cur, &f, 0, sizeof(f)); - } - - alloc_inp_KQ_scale = true; - } - if (!alloc_inp_KQ_mask && strcmp(name, "KQ_mask") == 0) { ggml_allocr_alloc(lctx.alloc, cur); @@ -9054,10 +8972,7 @@ static int llama_apply_lora_from_file_internal( ggml_set_name(BA, "BA"); if (scaling != 1.0f) { - ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx.get(), scaling); - ggml_set_name(scale_tensor, "scale_tensor"); - - BA = ggml_scale_inplace(lora_ctx.get(), BA, scale_tensor); + BA = ggml_scale_inplace(lora_ctx.get(), BA, scaling); offload_func(BA); ggml_set_name(BA, "BA_scaled"); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f04b9438a..f3df8a8c6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -766,18 +766,19 @@ struct test_bin_bcast : public test_case { struct test_scale : public test_case { const ggml_type type; const std::array ne; + float scale; std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR3(type, ne, scale); } test_scale(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 10}) - : type(type), ne(ne) {} + std::array ne = {10, 10, 10, 10}, + float scale = 2.0f) + : type(type), ne(ne), scale(scale) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * scale = ggml_new_tensor_1d(ctx, type, 1); ggml_tensor * out = ggml_scale(ctx, a, scale); return out; } diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 81c20a89c..14914def5 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -881,19 +881,19 @@ int main(int argc, const char ** argv) { // scale { srand(seed); - const int nargs = 2; + const int nargs = 1; int64_t ne2[4]; ne2[0] = 1; for (int ndims = 1; ndims <= 2; ++ndims) { - x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - ggml_set_param(ctx0, x[1]); + const float s = -1.0f + 2.0f*frand(); - struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], x[1])); + ggml_set_param(ctx0, x[0]); + + struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s)); check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); } @@ -1395,7 +1395,7 @@ int main(int argc, const char ** argv) { ggml_add1(ctx0, ggml_scale(ctx0, ggml_soft_max(ctx0, x[0]), - ggml_new_f32(ctx0, 1.0f - eps)), + 1.0f - eps), ggml_new_f32(ctx0, eps)))); check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);