diff --git a/examples/solveh.rs b/examples/solveh.rs new file mode 100644 index 00000000..26214d6f --- /dev/null +++ b/examples/solveh.rs @@ -0,0 +1,33 @@ + +extern crate ndarray; +extern crate ndarray_linalg; + +use ndarray::*; +use ndarray_linalg::*; + +// Solve `Ax=b` for Hermite matrix A +fn solve() -> Result<(), error::LinalgError> { + let a: Array2 = random_hermite(3); // complex Hermite positive definite matrix + let b: Array1 = random(3); + println!("b = {:?}", &b); + let x = a.solveh(&b)?; + println!("Ax = {:?}", a.dot(&x));; + Ok(()) +} + +// Solve `Ax=b` for many b with fixed A +fn factorize() -> Result<(), error::LinalgError> { + let a: Array2 = random_hpd(3); + let f = a.factorizeh_into()?; + // once factorized, you can use it several times: + for _ in 0..10 { + let b: Array1 = random(3); + let _x = f.solveh_into(b)?; + } + Ok(()) +} + +fn main() { + solve().unwrap(); + factorize().unwrap(); +} diff --git a/src/lapack_traits/mod.rs b/src/lapack_traits/mod.rs index 5ad95ee5..89ce1a46 100644 --- a/src/lapack_traits/mod.rs +++ b/src/lapack_traits/mod.rs @@ -4,6 +4,7 @@ pub mod opnorm; pub mod qr; pub mod svd; pub mod solve; +pub mod solveh; pub mod cholesky; pub mod eigh; pub mod triangular; @@ -13,14 +14,18 @@ pub use self::eigh::*; pub use self::opnorm::*; pub use self::qr::*; pub use self::solve::*; +pub use self::solveh::*; pub use self::svd::*; pub use self::triangular::*; use super::error::*; use super::types::*; +pub type Pivot = Vec; + pub trait LapackScalar - : OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ + Triangular_ { + : OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ + { } impl LapackScalar for f32 {} diff --git a/src/lapack_traits/solve.rs b/src/lapack_traits/solve.rs index 1b7a2b1a..2ee09d63 100644 --- a/src/lapack_traits/solve.rs +++ b/src/lapack_traits/solve.rs @@ -6,9 +6,7 @@ use error::*; use layout::MatrixLayout; use types::*; -use super::{Transpose, into_result}; - -pub type Pivot = Vec; +use super::{Pivot, Transpose, into_result}; /// Wraps `*getrf`, `*getri`, and `*getrs` pub trait Solve_: Sized { diff --git a/src/lapack_traits/solveh.rs b/src/lapack_traits/solveh.rs new file mode 100644 index 00000000..39d2b6b1 --- /dev/null +++ b/src/lapack_traits/solveh.rs @@ -0,0 +1,53 @@ +//! Solve symmetric linear problem using the Bunch-Kaufman diagonal pivoting method. +//! +//! See also [the manual of dsytrf](http://www.netlib.org/lapack/lapack-3.1.1/html/dsytrf.f.html) + +use lapack::c; + +use error::*; +use layout::MatrixLayout; +use types::*; + +use super::{Pivot, UPLO, into_result}; + +pub trait Solveh_: Sized { + /// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf` + unsafe fn bk(MatrixLayout, UPLO, a: &mut [Self]) -> Result; + /// Wrapper of `*sytri` and `*hetri` + unsafe fn invh(MatrixLayout, UPLO, a: &mut [Self], &Pivot) -> Result<()>; + /// Wrapper of `*sytrs` and `*hetrs` + unsafe fn solveh(MatrixLayout, UPLO, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>; +} + +macro_rules! impl_solveh { + ($scalar:ty, $trf:path, $tri:path, $trs:path) => { + +impl Solveh_ for $scalar { + unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result { + let (n, _) = l.size(); + let mut ipiv = vec![0; n as usize]; + let info = $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv); + into_result(info, ipiv) + } + + unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { + let (n, _) = l.size(); + let info = $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv); + into_result(info, ()) + } + + unsafe fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { + let (n, _) = l.size(); + let nrhs = 1; + let ldb = 1; + let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), ipiv, b, ldb); + into_result(info, ()) + } +} + +}} // impl_solveh! + +impl_solveh!(f64, c::dsytrf, c::dsytri, c::dsytrs); +impl_solveh!(f32, c::ssytrf, c::ssytri, c::ssytrs); +impl_solveh!(c64, c::zhetrf, c::zhetri, c::zhetrs); +impl_solveh!(c32, c::chetrf, c::chetri, c::chetrs); diff --git a/src/lib.rs b/src/lib.rs index fab1aadf..d437cc39 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,6 +42,7 @@ pub mod operator; pub mod opnorm; pub mod qr; pub mod solve; +pub mod solveh; pub mod svd; pub mod trace; pub mod triangular; @@ -59,6 +60,7 @@ pub use operator::*; pub use opnorm::*; pub use qr::*; pub use solve::*; +pub use solveh::*; pub use svd::*; pub use trace::*; pub use triangular::*; diff --git a/src/solveh.rs b/src/solveh.rs new file mode 100644 index 00000000..62c3728f --- /dev/null +++ b/src/solveh.rs @@ -0,0 +1,153 @@ +//! Solve Hermite/Symmetric linear problems + +use ndarray::*; + +use super::convert::*; +use super::error::*; +use super::layout::*; +use super::types::*; + +pub use lapack_traits::{Pivot, UPLO}; + +pub trait SolveH { + fn solveh>(&self, a: &ArrayBase) -> Result> { + let mut a = replicate(a); + self.solveh_mut(&mut a)?; + Ok(a) + } + fn solveh_into>(&self, mut a: ArrayBase) -> Result> { + self.solveh_mut(&mut a)?; + Ok(a) + } + fn solveh_mut<'a, S: DataMut>(&self, &'a mut ArrayBase) -> Result<&'a mut ArrayBase>; +} + +pub struct FactorizedH { + pub a: ArrayBase, + pub ipiv: Pivot, +} + +impl SolveH for FactorizedH +where + A: Scalar, + S: Data, +{ + fn solveh_mut<'a, Sb>(&self, rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + unsafe { + A::solveh( + self.a.square_layout()?, + UPLO::Upper, + self.a.as_allocated()?, + &self.ipiv, + rhs.as_slice_mut().unwrap(), + )? + }; + Ok(rhs) + } +} + +impl SolveH for ArrayBase +where + A: Scalar, + S: Data, +{ + fn solveh_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase) -> Result<&'a mut ArrayBase> + where + Sb: DataMut, + { + let f = self.factorizeh()?; + f.solveh_mut(rhs) + } +} + + +impl FactorizedH +where + A: Scalar, + S: DataMut, +{ + pub fn into_inverseh(mut self) -> Result> { + unsafe { + A::invh( + self.a.square_layout()?, + UPLO::Upper, + self.a.as_allocated_mut()?, + &self.ipiv, + )? + }; + Ok(self.a) + } +} + +pub trait FactorizeH { + fn factorizeh(&self) -> Result>; +} + +pub trait FactorizeHInto { + fn factorizeh_into(self) -> Result>; +} + +impl FactorizeHInto for ArrayBase +where + A: Scalar, + S: DataMut, +{ + fn factorizeh_into(mut self) -> Result> { + let ipiv = unsafe { A::bk(self.layout()?, UPLO::Upper, self.as_allocated_mut()?)? }; + Ok(FactorizedH { + a: self, + ipiv: ipiv, + }) + } +} + +impl FactorizeH> for ArrayBase +where + A: Scalar, + Si: Data, +{ + fn factorizeh(&self) -> Result>> { + let mut a: Array2 = replicate(self); + let ipiv = unsafe { A::bk(a.layout()?, UPLO::Upper, a.as_allocated_mut()?)? }; + Ok(FactorizedH { a: a, ipiv: ipiv }) + } +} + +pub trait InverseH { + type Output; + fn invh(&self) -> Result; +} + +pub trait InverseHInto { + type Output; + fn invh_into(self) -> Result; +} + +impl InverseHInto for ArrayBase +where + A: Scalar, + S: DataMut, +{ + type Output = Self; + + fn invh_into(self) -> Result { + let f = self.factorizeh_into()?; + f.into_inverseh() + } +} + +impl InverseH for ArrayBase +where + A: Scalar, + Si: Data, +{ + type Output = Array2; + + fn invh(&self) -> Result { + let f = self.factorizeh()?; + f.into_inverseh() + } +}