Skip to content

Commit

Permalink
Merge pull request #67 from termoshtt/solveh
Browse files Browse the repository at this point in the history
Solve Hermite/Symmetric linear problems
  • Loading branch information
termoshtt committed Aug 22, 2017
2 parents ad7624e + 4638183 commit 278a7bf
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 4 deletions.
33 changes: 33 additions & 0 deletions examples/solveh.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

extern crate ndarray;
extern crate ndarray_linalg;

use ndarray::*;
use ndarray_linalg::*;

// Solve `Ax=b` for Hermite matrix A
fn solve() -> Result<(), error::LinalgError> {
let a: Array2<c64> = random_hermite(3); // complex Hermite positive definite matrix
let b: Array1<c64> = random(3);
println!("b = {:?}", &b);
let x = a.solveh(&b)?;
println!("Ax = {:?}", a.dot(&x));;
Ok(())
}

// Solve `Ax=b` for many b with fixed A
fn factorize() -> Result<(), error::LinalgError> {
let a: Array2<f64> = random_hpd(3);
let f = a.factorizeh_into()?;
// once factorized, you can use it several times:
for _ in 0..10 {
let b: Array1<f64> = random(3);
let _x = f.solveh_into(b)?;
}
Ok(())
}

fn main() {
solve().unwrap();
factorize().unwrap();
}
7 changes: 6 additions & 1 deletion src/lapack_traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod opnorm;
pub mod qr;
pub mod svd;
pub mod solve;
pub mod solveh;
pub mod cholesky;
pub mod eigh;
pub mod triangular;
Expand All @@ -13,14 +14,18 @@ pub use self::eigh::*;
pub use self::opnorm::*;
pub use self::qr::*;
pub use self::solve::*;
pub use self::solveh::*;
pub use self::svd::*;
pub use self::triangular::*;

use super::error::*;
use super::types::*;

pub type Pivot = Vec<i32>;

pub trait LapackScalar
: OperatorNorm_ + QR_ + SVD_ + Solve_ + Cholesky_ + Eigh_ + Triangular_ {
: OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_
{
}

impl LapackScalar for f32 {}
Expand Down
4 changes: 1 addition & 3 deletions src/lapack_traits/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ use error::*;
use layout::MatrixLayout;
use types::*;

use super::{Transpose, into_result};

pub type Pivot = Vec<i32>;
use super::{Pivot, Transpose, into_result};

/// Wraps `*getrf`, `*getri`, and `*getrs`
pub trait Solve_: Sized {
Expand Down
53 changes: 53 additions & 0 deletions src/lapack_traits/solveh.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//! Solve symmetric linear problem using the Bunch-Kaufman diagonal pivoting method.
//!
//! See also [the manual of dsytrf](http://www.netlib.org/lapack/lapack-3.1.1/html/dsytrf.f.html)

use lapack::c;

use error::*;
use layout::MatrixLayout;
use types::*;

use super::{Pivot, UPLO, into_result};

pub trait Solveh_: Sized {
/// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf`
unsafe fn bk(MatrixLayout, UPLO, a: &mut [Self]) -> Result<Pivot>;
/// Wrapper of `*sytri` and `*hetri`
unsafe fn invh(MatrixLayout, UPLO, a: &mut [Self], &Pivot) -> Result<()>;
/// Wrapper of `*sytrs` and `*hetrs`
unsafe fn solveh(MatrixLayout, UPLO, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>;
}

macro_rules! impl_solveh {
($scalar:ty, $trf:path, $tri:path, $trs:path) => {

impl Solveh_ for $scalar {
unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
let (n, _) = l.size();
let mut ipiv = vec![0; n as usize];
let info = $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv);
into_result(info, ipiv)
}

unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
let (n, _) = l.size();
let info = $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv);
into_result(info, ())
}

unsafe fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> {
let (n, _) = l.size();
let nrhs = 1;
let ldb = 1;
let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), ipiv, b, ldb);
into_result(info, ())
}
}

}} // impl_solveh!

impl_solveh!(f64, c::dsytrf, c::dsytri, c::dsytrs);
impl_solveh!(f32, c::ssytrf, c::ssytri, c::ssytrs);
impl_solveh!(c64, c::zhetrf, c::zhetri, c::zhetrs);
impl_solveh!(c32, c::chetrf, c::chetri, c::chetrs);
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub mod operator;
pub mod opnorm;
pub mod qr;
pub mod solve;
pub mod solveh;
pub mod svd;
pub mod trace;
pub mod triangular;
Expand All @@ -59,6 +60,7 @@ pub use operator::*;
pub use opnorm::*;
pub use qr::*;
pub use solve::*;
pub use solveh::*;
pub use svd::*;
pub use trace::*;
pub use triangular::*;
Expand Down
153 changes: 153 additions & 0 deletions src/solveh.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//! Solve Hermite/Symmetric linear problems

use ndarray::*;

use super::convert::*;
use super::error::*;
use super::layout::*;
use super::types::*;

pub use lapack_traits::{Pivot, UPLO};

pub trait SolveH<A: Scalar> {
fn solveh<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> {
let mut a = replicate(a);
self.solveh_mut(&mut a)?;
Ok(a)
}
fn solveh_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> {
self.solveh_mut(&mut a)?;
Ok(a)
}
fn solveh_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>;
}

pub struct FactorizedH<S: Data> {
pub a: ArrayBase<S, Ix2>,
pub ipiv: Pivot,
}

impl<A, S> SolveH<A> for FactorizedH<S>
where
A: Scalar,
S: Data<Elem = A>,
{
fn solveh_mut<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
unsafe {
A::solveh(
self.a.square_layout()?,
UPLO::Upper,
self.a.as_allocated()?,
&self.ipiv,
rhs.as_slice_mut().unwrap(),
)?
};
Ok(rhs)
}
}

impl<A, S> SolveH<A> for ArrayBase<S, Ix2>
where
A: Scalar,
S: Data<Elem = A>,
{
fn solveh_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>>
where
Sb: DataMut<Elem = A>,
{
let f = self.factorizeh()?;
f.solveh_mut(rhs)
}
}


impl<A, S> FactorizedH<S>
where
A: Scalar,
S: DataMut<Elem = A>,
{
pub fn into_inverseh(mut self) -> Result<ArrayBase<S, Ix2>> {
unsafe {
A::invh(
self.a.square_layout()?,
UPLO::Upper,
self.a.as_allocated_mut()?,
&self.ipiv,
)?
};
Ok(self.a)
}
}

pub trait FactorizeH<S: Data> {
fn factorizeh(&self) -> Result<FactorizedH<S>>;
}

pub trait FactorizeHInto<S: Data> {
fn factorizeh_into(self) -> Result<FactorizedH<S>>;
}

impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2>
where
A: Scalar,
S: DataMut<Elem = A>,
{
fn factorizeh_into(mut self) -> Result<FactorizedH<S>> {
let ipiv = unsafe { A::bk(self.layout()?, UPLO::Upper, self.as_allocated_mut()?)? };
Ok(FactorizedH {
a: self,
ipiv: ipiv,
})
}
}

impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2>
where
A: Scalar,
Si: Data<Elem = A>,
{
fn factorizeh(&self) -> Result<FactorizedH<OwnedRepr<A>>> {
let mut a: Array2<A> = replicate(self);
let ipiv = unsafe { A::bk(a.layout()?, UPLO::Upper, a.as_allocated_mut()?)? };
Ok(FactorizedH { a: a, ipiv: ipiv })
}
}

pub trait InverseH {
type Output;
fn invh(&self) -> Result<Self::Output>;
}

pub trait InverseHInto {
type Output;
fn invh_into(self) -> Result<Self::Output>;
}

impl<A, S> InverseHInto for ArrayBase<S, Ix2>
where
A: Scalar,
S: DataMut<Elem = A>,
{
type Output = Self;

fn invh_into(self) -> Result<Self::Output> {
let f = self.factorizeh_into()?;
f.into_inverseh()
}
}

impl<A, Si> InverseH for ArrayBase<Si, Ix2>
where
A: Scalar,
Si: Data<Elem = A>,
{
type Output = Array2<A>;

fn invh(&self) -> Result<Self::Output> {
let f = self.factorizeh()?;
f.into_inverseh()
}
}

0 comments on commit 278a7bf

Please sign in to comment.