Skip to content

Commit

Permalink
Make a cosmetic adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanUkhov committed Aug 8, 2024
1 parent e564dfc commit ba55301
Showing 1 changed file with 93 additions and 93 deletions.
186 changes: 93 additions & 93 deletions src/gamma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub trait Gamma
where
Self: Sized,
{
/// Compute the gamma function.
fn gamma(self) -> Self;

/// Compute the real-valued digamma function.
///
/// The formula is as follows:
Expand Down Expand Up @@ -35,8 +38,12 @@ where
/// inference. University of London, 2003, pp. 265–266.
fn digamma(self) -> Self;

/// Compute the gamma function.
fn gamma(self) -> Self;
/// Compute the trigamma function.
///
/// The code is based on a [Julia implementation][1].
///
/// [1]: https://github.com/JuliaMath/SpecialFunctions.jl
fn trigamma(&self) -> Self;

/// Compute the regularized lower incomplete gamma function.
///
Expand All @@ -61,13 +68,6 @@ where

/// Compute the natural logarithm of the gamma function.
fn ln_gamma(self) -> (Self, i32);

/// Compute the trigamma function.
///
/// The code is based on a [Julia implementation][1].
///
/// [1]: https://github.com/JuliaMath/SpecialFunctions.jl
fn trigamma(&self) -> Self;
}

macro_rules! evaluate_polynomial(
Expand All @@ -78,6 +78,11 @@ macro_rules! evaluate_polynomial(

#[rustfmt::skip]
macro_rules! implement { ($kind:ty) => { impl Gamma for $kind {
#[inline]
fn gamma(self) -> Self {
self.tgamma()
}

fn digamma(self) -> Self {
let p = self;
if p <= 8.0 {
Expand All @@ -102,9 +107,40 @@ macro_rules! implement { ($kind:ty) => { impl Gamma for $kind {
)
}

#[inline]
fn gamma(self) -> Self {
self.tgamma()
fn trigamma(&self) -> Self {
let mut x: $kind = *self;
if x <= 0.0 {
return (<$kind>::PI * (<$kind>::PI * x).sin().recip()).powi(2)
- (1.0 - x).trigamma();
}

let mut psi: $kind = 0.0;
if x < 8.0 {
let n = (8.0 - x.floor()) as usize;
psi += x.recip().powi(2);
for v in 1..n {
psi += (x + (v as $kind)).recip().powi(2);
}
x += n as $kind;
}
let t = x.recip();
let w = t * t;
psi += t + 0.5 * w;
psi + t
* w
* evaluate_polynomial!(
w,
[
0.16666666666666666,
-0.03333333333333333,
0.023809523809523808,
-0.03333333333333333,
0.07575757575757576,
-0.2531135531135531,
1.1666666666666667,
-7.092156862745098,
]
)
}

fn inc_gamma(self, p: Self) -> Self {
Expand Down Expand Up @@ -205,42 +241,6 @@ macro_rules! implement { ($kind:ty) => { impl Gamma for $kind {
fn ln_gamma(self) -> (Self, i32) {
self.lgamma()
}

fn trigamma(&self) -> Self {
let mut x: $kind = *self;
if x <= 0.0 {
return (<$kind>::PI * (<$kind>::PI * x).sin().recip()).powi(2)
- (1.0 - x).trigamma();
}

let mut psi: $kind = 0.0;
if x < 8.0 {
let n = (8.0 - x.floor()) as usize;
psi += x.recip().powi(2);
for v in 1..n {
psi += (x + (v as $kind)).recip().powi(2);
}
x += n as $kind;
}
let t = x.recip();
let w = t * t;
psi += t + 0.5 * w;
psi + t
* w
* evaluate_polynomial!(
w,
[
0.16666666666666666,
-0.03333333333333333,
0.023809523809523808,
-0.03333333333333333,
0.07575757575757576,
-0.2531135531135531,
1.1666666666666667,
-7.092156862745098,
]
)
}
}}}

implement!(f32);
Expand All @@ -263,6 +263,51 @@ mod tests {
assert_eq!(-FRAC_PI_2 - 3.0 * LN_2 - EULER_MASCHERONI, 0.25.digamma());
}

#[test]
fn trigamma() {
#[cfg(feature = "no_std")]
use core::f64::consts::PI;
#[cfg(not(feature = "no_std"))]
use std::f64::consts::PI;
let x = vec![
0.1,
0.5,
1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
8.0,
9.0,
10.0,
-PI,
-2.0 * PI,
-3.0 * PI,
];
let y = vec![
101.43329915079276,
4.93480220054468,
1.6449340668482262,
0.6449340668482261,
0.39493406684822613,
0.28382295573711497,
0.221322955737115,
0.18132295573711496,
0.1535451779593372,
0.13313701469403108,
0.11751201469403139,
0.10516633568168575,
53.030438740085536,
16.206759250472963,
10.341296000533267,
];

let z = x.iter().map(|&x| x.trigamma()).collect::<Vec<_>>();
assert::close(&z, &y, 1e-12);
}

#[test]
fn inc_gamma_small_p() {
let p = 4.2;
Expand Down Expand Up @@ -330,49 +375,4 @@ mod tests {
let z = x.iter().map(|&x| x.inc_gamma(p)).collect::<Vec<_>>();
assert::close(&z, &y, 1e-12);
}

#[test]
fn trigamma() {
#[cfg(feature = "no_std")]
use core::f64::consts::PI;
#[cfg(not(feature = "no_std"))]
use std::f64::consts::PI;
let x = vec![
0.1,
0.5,
1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
8.0,
9.0,
10.0,
-PI,
-2.0 * PI,
-3.0 * PI,
];
let y = vec![
101.43329915079276,
4.93480220054468,
1.6449340668482262,
0.6449340668482261,
0.39493406684822613,
0.28382295573711497,
0.221322955737115,
0.18132295573711496,
0.1535451779593372,
0.13313701469403108,
0.11751201469403139,
0.10516633568168575,
53.030438740085536,
16.206759250472963,
10.341296000533267,
];

let z = x.iter().map(|&x| x.trigamma()).collect::<Vec<_>>();
assert::close(&z, &y, 1e-12);
}
}

0 comments on commit ba55301

Please sign in to comment.