1use std::collections::hash_map::DefaultHasher;
4use std::error::Error;
5use std::fmt;
6use std::hash::{Hash, Hasher};
7
8use itertools::{Itertools, MultiProduct};
9use log;
10use ndarray::{Array1, Array2, Axis, stack};
11use num_complex::ComplexFloat;
12
13pub trait HashableFloat {
15 #[must_use]
29 fn round_factor(self, threshold: Self) -> Self;
30
31 fn integer_decode(self) -> (u64, i16, i8);
43}
44
45impl HashableFloat for f64 {
46 fn round_factor(self, factor: f64) -> Self {
47 (self / factor).round() * factor + 0.0
48 }
49
50 fn integer_decode(self) -> (u64, i16, i8) {
51 let bits: u64 = self.to_bits();
52 let sign: i8 = if bits >> 63 == 0 { 1 } else { -1 };
53 let mut exponent: i16 = ((bits >> 52) & 0x7ff) as i16;
54 let mantissa = if exponent == 0 {
55 (bits & 0x000f_ffff_ffff_ffff) << 1
56 } else {
57 (bits & 0x000f_ffff_ffff_ffff) | 0x0010_0000_0000_0000
58 };
59
60 exponent -= 1023 + 52;
61 (mantissa, exponent, sign)
62 }
63}
64
65pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
75 let mut s = DefaultHasher::new();
76 t.hash(&mut s);
77 s.finish()
78}
79
80pub trait ProductRepeat: Iterator + Clone
82where
83 Self::Item: Clone,
84{
85 fn product_repeat(self, repeat: usize) -> MultiProduct<Self> {
97 std::iter::repeat_n(self, repeat)
98 .multi_cartesian_product()
99 }
100}
101
102impl<T: Iterator + Clone> ProductRepeat for T where T::Item: Clone {}
103
104#[derive(Debug, Clone)]
110pub struct GramSchmidtError<'a, T> {
111 pub mat: Option<&'a Array2<T>>,
112 pub vecs: Option<&'a [Array1<T>]>,
113}
114
115impl<'a, T: fmt::Display + fmt::Debug> fmt::Display for GramSchmidtError<'a, T> {
116 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117 writeln!(f, "Unable to perform Gram--Schmidt orthogonalisation on:",)?;
118 if let Some(mat) = self.mat {
119 writeln!(f, "{mat}")?;
120 } else if let Some(vecs) = self.vecs {
121 for vec in vecs {
122 writeln!(f, "{vec}")?;
123 }
124 } else {
125 writeln!(f, "Unspecified basis vectors for Gram--Schmidt.")?;
126 }
127 Ok(())
128 }
129}
130
131impl<'a, T: fmt::Display + fmt::Debug> Error for GramSchmidtError<'a, T> {}
132
133pub fn complex_modified_gram_schmidt<T>(
153 vmat: &'_ Array2<T>,
154 complex_symmetric: bool,
155 thresh: T::Real,
156) -> Result<Array2<T>, GramSchmidtError<'_, T>>
157where
158 T: ComplexFloat + fmt::Display + 'static,
159{
160 let mut us: Vec<Array1<T>> = Vec::with_capacity(vmat.shape()[1]);
161 let mut us_sq_norm: Vec<T> = Vec::with_capacity(vmat.shape()[1]);
162 for (i, vi) in vmat.columns().into_iter().enumerate() {
163 us.push(vi.to_owned());
165
166 for j in 0..i {
170 let p_uj_ui = if complex_symmetric {
171 us[j].t().dot(&us[i]) / us_sq_norm[j]
172 } else {
173 us[j].t().map(|x| x.conj()).dot(&us[i]) / us_sq_norm[j]
174 };
175 us[i] = &us[i] - us[j].map(|&x| x * p_uj_ui);
176 }
177
178 let us_sq_norm_i = if complex_symmetric {
181 us[i].t().dot(&us[i])
182 } else {
183 us[i].t().map(|x| x.conj()).dot(&us[i])
184 };
185 if us_sq_norm_i.abs() < thresh {
186 log::error!("A zero-norm vector found: {}", us[i]);
187 return Err(GramSchmidtError {
188 mat: Some(vmat),
189 vecs: None,
190 });
191 }
192 us_sq_norm.push(us_sq_norm_i);
193 }
194
195 for i in 0..us.len() {
197 us[i].mapv_inplace(|x| x / us_sq_norm[i].sqrt());
198 }
199
200 let ortho_check = us.iter().enumerate().all(|(i, ui)| {
201 us.iter().enumerate().all(|(j, uj)| {
202 let ov_ij = if complex_symmetric {
203 ui.dot(uj)
204 } else {
205 ui.map(|x| x.conj()).dot(uj)
206 };
207 i == j || ov_ij.abs() < thresh
208 })
209 });
210
211 if ortho_check {
212 stack(Axis(1), &us.iter().map(|u| u.view()).collect_vec()).map_err(|err| {
213 log::error!("{}", err);
214 GramSchmidtError {
215 mat: Some(vmat),
216 vecs: None,
217 }
218 })
219 } else {
220 Err(GramSchmidtError {
221 mat: Some(vmat),
222 vecs: None,
223 })
224 }
225}