qsym2/auxiliary/
misc.rs

1//! Miscellaneous items.
2
3use std::collections::hash_map::DefaultHasher;
4use std::error::Error;
5use std::fmt;
6use std::hash::{Hash, Hasher};
7
8use itertools::{Itertools, MultiProduct};
9use log;
10use ndarray::{stack, Array1, Array2, Axis};
11use num_complex::ComplexFloat;
12
13/// Trait to enable floating point numbers to be hashed.
14pub trait HashableFloat {
15    /// Returns a float rounded after being multiplied by a factor.
16    ///
17    /// Let $`x`$ be a float, $`k`$ a factor, and $`[\cdot]`$ denote the
18    /// rounding-to-integer operation. This function yields $`[x \times k] / k`$.
19    ///
20    /// Arguments
21    ///
22    /// * `threshold` - The inverse $`k^{-1}`$ of the factor $`k`$ used in the
23    /// rounding of the float.
24    ///
25    /// Returns
26    ///
27    /// The rounded float.
28    #[must_use]
29    fn round_factor(self, threshold: Self) -> Self;
30
31    /// Returns the mantissa-exponent-sign triplet for a float.
32    ///
33    /// Reference: <https://stackoverflow.com/questions/39638363/how-can-i-use-a-hashmap-with-f64-as-key-in-rust>
34    ///
35    /// # Arguments
36    ///
37    /// * `val` - A floating point number.
38    ///
39    /// # Returns
40    ///
41    /// The corresponding mantissa-exponent-sign triplet.
42    fn integer_decode(self) -> (u64, i16, i8);
43}
44
45impl HashableFloat for f64 {
46    fn round_factor(self, factor: f64) -> Self {
47        (self / factor).round() * factor + 0.0
48    }
49
50    fn integer_decode(self) -> (u64, i16, i8) {
51        let bits: u64 = self.to_bits();
52        let sign: i8 = if bits >> 63 == 0 { 1 } else { -1 };
53        let mut exponent: i16 = ((bits >> 52) & 0x7ff) as i16;
54        let mantissa = if exponent == 0 {
55            (bits & 0x000f_ffff_ffff_ffff) << 1
56        } else {
57            (bits & 0x000f_ffff_ffff_ffff) | 0x0010_0000_0000_0000
58        };
59
60        exponent -= 1023 + 52;
61        (mantissa, exponent, sign)
62    }
63}
64
65/// Returns the hash value of a hashable struct.
66///
67/// Arguments
68///
69/// * `t` - A struct of a hashable type.
70///
71/// Returns
72///
73/// The hash value.
74pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
75    let mut s = DefaultHasher::new();
76    t.hash(&mut s);
77    s.finish()
78}
79
80/// Trait for performing repeated products of iterators.
81pub trait ProductRepeat: Iterator + Clone
82where
83    Self::Item: Clone,
84{
85    /// Rust implementation of Python's `itertools.product()` with repetition.
86    ///
87    /// From <https://stackoverflow.com/a/68231315>.
88    ///
89    /// # Arguments
90    ///
91    /// * `repeat` - Number of repetitions of the given iterator.
92    ///
93    /// # Returns
94    ///
95    /// A [`MultiProduct`] iterator.
96    fn product_repeat(self, repeat: usize) -> MultiProduct<Self> {
97        std::iter::repeat(self)
98            .take(repeat)
99            .multi_cartesian_product()
100    }
101}
102
103impl<T: Iterator + Clone> ProductRepeat for T where T::Item: Clone {}
104
105// =============
106// Gram--Schmidt
107// =============
108
109/// Error during Gram--Schmidt orthogonalisation.
110#[derive(Debug, Clone)]
111pub struct GramSchmidtError<'a, T> {
112    pub mat: Option<&'a Array2<T>>,
113    pub vecs: Option<&'a [Array1<T>]>,
114}
115
116impl<'a, T: fmt::Display + fmt::Debug> fmt::Display for GramSchmidtError<'a, T> {
117    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118        writeln!(f, "Unable to perform Gram--Schmidt orthogonalisation on:",)?;
119        if let Some(mat) = self.mat {
120            writeln!(f, "{mat}")?;
121        } else if let Some(vecs) = self.vecs {
122            for vec in vecs {
123                writeln!(f, "{vec}")?;
124            }
125        } else {
126            writeln!(f, "Unspecified basis vectors for Gram--Schmidt.")?;
127        }
128        Ok(())
129    }
130}
131
132impl<'a, T: fmt::Display + fmt::Debug> Error for GramSchmidtError<'a, T> {}
133
134/// Performs modified Gram--Schmidt orthonormalisation on a set of column vectors in a matrix with
135/// respect to the complex-symmetric or Hermitian dot product.
136///
137/// # Arguments
138///
139/// * `vmat` - Matrix containing column vectors forming a basis for a subspace.
140/// * `complex_symmetric` - A boolean indicating if the vector dot product is complex-symmetric. If
141/// `false`, the conventional Hermitian dot product is used.
142/// * `thresh` - A threshold for determining self-orthogonal vectors.
143///
144/// # Returns
145///
146/// The orthonormal vectors forming a basis for the same subspace collected as column vectors in a
147/// matrix.
148///
149/// # Errors
150///
151/// Errors when the orthonormalisation procedure fails, which occurs when there is linear dependency
152/// between the basis vectors and/or when self-orthogonal vectors are encountered.
153pub fn complex_modified_gram_schmidt<T>(
154    vmat: &Array2<T>,
155    complex_symmetric: bool,
156    thresh: T::Real,
157) -> Result<Array2<T>, GramSchmidtError<T>>
158where
159    T: ComplexFloat + fmt::Display + 'static,
160{
161    let mut us: Vec<Array1<T>> = Vec::with_capacity(vmat.shape()[1]);
162    let mut us_sq_norm: Vec<T> = Vec::with_capacity(vmat.shape()[1]);
163    for (i, vi) in vmat.columns().into_iter().enumerate() {
164        // u[i] now initialised with v[i]
165        us.push(vi.to_owned());
166
167        // Project ui onto all uj (0 <= j < i)
168        // This is the 'modified' part of Gram--Schmidt. We project the current (and being updated)
169        // ui onto uj, rather than projecting vi onto uj. This enhances numerical stability.
170        for j in 0..i {
171            let p_uj_ui = if complex_symmetric {
172                us[j].t().dot(&us[i]) / us_sq_norm[j]
173            } else {
174                us[j].t().map(|x| x.conj()).dot(&us[i]) / us_sq_norm[j]
175            };
176            us[i] = &us[i] - us[j].map(|&x| x * p_uj_ui);
177        }
178
179        // Evaluate the squared norm of ui which will no longer be changed after this iteration.
180        // us_sq_norm[i] now available.
181        let us_sq_norm_i = if complex_symmetric {
182            us[i].t().dot(&us[i])
183        } else {
184            us[i].t().map(|x| x.conj()).dot(&us[i])
185        };
186        if us_sq_norm_i.abs() < thresh {
187            log::error!("A zero-norm vector found: {}", us[i]);
188            return Err(GramSchmidtError {
189                mat: Some(vmat),
190                vecs: None,
191            });
192        }
193        us_sq_norm.push(us_sq_norm_i);
194    }
195
196    // Normalise ui
197    for i in 0..us.len() {
198        us[i].mapv_inplace(|x| x / us_sq_norm[i].sqrt());
199    }
200
201    let ortho_check = us.iter().enumerate().all(|(i, ui)| {
202        us.iter().enumerate().all(|(j, uj)| {
203            let ov_ij = if complex_symmetric {
204                ui.dot(uj)
205            } else {
206                ui.map(|x| x.conj()).dot(uj)
207            };
208            i == j || ov_ij.abs() < thresh
209        })
210    });
211
212    if ortho_check {
213        stack(Axis(1), &us.iter().map(|u| u.view()).collect_vec()).map_err(|err| {
214            log::error!("{}", err);
215            GramSchmidtError {
216                mat: Some(vmat),
217                vecs: None
218            }
219        })
220    } else {
221        log::error!("Post-Gram--Schmidt orthogonality check failed.");
222        Err(GramSchmidtError {
223            mat: Some(vmat),
224            vecs: None,
225        })
226    }
227}