-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #67 from termoshtt/solveh
Solve Hermite/Symmetric linear problems
- Loading branch information
Showing
6 changed files
with
248 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
|
||
extern crate ndarray; | ||
extern crate ndarray_linalg; | ||
|
||
use ndarray::*; | ||
use ndarray_linalg::*; | ||
|
||
// Solve `Ax=b` for Hermite matrix A | ||
fn solve() -> Result<(), error::LinalgError> { | ||
let a: Array2<c64> = random_hermite(3); // complex Hermite positive definite matrix | ||
let b: Array1<c64> = random(3); | ||
println!("b = {:?}", &b); | ||
let x = a.solveh(&b)?; | ||
println!("Ax = {:?}", a.dot(&x));; | ||
Ok(()) | ||
} | ||
|
||
// Solve `Ax=b` for many b with fixed A | ||
fn factorize() -> Result<(), error::LinalgError> { | ||
let a: Array2<f64> = random_hpd(3); | ||
let f = a.factorizeh_into()?; | ||
// once factorized, you can use it several times: | ||
for _ in 0..10 { | ||
let b: Array1<f64> = random(3); | ||
let _x = f.solveh_into(b)?; | ||
} | ||
Ok(()) | ||
} | ||
|
||
fn main() { | ||
solve().unwrap(); | ||
factorize().unwrap(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
//! Solve symmetric linear problem using the Bunch-Kaufman diagonal pivoting method. | ||
//! | ||
//! See also [the manual of dsytrf](http://www.netlib.org/lapack/lapack-3.1.1/html/dsytrf.f.html) | ||
|
||
use lapack::c; | ||
|
||
use error::*; | ||
use layout::MatrixLayout; | ||
use types::*; | ||
|
||
use super::{Pivot, UPLO, into_result}; | ||
|
||
pub trait Solveh_: Sized { | ||
/// Bunch-Kaufman: wrapper of `*sytrf` and `*hetrf` | ||
unsafe fn bk(MatrixLayout, UPLO, a: &mut [Self]) -> Result<Pivot>; | ||
/// Wrapper of `*sytri` and `*hetri` | ||
unsafe fn invh(MatrixLayout, UPLO, a: &mut [Self], &Pivot) -> Result<()>; | ||
/// Wrapper of `*sytrs` and `*hetrs` | ||
unsafe fn solveh(MatrixLayout, UPLO, a: &[Self], &Pivot, b: &mut [Self]) -> Result<()>; | ||
} | ||
|
||
macro_rules! impl_solveh { | ||
($scalar:ty, $trf:path, $tri:path, $trs:path) => { | ||
|
||
impl Solveh_ for $scalar { | ||
unsafe fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> { | ||
let (n, _) = l.size(); | ||
let mut ipiv = vec![0; n as usize]; | ||
let info = $trf(l.lapacke_layout(), uplo as u8, n, a, l.lda(), &mut ipiv); | ||
into_result(info, ipiv) | ||
} | ||
|
||
unsafe fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> { | ||
let (n, _) = l.size(); | ||
let info = $tri(l.lapacke_layout(), uplo as u8, n, a, l.lda(), ipiv); | ||
into_result(info, ()) | ||
} | ||
|
||
unsafe fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()> { | ||
let (n, _) = l.size(); | ||
let nrhs = 1; | ||
let ldb = 1; | ||
let info = $trs(l.lapacke_layout(), uplo as u8, n, nrhs, a, l.lda(), ipiv, b, ldb); | ||
into_result(info, ()) | ||
} | ||
} | ||
|
||
}} // impl_solveh! | ||
|
||
impl_solveh!(f64, c::dsytrf, c::dsytri, c::dsytrs); | ||
impl_solveh!(f32, c::ssytrf, c::ssytri, c::ssytrs); | ||
impl_solveh!(c64, c::zhetrf, c::zhetri, c::zhetrs); | ||
impl_solveh!(c32, c::chetrf, c::chetri, c::chetrs); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
//! Solve Hermite/Symmetric linear problems | ||
|
||
use ndarray::*; | ||
|
||
use super::convert::*; | ||
use super::error::*; | ||
use super::layout::*; | ||
use super::types::*; | ||
|
||
pub use lapack_traits::{Pivot, UPLO}; | ||
|
||
pub trait SolveH<A: Scalar> { | ||
fn solveh<S: Data<Elem = A>>(&self, a: &ArrayBase<S, Ix1>) -> Result<Array1<A>> { | ||
let mut a = replicate(a); | ||
self.solveh_mut(&mut a)?; | ||
Ok(a) | ||
} | ||
fn solveh_into<S: DataMut<Elem = A>>(&self, mut a: ArrayBase<S, Ix1>) -> Result<ArrayBase<S, Ix1>> { | ||
self.solveh_mut(&mut a)?; | ||
Ok(a) | ||
} | ||
fn solveh_mut<'a, S: DataMut<Elem = A>>(&self, &'a mut ArrayBase<S, Ix1>) -> Result<&'a mut ArrayBase<S, Ix1>>; | ||
} | ||
|
||
pub struct FactorizedH<S: Data> { | ||
pub a: ArrayBase<S, Ix2>, | ||
pub ipiv: Pivot, | ||
} | ||
|
||
impl<A, S> SolveH<A> for FactorizedH<S> | ||
where | ||
A: Scalar, | ||
S: Data<Elem = A>, | ||
{ | ||
fn solveh_mut<'a, Sb>(&self, rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>> | ||
where | ||
Sb: DataMut<Elem = A>, | ||
{ | ||
unsafe { | ||
A::solveh( | ||
self.a.square_layout()?, | ||
UPLO::Upper, | ||
self.a.as_allocated()?, | ||
&self.ipiv, | ||
rhs.as_slice_mut().unwrap(), | ||
)? | ||
}; | ||
Ok(rhs) | ||
} | ||
} | ||
|
||
impl<A, S> SolveH<A> for ArrayBase<S, Ix2> | ||
where | ||
A: Scalar, | ||
S: Data<Elem = A>, | ||
{ | ||
fn solveh_mut<'a, Sb>(&self, mut rhs: &'a mut ArrayBase<Sb, Ix1>) -> Result<&'a mut ArrayBase<Sb, Ix1>> | ||
where | ||
Sb: DataMut<Elem = A>, | ||
{ | ||
let f = self.factorizeh()?; | ||
f.solveh_mut(rhs) | ||
} | ||
} | ||
|
||
|
||
impl<A, S> FactorizedH<S> | ||
where | ||
A: Scalar, | ||
S: DataMut<Elem = A>, | ||
{ | ||
pub fn into_inverseh(mut self) -> Result<ArrayBase<S, Ix2>> { | ||
unsafe { | ||
A::invh( | ||
self.a.square_layout()?, | ||
UPLO::Upper, | ||
self.a.as_allocated_mut()?, | ||
&self.ipiv, | ||
)? | ||
}; | ||
Ok(self.a) | ||
} | ||
} | ||
|
||
pub trait FactorizeH<S: Data> { | ||
fn factorizeh(&self) -> Result<FactorizedH<S>>; | ||
} | ||
|
||
pub trait FactorizeHInto<S: Data> { | ||
fn factorizeh_into(self) -> Result<FactorizedH<S>>; | ||
} | ||
|
||
impl<A, S> FactorizeHInto<S> for ArrayBase<S, Ix2> | ||
where | ||
A: Scalar, | ||
S: DataMut<Elem = A>, | ||
{ | ||
fn factorizeh_into(mut self) -> Result<FactorizedH<S>> { | ||
let ipiv = unsafe { A::bk(self.layout()?, UPLO::Upper, self.as_allocated_mut()?)? }; | ||
Ok(FactorizedH { | ||
a: self, | ||
ipiv: ipiv, | ||
}) | ||
} | ||
} | ||
|
||
impl<A, Si> FactorizeH<OwnedRepr<A>> for ArrayBase<Si, Ix2> | ||
where | ||
A: Scalar, | ||
Si: Data<Elem = A>, | ||
{ | ||
fn factorizeh(&self) -> Result<FactorizedH<OwnedRepr<A>>> { | ||
let mut a: Array2<A> = replicate(self); | ||
let ipiv = unsafe { A::bk(a.layout()?, UPLO::Upper, a.as_allocated_mut()?)? }; | ||
Ok(FactorizedH { a: a, ipiv: ipiv }) | ||
} | ||
} | ||
|
||
pub trait InverseH { | ||
type Output; | ||
fn invh(&self) -> Result<Self::Output>; | ||
} | ||
|
||
pub trait InverseHInto { | ||
type Output; | ||
fn invh_into(self) -> Result<Self::Output>; | ||
} | ||
|
||
impl<A, S> InverseHInto for ArrayBase<S, Ix2> | ||
where | ||
A: Scalar, | ||
S: DataMut<Elem = A>, | ||
{ | ||
type Output = Self; | ||
|
||
fn invh_into(self) -> Result<Self::Output> { | ||
let f = self.factorizeh_into()?; | ||
f.into_inverseh() | ||
} | ||
} | ||
|
||
impl<A, Si> InverseH for ArrayBase<Si, Ix2> | ||
where | ||
A: Scalar, | ||
Si: Data<Elem = A>, | ||
{ | ||
type Output = Array2<A>; | ||
|
||
fn invh(&self) -> Result<Self::Output> { | ||
let f = self.factorizeh()?; | ||
f.into_inverseh() | ||
} | ||
} |