qsym2/auxiliary/
misc.rs

1//! Miscellaneous items.
2
3use std::collections::hash_map::DefaultHasher;
4use std::error::Error;
5use std::fmt;
6use std::hash::{Hash, Hasher};
7
8use itertools::{Itertools, MultiProduct};
9use log;
10use ndarray::{Array1, Array2, Axis, stack};
11use num_complex::ComplexFloat;
12
13/// Trait to enable floating point numbers to be hashed.
14pub trait HashableFloat {
15    /// Returns a float rounded after being multiplied by a factor.
16    ///
17    /// Let $`x`$ be a float, $`k`$ a factor, and $`[\cdot]`$ denote the
18    /// rounding-to-integer operation. This function yields $`[x \times k] / k`$.
19    ///
20    /// Arguments
21    ///
22    /// * `threshold` - The inverse $`k^{-1}`$ of the factor $`k`$ used in the
23    ///   rounding of the float.
24    ///
25    /// Returns
26    ///
27    /// The rounded float.
28    #[must_use]
29    fn round_factor(self, threshold: Self) -> Self;
30
31    /// Returns the mantissa-exponent-sign triplet for a float.
32    ///
33    /// Reference: <https://stackoverflow.com/questions/39638363/how-can-i-use-a-hashmap-with-f64-as-key-in-rust>
34    ///
35    /// # Arguments
36    ///
37    /// * `val` - A floating point number.
38    ///
39    /// # Returns
40    ///
41    /// The corresponding mantissa-exponent-sign triplet.
42    fn integer_decode(self) -> (u64, i16, i8);
43}
44
45impl HashableFloat for f64 {
46    fn round_factor(self, factor: f64) -> Self {
47        (self / factor).round() * factor + 0.0
48    }
49
50    fn integer_decode(self) -> (u64, i16, i8) {
51        let bits: u64 = self.to_bits();
52        let sign: i8 = if bits >> 63 == 0 { 1 } else { -1 };
53        let mut exponent: i16 = ((bits >> 52) & 0x7ff) as i16;
54        let mantissa = if exponent == 0 {
55            (bits & 0x000f_ffff_ffff_ffff) << 1
56        } else {
57            (bits & 0x000f_ffff_ffff_ffff) | 0x0010_0000_0000_0000
58        };
59
60        exponent -= 1023 + 52;
61        (mantissa, exponent, sign)
62    }
63}
64
65/// Returns the hash value of a hashable struct.
66///
67/// Arguments
68///
69/// * `t` - A struct of a hashable type.
70///
71/// Returns
72///
73/// The hash value.
74pub fn calculate_hash<T: Hash>(t: &T) -> u64 {
75    let mut s = DefaultHasher::new();
76    t.hash(&mut s);
77    s.finish()
78}
79
80/// Trait for performing repeated products of iterators.
81pub trait ProductRepeat: Iterator + Clone
82where
83    Self::Item: Clone,
84{
85    /// Rust implementation of Python's `itertools.product()` with repetition.
86    ///
87    /// From <https://stackoverflow.com/a/68231315>.
88    ///
89    /// # Arguments
90    ///
91    /// * `repeat` - Number of repetitions of the given iterator.
92    ///
93    /// # Returns
94    ///
95    /// A [`MultiProduct`] iterator.
96    fn product_repeat(self, repeat: usize) -> MultiProduct<Self> {
97        std::iter::repeat_n(self, repeat)
98            .multi_cartesian_product()
99    }
100}
101
102impl<T: Iterator + Clone> ProductRepeat for T where T::Item: Clone {}
103
104// =============
105// Gram--Schmidt
106// =============
107
108/// Error during Gram--Schmidt orthogonalisation.
109#[derive(Debug, Clone)]
110pub struct GramSchmidtError<'a, T> {
111    pub mat: Option<&'a Array2<T>>,
112    pub vecs: Option<&'a [Array1<T>]>,
113}
114
115impl<'a, T: fmt::Display + fmt::Debug> fmt::Display for GramSchmidtError<'a, T> {
116    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117        writeln!(f, "Unable to perform Gram--Schmidt orthogonalisation on:",)?;
118        if let Some(mat) = self.mat {
119            writeln!(f, "{mat}")?;
120        } else if let Some(vecs) = self.vecs {
121            for vec in vecs {
122                writeln!(f, "{vec}")?;
123            }
124        } else {
125            writeln!(f, "Unspecified basis vectors for Gram--Schmidt.")?;
126        }
127        Ok(())
128    }
129}
130
131impl<'a, T: fmt::Display + fmt::Debug> Error for GramSchmidtError<'a, T> {}
132
133/// Performs modified Gram--Schmidt orthonormalisation on a set of column vectors in a matrix with
134/// respect to the complex-symmetric or Hermitian dot product.
135///
136/// # Arguments
137///
138/// * `vmat` - Matrix containing column vectors forming a basis for a subspace.
139/// * `complex_symmetric` - A boolean indicating if the vector dot product is complex-symmetric. If
140///   `false`, the conventional Hermitian dot product is used.
141/// * `thresh` - A threshold for determining self-orthogonal vectors.
142///
143/// # Returns
144///
145/// The orthonormal vectors forming a basis for the same subspace collected as column vectors in a
146/// matrix.
147///
148/// # Errors
149///
150/// Errors when the orthonormalisation procedure fails, which occurs when there is linear dependency
151/// between the basis vectors and/or when self-orthogonal vectors are encountered.
152pub fn complex_modified_gram_schmidt<T>(
153    vmat: &'_ Array2<T>,
154    complex_symmetric: bool,
155    thresh: T::Real,
156) -> Result<Array2<T>, GramSchmidtError<'_, T>>
157where
158    T: ComplexFloat + fmt::Display + 'static,
159{
160    let mut us: Vec<Array1<T>> = Vec::with_capacity(vmat.shape()[1]);
161    let mut us_sq_norm: Vec<T> = Vec::with_capacity(vmat.shape()[1]);
162    for (i, vi) in vmat.columns().into_iter().enumerate() {
163        // u[i] now initialised with v[i]
164        us.push(vi.to_owned());
165
166        // Project ui onto all uj (0 <= j < i)
167        // This is the 'modified' part of Gram--Schmidt. We project the current (and being updated)
168        // ui onto uj, rather than projecting vi onto uj. This enhances numerical stability.
169        for j in 0..i {
170            let p_uj_ui = if complex_symmetric {
171                us[j].t().dot(&us[i]) / us_sq_norm[j]
172            } else {
173                us[j].t().map(|x| x.conj()).dot(&us[i]) / us_sq_norm[j]
174            };
175            us[i] = &us[i] - us[j].map(|&x| x * p_uj_ui);
176        }
177
178        // Evaluate the squared norm of ui which will no longer be changed after this iteration.
179        // us_sq_norm[i] now available.
180        let us_sq_norm_i = if complex_symmetric {
181            us[i].t().dot(&us[i])
182        } else {
183            us[i].t().map(|x| x.conj()).dot(&us[i])
184        };
185        if us_sq_norm_i.abs() < thresh {
186            log::error!("A zero-norm vector found: {}", us[i]);
187            return Err(GramSchmidtError {
188                mat: Some(vmat),
189                vecs: None,
190            });
191        }
192        us_sq_norm.push(us_sq_norm_i);
193    }
194
195    // Normalise ui
196    for i in 0..us.len() {
197        us[i].mapv_inplace(|x| x / us_sq_norm[i].sqrt());
198    }
199
200    let ortho_check = us.iter().enumerate().all(|(i, ui)| {
201        us.iter().enumerate().all(|(j, uj)| {
202            let ov_ij = if complex_symmetric {
203                ui.dot(uj)
204            } else {
205                ui.map(|x| x.conj()).dot(uj)
206            };
207            i == j || ov_ij.abs() < thresh
208        })
209    });
210
211    if ortho_check {
212        stack(Axis(1), &us.iter().map(|u| u.view()).collect_vec()).map_err(|err| {
213            log::error!("{}", err);
214            GramSchmidtError {
215                mat: Some(vmat),
216                vecs: None,
217            }
218        })
219    } else {
220        Err(GramSchmidtError {
221            mat: Some(vmat),
222            vecs: None,
223        })
224    }
225}