diff --git a/common/sampling.cpp b/common/sampling.cpp index 5a5450982..45d68b26c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -168,77 +168,20 @@ static llama_token llama_sampling_sample_impl( bool is_resampling) { // Add a parameter to indicate if we are resampling const llama_sampling_params & params = ctx_sampling->params; - const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const float temp = params.temp; - const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; - const float penalty_repeat = params.penalty_repeat; - const float penalty_freq = params.penalty_freq; - const float penalty_present = params.penalty_present; const int mirostat = params.mirostat; const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; - const bool penalize_nl = params.penalize_nl; - - auto & prev = ctx_sampling->prev; - auto & cur = ctx_sampling->cur; + std::vector original_logits; + auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); + if (!is_resampling) { + GGML_ASSERT(!original_logits.empty()); + } llama_token id = 0; - // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); - // Declare original_logits at the beginning of the function scope - std::vector original_logits; - - if (!is_resampling) { - // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this. - original_logits = std::vector(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); - } - - // apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - if (ctx_cfg) { - float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); - llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); - } - - cur.clear(); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - - // apply penalties - const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; - const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); - if (penalty_tokens_used_size) { - const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - - llama_sample_repetition_penalties(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); - - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { - cur_p.data[idx].logit = nl_logit; - break; - } - } - } - } - - // If we are in the resampling phase, apply grammar checks before sampling logic - if (is_resampling && ctx_sampling->grammar != NULL) { - llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); - } - if (temp < 0.0) { // greedy sampling, with probs llama_sample_softmax(ctx_main, &cur_p); @@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl( return id; } -static llama_token_data_array llama_sample_probability_distribution_impl( +static llama_token_data_array llama_sampling_prepare_impl( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, - const int idx) { + const int idx, + bool apply_grammar, + std::vector * original_logits) { const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl( const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; const float penalty_present = params.penalty_present; + const bool penalize_nl = params.penalize_nl; auto & prev = ctx_sampling->prev; @@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); - // Declare original_logits at the beginning of the function scope - std::vector original_logits; + if (apply_grammar && original_logits != NULL) { + // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. + *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; + } // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl( } } - // apply grammar checks - if (ctx_sampling->grammar != NULL) { + // apply grammar checks before sampling logic + if (apply_grammar && ctx_sampling->grammar != NULL) { llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); } - llama_sample_softmax(ctx_main, &cur_p); return cur_p; } @@ -382,12 +329,14 @@ llama_token llama_sampling_sample( return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); } -llama_token_data_array llama_sampling_probability_distribution( +llama_token_data_array llama_sampling_prepare( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, - const int idx) { - return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx); + const int idx, + bool apply_grammar, + std::vector * original_logits) { + return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); } void llama_sampling_accept( diff --git a/common/sampling.h b/common/sampling.h index 79a998be8..56ed991b8 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -131,12 +131,14 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, int idx = 0); -// returns the probability that token of given id will be sampled -llama_token_data_array llama_sampling_probability_distribution( +// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. +llama_token_data_array llama_sampling_prepare( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, - int idx = 0); + int idx = 0, + bool apply_grammar = true, + std::vector * original_logits = nullptr); void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index e991b8846..8b31b678a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -219,7 +219,8 @@ int main(int argc, char ** argv) { if (params.sparams.temp > 0) { // stochastic verification - llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); + llama_sample_softmax(ctx_tgt, &dist_tgt); float p_tgt = 0, p_dft = 0; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); diff --git a/retrieval b/retrieval new file mode 100755 index 000000000..dd31789f8 Binary files /dev/null and b/retrieval differ