Define non-positive temperature behavior (#720)

This commit is contained in:
Ivan Stepanov 2023-04-03 03:19:04 +03:00 committed by GitHub
parent a0c0516416
commit cd7fa95690
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1194,6 +1194,20 @@ static llama_vocab::id llama_sample_top_p_top_k(
const auto & logits = lctx.logits; const auto & logits = lctx.logits;
const auto * plogits = logits.data() + logits.size() - n_logits; const auto * plogits = logits.data() + logits.size() - n_logits;
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = plogits[0];
llama_vocab::id max_id = 0;
for (int i = 1; i < n_logits; ++i) {
if (plogits[i] > max_logit) {
max_logit = plogits[i];
max_id = i;
}
}
return max_id;
}
std::vector<std::pair<float, llama_vocab::id>> logits_id; std::vector<std::pair<float, llama_vocab::id>> logits_id;
logits_id.reserve(n_logits); logits_id.reserve(n_logits);