-
Notifications
You must be signed in to change notification settings - Fork 78
/
logits.rs
35 lines (30 loc) · 951 Bytes
/
logits.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// based on https://github.com/huggingface/candle/blob/main/candle-transformers/src/generation/mod.rs
use crate::config::{SamplingParams, SAMPLING_EPS};
use rand::SeedableRng;
pub struct LogitsProcessor {
pub rng: rand::rngs::StdRng,
pub temperature: Option<f32>,
pub top_p: f32,
}
impl LogitsProcessor {
pub fn new(sampling_params: &SamplingParams) -> Self {
let temperature = if sampling_params.temperature < SAMPLING_EPS {
None
} else {
Some(sampling_params.temperature)
};
Self {
rng: rand::rngs::StdRng::from_entropy(),
// seed_from_u64(42),
temperature,
top_p: sampling_params.top_p,
}
}
pub fn set_temperature(&mut self, temperature: f32) {
if temperature < SAMPLING_EPS {
self.temperature = None;
} else {
self.temperature = Some(temperature);
}
}
}