1use std::collections::hash_map::DefaultHasher;
4use std::error::Error;
5use std::fmt;
6use std::hash::{Hash, Hasher};
7
8use itertools::{Itertools, MultiProduct};
9use log;
10use ndarray::{stack, Array1, Array2, Axis};
11use num_complex::ComplexFloat;
12
13pub trait HashableFloat {
15 #[must_use]
29 fn round_factor(self, threshold: Self) -> Self;
30
31 fn integer_decode(self) -> (u64, i16, i8);
43}
44
45impl HashableFloat for f64 {
46 fn round_factor(self, factor: f64) -> Self {
47 (self / factor).round() * factor + 0.0
48 }
49
50 fn integer_decode(self) -> (u64, i16, i8) {
51 let bits: u64 = self.to_bits();
52 let sign: i8 = if bits >> 63 == 0 { 1 } else { -1 };
53 let mut exponent: i16 = ((bits >> 52) & 0x7ff) as i16;
54 let mantissa = if exponent == 0 {
55 (bits & 0x000f_ffff_ffff_ffff) << 1
56 } else {
57 (bits & 0x000f_ffff_ffff_ffff) | 0x0010_0000_0000_0000
58 };
59
60 exponent -= 1023 + 52;
61 (mantissa, exponent, sign)
62 }
63}
64
65pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
75 let mut s = DefaultHasher::new();
76 t.hash(&mut s);
77 s.finish()
78}
79
80pub trait ProductRepeat: Iterator + Clone
82where
83 Self::Item: Clone,
84{
85 fn product_repeat(self, repeat: usize) -> MultiProduct<Self> {
97 std::iter::repeat(self)
98 .take(repeat)
99 .multi_cartesian_product()
100 }
101}
102
103impl<T: Iterator + Clone> ProductRepeat for T where T::Item: Clone {}
104
105#[derive(Debug, Clone)]
111pub struct GramSchmidtError<'a, T> {
112 pub mat: Option<&'a Array2<T>>,
113 pub vecs: Option<&'a [Array1<T>]>,
114}
115
116impl<'a, T: fmt::Display + fmt::Debug> fmt::Display for GramSchmidtError<'a, T> {
117 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118 writeln!(f, "Unable to perform Gram--Schmidt orthogonalisation on:",)?;
119 if let Some(mat) = self.mat {
120 writeln!(f, "{mat}")?;
121 } else if let Some(vecs) = self.vecs {
122 for vec in vecs {
123 writeln!(f, "{vec}")?;
124 }
125 } else {
126 writeln!(f, "Unspecified basis vectors for Gram--Schmidt.")?;
127 }
128 Ok(())
129 }
130}
131
132impl<'a, T: fmt::Display + fmt::Debug> Error for GramSchmidtError<'a, T> {}
133
134pub fn complex_modified_gram_schmidt<T>(
154 vmat: &Array2<T>,
155 complex_symmetric: bool,
156 thresh: T::Real,
157) -> Result<Array2<T>, GramSchmidtError<T>>
158where
159 T: ComplexFloat + fmt::Display + 'static,
160{
161 let mut us: Vec<Array1<T>> = Vec::with_capacity(vmat.shape()[1]);
162 let mut us_sq_norm: Vec<T> = Vec::with_capacity(vmat.shape()[1]);
163 for (i, vi) in vmat.columns().into_iter().enumerate() {
164 us.push(vi.to_owned());
166
167 for j in 0..i {
171 let p_uj_ui = if complex_symmetric {
172 us[j].t().dot(&us[i]) / us_sq_norm[j]
173 } else {
174 us[j].t().map(|x| x.conj()).dot(&us[i]) / us_sq_norm[j]
175 };
176 us[i] = &us[i] - us[j].map(|&x| x * p_uj_ui);
177 }
178
179 let us_sq_norm_i = if complex_symmetric {
182 us[i].t().dot(&us[i])
183 } else {
184 us[i].t().map(|x| x.conj()).dot(&us[i])
185 };
186 if us_sq_norm_i.abs() < thresh {
187 log::error!("A zero-norm vector found: {}", us[i]);
188 return Err(GramSchmidtError {
189 mat: Some(vmat),
190 vecs: None,
191 });
192 }
193 us_sq_norm.push(us_sq_norm_i);
194 }
195
196 for i in 0..us.len() {
198 us[i].mapv_inplace(|x| x / us_sq_norm[i].sqrt());
199 }
200
201 let ortho_check = us.iter().enumerate().all(|(i, ui)| {
202 us.iter().enumerate().all(|(j, uj)| {
203 let ov_ij = if complex_symmetric {
204 ui.dot(uj)
205 } else {
206 ui.map(|x| x.conj()).dot(uj)
207 };
208 i == j || ov_ij.abs() < thresh
209 })
210 });
211
212 if ortho_check {
213 stack(Axis(1), &us.iter().map(|u| u.view()).collect_vec()).map_err(|err| {
214 log::error!("{}", err);
215 GramSchmidtError {
216 mat: Some(vmat),
217 vecs: None
218 }
219 })
220 } else {
221 log::error!("Post-Gram--Schmidt orthogonality check failed.");
222 Err(GramSchmidtError {
223 mat: Some(vmat),
224 vecs: None,
225 })
226 }
227}