diff --git a/src/lib.rs b/src/lib.rs index dfc2aa0..36e802a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,13 +1,13 @@ use std::fmt::Display; #[derive(Debug, Clone, Default)] -pub struct InterpLUT { - inner: Vec<(T, T)>, +pub struct InterpLUT { + inner: Vec<(f64, f64)>, } -impl InterpLUT { +impl InterpLUT { /// Creates a new, empty table. - pub fn new() -> InterpLUT { + pub fn new() -> InterpLUT { Self::default() } @@ -27,9 +27,42 @@ impl InterpLUT { self.inner.push((m, n)); self.sort(); } + + fn binary_search(&self, n: &f64) -> Result { + self.inner + .binary_search_by(|&(x, _)| x.partial_cmp(n).unwrap_or(std::cmp::Ordering::Equal)) + } + + pub fn get(&self, n: f64) -> Option { + let searched = self.binary_search(&n); + if let Ok(i) = searched { + return Some(self.inner[i].1); + } else if let Err(i) = searched { + // value was not found + if i == 0 || i == self.inner.len() { + return None; + } + let t = Self::inverse_lerp(self.inner[i - 1].0, self.inner[i].0, n); + let res = Self::lerp(self.inner[i - 1].1, self.inner[i].1, t); + return Some(res); + } + None + } + + fn inverse_lerp(low: f64, high: f64, n: f64) -> f64 { + let total_distance = high - low; + let partial_distance = n - low; + partial_distance / total_distance + } + + fn lerp(low: f64, high: f64, t: f64) -> f64 { + let total_distance = high - low; + let partial_distance = total_distance * t; + low + partial_distance + } } -impl Display for InterpLUT { +impl Display for InterpLUT { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let num_digits_largest = self .inner diff --git a/src/main.rs b/src/main.rs index 7879768..0db5ef6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,16 @@ use interplut::InterpLUT; fn main() { - let mut lut: InterpLUT = InterpLUT::new(); + let mut lut = InterpLUT::new(); lut.insert(31.0, 5.0); - lut.insert(5.0, 31.0); - lut.insert(5.0, 31.0); - lut.insert(5.0, 31.0); - lut.insert(4.0, 31.0); - lut.insert(3.0, 31.0); - lut.insert(2.0, 31.0); - lut.insert(1.0, 31.0); + lut.insert(5.0, 30.2); + lut.insert(4.0, 30.0); + lut.insert(3.0, 29.0); + lut.insert(2.0, 28.0); + lut.insert(1.0, 27.0); lut.insert(18.0, 31.0); println!("{}", lut); + dbg!(lut.get(0.5)); + dbg!(lut.get(17.)); }