Skip to content

Commit

Permalink
Clean up source a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
tgross35 committed Aug 19, 2024
1 parent 492bc0b commit 34483e7
Showing 1 changed file with 89 additions and 105 deletions.
194 changes: 89 additions & 105 deletions src/float/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,83 +116,6 @@ mod fmt {
impl Display for i128 {}
}

/// Calculate the number of iterations required to get needed precision of a float type.
///
/// This returns `(h, f)` where `h` is the number of iterations to be donei using integers
/// at half the float's width, and `f` is the number of iterations done using integers of the
/// float's full width. Doing some iterations at half width is an optimization when the float
/// is larger than a word.
///
/// ASSUMPTION: the initial estimate should have at least 8 bits of precision. If this is not
/// true, results will be inaccurate.
const fn get_iterations<F: Float>() -> (usize, usize) {
// Precision doubles with each iteration
let total_iterations = F::BITS.ilog2() as usize - 2;

// If widening multiplication will be efficient (uses word-sized integers), there is no reason
// to use half-sized iterations.
// TODO: use half iterations.
if 2 * size_of::<F>() <= size_of::<*const ()>() {
(0, total_iterations)
} else {
(total_iterations - 1, 1)
}
}

/// u_n for different precisions (with N-1 half-width iterations):
/// W0 is the precision of C
/// u_0 = (3/4 - 1/sqrt(2) + 2^-W0) * 2^HW
///
/// Estimated with bc:
/// define half1(un) { return 2.0 * (un + un^2) / 2.0^hw + 1.0; }
/// define half2(un) { return 2.0 * un / 2.0^hw + 2.0; }
/// define full1(un) { return 4.0 * (un + 3.01) / 2.0^hw + 2.0 * (un + 3.01)^2 + 4.0; }
/// define full2(un) { return 4.0 * (un + 3.01) / 2.0^hw + 8.0; }
///
/// | f32 (0 + 3) | f32 (2 + 1) | f64 (3 + 1) | f128 (4 + 1)
/// u_0 | < 184224974 | < 2812.1 | < 184224974 | < 791240234244348797
/// u_1 | < 15804007 | < 242.7 | < 15804007 | < 67877681371350440
/// u_2 | < 116308 | < 2.81 | < 116308 | < 499533100252317
/// u_3 | < 7.31 | | < 7.31 | < 27054456580
/// u_4 | | | | < 80.4
/// Final (U_N) | same as u_3 | < 72 | < 218 | < 13920
///
/// Add 2 to U_N due to final decrement.
const fn reciprocal_precision<F: Float>() -> u16 {
let (half_iterations, full_iterations) = get_iterations::<F>();

if full_iterations < 1 {
panic!("Must have at least one full iteration");
}

// FIXME(tgross35): calculate this programmatically
if F::BITS == 32 && half_iterations == 2 && full_iterations == 1 {
74u16
} else if F::BITS == 32 && half_iterations == 0 && full_iterations == 3 {
10
} else if F::BITS == 64 && half_iterations == 3 && full_iterations == 1 {
220
} else if F::BITS == 128 && half_iterations == 4 && full_iterations == 1 {
13922
} else {
panic!("Invalid number of iterations")
}
}

/// C is (3/4 + 1/sqrt(2)) - 1 truncated to W0 fractional bits as UQ0.HW
/// with W0 being either 16 or 32 and W0 <= HW.
/// That is, C is the aforementioned 3/4 + 1/sqrt(2) constant (from which
/// b/2 is subtracted to obtain x0) wrapped to [0, 1) range.
fn c_hw<F: Float>() -> HalfRep<F>
where
F::Int: DInt,
u128: CastInto<HalfRep<F>>,
{
const C_U128: u128 = 0x7504f333f9de6108b2fb1366eaa6a542;

const { C_U128 >> (u128::BITS - <HalfRep<F>>::BITS) }.cast()
}

fn div<F: Float>(a: F, b: F) -> F
where
F::Int: CastInto<u32>,
Expand Down Expand Up @@ -338,7 +261,7 @@ where
);

// Transform to a fixed-point representation by shifting the significand to the high bits. We
// know this is in the range [1.0, 2.0] since the explicit bit is set above.
// know this is in the range [1.0, 2.0] since the implicit bit is set to 1 above.
let b_uq1 = b_significand << (F::BITS - significand_bits - 1);

println!("b_uq1: {:#034x}", b_uq1);
Expand Down Expand Up @@ -390,10 +313,6 @@ where
// b/2 is subtracted to obtain x0) wrapped to [0, 1) range.
let c_hw = c_hw::<F>();

// guess!(HalfRep<F>);

// F::C_HW;

// Check that the top bit is set, i.e. value is within `[1, 2)`.
debug_assert!(b_uq1_hw & one_hw << (HalfRep::<F>::BITS - 1) > zero_hw);

Expand Down Expand Up @@ -643,34 +562,99 @@ where
F::from_repr(abs_result | quotient_sign)
}

mod implementation {
use crate::int::{DInt, HInt, Int};
use core::ops;

/// Perform one iteration at any width to approach `1/b`, given previous guess `x`. It returns
/// the next `x` as a UQ0 number.
///
/// This is the `x_{n+1} = 2*x_n - b*x_n^2` algorithm, implemented as `x_n * (2 - b*x_n)`.
pub fn iter_once<I>(x_uq0: I, b_uq1: I) -> I
where
I: Int + HInt,
<I as HInt>::D: ops::Shr<u32, Output = <I as HInt>::D>,
{
// `corr = 2 - b*x_n`
//
// This looks like `0 - b*x_n`. However, this works - in `UQ1`, `0.0 - x = 2.0 - x`.
let corr_uq1: I = I::ZERO.wrapping_sub(x_uq0.widen_mul(b_uq1).hi());
/// Calculate the number of iterations required to get needed precision of a float type.
///
/// This returns `(h, f)` where `h` is the number of iterations to be donei using integers
/// at half the float's width, and `f` is the number of iterations done using integers of the
/// float's full width. Doing some iterations at half width is an optimization when the float
/// is larger than a word.
///
/// ASSUMPTION: the initial estimate should have at least 8 bits of precision. If this is not
/// true, results will be inaccurate.
const fn get_iterations<F: Float>() -> (usize, usize) {
// Precision doubles with each iteration
let total_iterations = F::BITS.ilog2() as usize - 2;

if 2 * size_of::<F>() <= size_of::<*const ()>() {
// If widening multiplication will be efficient (uses word-sized integers), there is no
// reason to use half-sized iterations.
(0, total_iterations)
} else {
(total_iterations - 1, 1)
}
}

/// u_n for different precisions (with N-1 half-width iterations):
/// W0 is the precision of C
/// u_0 = (3/4 - 1/sqrt(2) + 2^-W0) * 2^HW
///
/// Estimated with bc:
/// define half1(un) { return 2.0 * (un + un^2) / 2.0^hw + 1.0; }
/// define half2(un) { return 2.0 * un / 2.0^hw + 2.0; }
/// define full1(un) { return 4.0 * (un + 3.01) / 2.0^hw + 2.0 * (un + 3.01)^2 + 4.0; }
/// define full2(un) { return 4.0 * (un + 3.01) / 2.0^hw + 8.0; }
///
/// | f32 (0 + 3) | f32 (2 + 1) | f64 (3 + 1) | f128 (4 + 1)
/// u_0 | < 184224974 | < 2812.1 | < 184224974 | < 791240234244348797
/// u_1 | < 15804007 | < 242.7 | < 15804007 | < 67877681371350440
/// u_2 | < 116308 | < 2.81 | < 116308 | < 499533100252317
/// u_3 | < 7.31 | | < 7.31 | < 27054456580
/// u_4 | | | | < 80.4
/// Final (U_N) | same as u_3 | < 72 | < 218 | < 13920
///
/// Add 2 to U_N due to final decrement.
const fn reciprocal_precision<F: Float>() -> u16 {
let (half_iterations, full_iterations) = get_iterations::<F>();

if full_iterations < 1 {
panic!("Must have at least one full iteration");
}

// `x_n * corr = x_n * (2 - b*x_n)`
(x_uq0.widen_mul(corr_uq1) >> (I::BITS - 1)).lo()
// FIXME(tgross35): calculate this programmatically
if F::BITS == 32 && half_iterations == 2 && full_iterations == 1 {
74u16
} else if F::BITS == 32 && half_iterations == 0 && full_iterations == 3 {
10
} else if F::BITS == 64 && half_iterations == 3 && full_iterations == 1 {
220
} else if F::BITS == 128 && half_iterations == 4 && full_iterations == 1 {
13922
} else {
panic!("Invalid number of iterations")
}
}

#[cfg(not(feature = "public-test-deps"))]
use implementation::*;
/// The value of `C` adjusted to half width.
///
/// C is (3/4 + 1/sqrt(2)) - 1 truncated to W0 fractional bits as UQ0.HW with W0 being either
/// 16 or 32 and W0 <= HW. That is, C is the aforementioned 3/4 + 1/sqrt(2) constant (from
/// which b/2 is subtracted to obtain x0) wrapped to [0, 1) range.
fn c_hw<F: Float>() -> HalfRep<F>
where
F::Int: DInt,
u128: CastInto<HalfRep<F>>,
{
const C_U128: u128 = 0x7504f333f9de6108b2fb1366eaa6a542;
const { C_U128 >> (u128::BITS - <HalfRep<F>>::BITS) }.cast()
}

/// Perform one iteration at any width to approach `1/b`, given previous guess `x`. It returns
/// the next `x` as a UQ0 number.
///
/// This is the `x_{n+1} = 2*x_n - b*x_n^2` algorithm, implemented as `x_n * (2 - b*x_n)`.
pub fn iter_once<I>(x_uq0: I, b_uq1: I) -> I
where
I: Int + HInt,
<I as HInt>::D: ops::Shr<u32, Output = <I as HInt>::D>,
{
// `corr = 2 - b*x_n`
//
// This looks like `0 - b*x_n`. However, this works - in `UQ1`, `0.0 - x = 2.0 - x`.
let corr_uq1: I = I::ZERO.wrapping_sub(x_uq0.widen_mul(b_uq1).hi());

#[cfg(feature = "public-test-deps")]
pub use implementation::*;
// `x_n * corr = x_n * (2 - b*x_n)`
(x_uq0.widen_mul(corr_uq1) >> (I::BITS - 1)).lo()
}

intrinsics! {
#[avr_skip]
Expand Down

0 comments on commit 34483e7

Please sign in to comment.