qsym2/target/noci/backend/solver/
mod.rs

1use std::fmt::format;
2
3use anyhow::{self, ensure, format_err};
4use duplicate::duplicate_item;
5use itertools::Itertools;
6use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, Ix2, LinalgScalar, stack};
7use ndarray_einsum::einsum;
8use ndarray_linalg::{Eig, EigGeneralized, Eigh, GeneralizedEigenvalue, Lapack, Scalar, UPLO};
9use num::traits::FloatConst;
10use num::{Float, One};
11use num_complex::{Complex, ComplexFloat};
12
13use crate::analysis::EigenvalueComparisonMode;
14
15use crate::target::noci::backend::nonortho::CanonicalOrthogonalisable;
16
17pub mod noci;
18
19#[cfg(test)]
20#[path = "solver_tests.rs"]
21mod solver_tests;
22
23// -----------------------------
24// GeneralisedEigenvalueSolvable
25// -----------------------------
26
27/// Trait to solve the generalised eigenvalue equation for a pair of square matrices $`\mathbf{A}`$
28/// and $`\mathbf{B}`$:
29/// ```math
30///     \mathbf{A} \mathbf{v} = \lambda \mathbf{B} \mathbf{v},
31/// ```
32/// where $`\mathbf{A}`$ and $`\mathbf{B}`$ are in general non-Hermitian and non-positive-definite.
33pub trait GeneralisedEigenvalueSolvable {
34    /// Numerical type of the matrix elements constituting the generalised eigenvalue problem.
35    type NumType;
36
37    /// Numerical type of the various thresholds for comparison.
38    type RealType;
39
40    /// Solves the *auxiliary* generalised eigenvalue problem
41    /// ```math
42    ///     \tilde{\mathbf{A}} \tilde{\mathbf{v}} = \tilde{\lambda} \tilde{\mathbf{B}} \tilde{\mathbf{v}},
43    /// ```
44    /// where $`\tilde{\mathbf{B}}`$ is the canonical-orthogonalised version of $`\mathbf{B}`$. If
45    /// $`\mathbf{B}`$ is not of full rank, then the two eigenvalue problems are different.
46    ///
47    /// # Arguments
48    ///
49    /// * `complex_symmetric` - Boolean indicating whether the provided pair of matrices are
50    ///   complex-symmetric.
51    /// * `thresh_offdiag` - Threshold for checking if any off-diagonal elements are non-zero when
52    ///   verifying orthogonality.
53    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of $`\mathbf{B}`$.
54    /// * `eigenvalue_comparison_mode` - Comparison mode for sorting eigenvalues and their
55    ///   corresponding eigenvectors.
56    ///
57    /// # Returns
58    ///
59    /// The generalised eigenvalue result.
60    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
61        &self,
62        complex_symmetric: bool,
63        thresh_offdiag: Self::RealType,
64        thresh_zeroov: Self::RealType,
65        eigenvalue_comparison_mode: EigenvalueComparisonMode,
66    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
67
68    /// Solves the generalised eigenvalue problem using LAPACK's `?ggev` generalised eigensolver.
69    ///
70    /// Note that this can be numerically unstable if $`\mathbf{B}`$ is not of full rank.
71    ///
72    /// # Arguments
73    ///
74    /// * `complex_symmetric` - Boolean indicating whether the provided pair of matrices are
75    ///   complex-symmetric.
76    /// * `thresh_offdiag` - Threshold for checking if any off-diagonal elements are non-zero when
77    ///   verifying orthogonality.
78    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of $`\mathbf{B}`$.
79    /// * `eigenvalue_comparison_mode` - Comparison mode for sorting eigenvalues and their
80    ///   corresponding eigenvectors.
81    ///
82    /// # Returns
83    ///
84    /// The generalised eigenvalue result.
85    fn solve_generalised_eigenvalue_problem_with_ggev(
86        &self,
87        complex_symmetric: bool,
88        thresh_offdiag: Self::RealType,
89        thresh_zeroov: Self::RealType,
90        eigenvalue_comparison_mode: EigenvalueComparisonMode,
91    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
92}
93
94/// Structure containing the eigenvalues and eigenvectors of a generalised
95/// eigenvalue problem.
96pub struct GeneralisedEigenvalueResult<T> {
97    /// The resulting eigenvalues.
98    eigenvalues: Array1<T>,
99
100    /// The corresponding eigenvectors.
101    eigenvectors: Array2<T>,
102}
103
104impl<T> GeneralisedEigenvalueResult<T> {
105    /// Returns the eigenvalues.
106    pub fn eigenvalues(&'_ self) -> ArrayView1<'_, T> {
107        self.eigenvalues.view()
108    }
109
110    /// Returns the eigenvectors.
111    pub fn eigenvectors(&'_ self) -> ArrayView2<'_, T> {
112        self.eigenvectors.view()
113    }
114}
115
116#[duplicate_item(
117    [
118        dtype_ [ f64 ]
119    ]
120    [
121        dtype_ [ f32 ]
122    ]
123)]
124impl GeneralisedEigenvalueSolvable for (&ArrayView2<'_, dtype_>, &ArrayView2<'_, dtype_>) {
125    type NumType = dtype_;
126    type RealType = dtype_;
127
128    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
129        &self,
130        _: bool,
131        thresh_offdiag: dtype_,
132        thresh_zeroov: dtype_,
133        eigenvalue_comparison_mode: EigenvalueComparisonMode,
134    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
135        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
136
137        // Symmetrise `hmat` and `smat` to improve numerical stability
138        let (hmat, smat): (Array2<dtype_>, Array2<dtype_>) = {
139            // Real, symmetric S and H
140            let (pos, &max_offdiag_h) = (hmat.to_owned() - hmat.t())
141                    .map(|v| v.abs())
142                    .iter()
143                    .enumerate()
144                    .max_by(|(_, a), (_, b)| a.total_cmp(b))
145                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the Hamiltonian symmetric deviation matrix."))?;
146            log::debug!("Hamiltonian matrix:\n  {hmat:+.3e}");
147            log::debug!(
148                "Hamiltonian matrix symmetric deviation:\n  {:+.3e}",
149                hmat.to_owned() - hmat.t()
150            );
151            let (pos_i, pos_j) = (pos.div_euclid(hmat.ncols()), pos.rem_euclid(hmat.ncols()));
152            ensure!(
153                max_offdiag_h <= thresh_offdiag,
154                "Hamiltonian matrix is not real-symmetric: ||H - H^T||_∞ = {max_offdiag_h:.3e} > {thresh_offdiag:.3e} at ({pos_i}, {pos_j})."
155            );
156            log::debug!("Overlap matrix:\n  {smat:+.3e}");
157
158            (
159                (hmat.to_owned() + hmat.t().to_owned()).map(|v| v / (2.0)),
160                (smat.to_owned() + smat.t().to_owned()).map(|v| v / (2.0)),
161            )
162        };
163
164        // CanonicalOrthogonalisationResult::calc_canonical_orthogonal_matrix checks for
165        // real-symmetry of S.
166        // This will fail over the reals if smat contains negative eigenvalues.
167        let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
168            true,
169            false,
170            thresh_offdiag,
171            thresh_zeroov,
172        )?;
173
174        let xmat = xmat_res.xmat();
175        let xmat_d = xmat_res.xmat_d();
176
177        let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
178        let smat_t = xmat_d.dot(&smat).dot(&xmat);
179
180        log::debug!("Canonical-orthogonalised NOCI Hamiltonian matrix H~:\n  {hmat_t:+.8e}");
181        log::debug!("Canonical-orthogonalised NOCI overlap matrix S~:\n  {smat_t:+.8e}");
182
183        // Over the reals, canonical orthogonalisation cannot handle `smat` with negative
184        // eigenvalues. This means that `smat_t` can only be the identity.
185        let (pos, max_diff) = (&smat_t - &Array2::<dtype_>::eye(smat_t.nrows()))
186            .iter()
187            .map(|x| ComplexFloat::abs(*x))
188            .enumerate()
189            .max_by(|(_, x), (_, y)| {
190                x.partial_cmp(y)
191                    .expect("Unable to compare two `abs` values.")
192            })
193            .ok_or_else(|| {
194                format_err!("Unable to determine the maximum element of the |S - I| matrix.")
195            })?;
196        let (pos_i, pos_j) = (pos.div_euclid(hmat.ncols()), pos.rem_euclid(hmat.ncols()));
197        ensure!(
198            max_diff <= thresh_offdiag,
199            "The orthogonalised overlap matrix is not the identity matrix: the maximum absolute deviation is {max_diff:.3e} > {thresh_offdiag:.3e} at ({pos_i}, {pos_j})."
200        );
201
202        let (eigvals_t, eigvecs_t) = hmat_t.eigh(UPLO::Lower)?;
203
204        // Sort the eigenvalues and eigenvectors
205        let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
206            &eigvals_t.view(),
207            &eigvecs_t.view(),
208            &eigenvalue_comparison_mode,
209        );
210        let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
211
212        // Normalise the eigenvectors
213        let eigvecs_sorted_normalised =
214            normalise_eigenvectors_real(&eigvecs_sorted.view(), &smat.view(), thresh_offdiag)?;
215
216        // Regularise the eigenvectors
217        let eigvecs_sorted_normalised_regularised =
218            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
219
220        Ok(GeneralisedEigenvalueResult {
221            eigenvalues: eigvals_t_sorted,
222            eigenvectors: eigvecs_sorted_normalised_regularised,
223        })
224    }
225
226    fn solve_generalised_eigenvalue_problem_with_ggev(
227        &self,
228        _: bool,
229        thresh_offdiag: dtype_,
230        thresh_zeroov: dtype_,
231        eigenvalue_comparison_mode: EigenvalueComparisonMode,
232    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
233        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
234
235        // Real, symmetric S and H
236        let max_offdiag_h = *(hmat.to_owned() - hmat.t())
237                .map(|v| v.abs())
238                .iter()
239                .max_by(|a, b| a.total_cmp(b))
240                .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the Hamiltonian symmetric deviation matrix."))?;
241        ensure!(
242            max_offdiag_h <= thresh_offdiag,
243            "Hamiltonian matrix is not real-symmetric: ||H - H^T||_∞ = {max_offdiag_h:.3e} > {thresh_offdiag:.3e}."
244        );
245        let max_offdiag_s = *(smat.to_owned() - smat.t())
246                .map(|v| v.abs())
247                .iter()
248                .max_by(|a, b| a.total_cmp(b))
249                .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap symmetric deviation matrix."))?;
250        ensure!(
251            max_offdiag_s <= thresh_offdiag,
252            "Overlap matrix is not real-symmetric: ||S - S^T|| = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
253        );
254
255        let (geneigvals, eigvecs) =
256            (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
257
258        for gv in geneigvals.iter() {
259            if let GeneralizedEigenvalue::Finite(v, _) = gv {
260                ensure!(
261                    v.im().abs() <= thresh_offdiag,
262                    "Unexpected complex eigenvalue {v} for real, symmetric S and H."
263                );
264            }
265        }
266
267        // Filter and sort the eigenvalues and eigenvectors
268        let mut indices = (0..geneigvals.len())
269            .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
270            .collect_vec();
271
272        match eigenvalue_comparison_mode {
273            EigenvalueComparisonMode::Modulus => {
274                indices.sort_by(|i, j| {
275                    if let (
276                        GeneralizedEigenvalue::Finite(e_i, _),
277                        GeneralizedEigenvalue::Finite(e_j, _),
278                    ) = (&geneigvals[*i], &geneigvals[*j])
279                    {
280                        ComplexFloat::abs(*e_i)
281                            .partial_cmp(&ComplexFloat::abs(*e_j))
282                            .unwrap()
283                    } else {
284                        panic!("Unable to compare some eigenvalues.")
285                    }
286                });
287            }
288            EigenvalueComparisonMode::Real => {
289                indices.sort_by(|i, j| {
290                    if let (
291                        GeneralizedEigenvalue::Finite(e_i, _),
292                        GeneralizedEigenvalue::Finite(e_j, _),
293                    ) = (&geneigvals[*i], &geneigvals[*j])
294                    {
295                        e_i.re().partial_cmp(&e_j.re()).unwrap()
296                    } else {
297                        panic!("Unable to compare some eigenvalues.")
298                    }
299                });
300            }
301        }
302
303        let eigvals_re_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
304            if let GeneralizedEigenvalue::Finite(v, _) = gv {
305                v.re()
306            } else {
307                panic!("Unexpected indeterminate eigenvalue.")
308            }
309        });
310        let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
311        ensure!(
312            eigvecs_sorted.iter().all(|v| v.im().abs() < thresh_offdiag),
313            "Unexpected complex eigenvectors."
314        );
315        let eigvecs_re_sorted = eigvecs_sorted.map(|v| v.re());
316
317        // Normalise the eigenvectors
318        let eigvecs_re_sorted_normalised =
319            normalise_eigenvectors_real(&eigvecs_re_sorted.view(), &smat.view(), thresh_offdiag)?;
320
321        // Regularise the eigenvectors
322        let eigvecs_re_sorted_normalised_regularised =
323            regularise_eigenvectors(&eigvecs_re_sorted_normalised.view(), thresh_offdiag);
324
325        Ok(GeneralisedEigenvalueResult {
326            eigenvalues: eigvals_re_sorted,
327            eigenvectors: eigvecs_re_sorted_normalised_regularised,
328        })
329    }
330}
331
332impl<T> GeneralisedEigenvalueSolvable for (&ArrayView2<'_, Complex<T>>, &ArrayView2<'_, Complex<T>>)
333where
334    T: Float + FloatConst + Scalar<Complex = Complex<T>>,
335    Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
336    for<'a> ArrayView2<'a, Complex<T>>:
337        CanonicalOrthogonalisable<NumType = Complex<T>, RealType = T>,
338{
339    type NumType = Complex<T>;
340
341    type RealType = T;
342
343    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
344        &self,
345        complex_symmetric: bool,
346        thresh_offdiag: T,
347        thresh_zeroov: T,
348        eigenvalue_comparison_mode: EigenvalueComparisonMode,
349    ) -> Result<GeneralisedEigenvalueResult<Complex<T>>, anyhow::Error> {
350        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
351
352        // Symmetrise `hmat` and `smat` to improve numerical stability
353        let (hmat, smat): (Array2<Complex<T>>, Array2<Complex<T>>) = if complex_symmetric {
354            // Complex-symmetric
355            let max_offdiag_h = *(hmat.to_owned() - hmat.t())
356                    .mapv(ComplexFloat::abs)
357                    .iter()
358                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
359                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the Hamiltonian complex-symmetric deviation matrix."))?;
360            ensure!(
361                max_offdiag_h <= thresh_offdiag,
362                "Hamiltonian matrix is not complex-symmetric: ||H - H^T||_∞ = {max_offdiag_h:.3e} > {thresh_offdiag:.3e}."
363            );
364            let max_offdiag_s = *(smat.to_owned() - smat.t())
365                    .mapv(ComplexFloat::abs)
366                    .iter()
367                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
368                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-symmetric deviation matrix."))?;
369            ensure!(
370                max_offdiag_s <= thresh_offdiag,
371                "Overlap matrix is not complex-symmetric: ||S - S^T||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
372            );
373            (
374                (hmat.to_owned() + hmat.t().to_owned())
375                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
376                (smat.to_owned() + smat.t().to_owned())
377                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
378            )
379        } else {
380            // Complex-Hermitian
381            let max_offdiag_h = *(hmat.to_owned() - hmat.map(|v| v.conj()).t())
382                    .mapv(ComplexFloat::abs)
383                    .iter()
384                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
385                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the Hamiltonian complex-Hermitian deviation matrix."))?;
386            ensure!(
387                max_offdiag_h <= thresh_offdiag,
388                "Hamiltonian matrix is not complex-Hermitian: ||H - H^†||_∞ = {max_offdiag_h:.3e} > {thresh_offdiag:.3e}."
389            );
390            let max_offdiag_s = *(smat.to_owned() - smat.map(|v| v.conj()).t())
391                    .mapv(ComplexFloat::abs)
392                    .iter()
393                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
394                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-Hermitian deviation matrix."))?;
395            ensure!(
396                max_offdiag_s <= thresh_offdiag,
397                "Overlap matrix is not complex-Hermitian: ||S - S^†||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
398            );
399            (
400                (hmat.to_owned() + hmat.map(|v| v.conj()).t().to_owned())
401                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
402                (smat.to_owned() + smat.map(|v| v.conj()).t().to_owned())
403                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
404            )
405        };
406
407        // CanonicalOrthogonalisationResult::calc_canonical_orthogonal_matrix checks for
408        // complex-symmetry or complex-Hermiticity of S.
409        let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
410            complex_symmetric,
411            false,
412            thresh_offdiag,
413            thresh_zeroov,
414        )?;
415
416        let xmat = xmat_res.xmat();
417        let xmat_d = xmat_res.xmat_d();
418
419        let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
420        let smat_t = xmat_d.dot(&smat).dot(&xmat);
421
422        // Symmetrise `hmat_t` and `smat_t` to improve numerical stability
423        let (hmat_t_sym, smat_t_sym): (Array2<Complex<T>>, Array2<Complex<T>>) =
424            if complex_symmetric {
425                // Complex-symmetric
426                let max_offdiag_h = *(hmat_t.to_owned() - hmat_t.t())
427                        .mapv(ComplexFloat::abs)
428                        .iter()
429                        .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
430                        .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the Hamiltonian complex-symmetric deviation matrix."))?;
431                ensure!(
432                    max_offdiag_h <= thresh_offdiag,
433                    "Transformed Hamiltonian matrix is not complex-symmetric: ||H~ - (H~)^T||_∞ = {max_offdiag_h:.3e} > {thresh_offdiag:.3e}."
434                );
435                let max_offdiag_s = *(smat_t.to_owned() - smat_t.t())
436                        .mapv(ComplexFloat::abs)
437                        .iter()
438                        .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
439                        .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-symmetric deviation matrix."))?;
440                ensure!(
441                    max_offdiag_s <= thresh_offdiag,
442                    "Transformed overlap matrix is not complex-symmetric: ||S~ - (S~)^T||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
443                );
444                let hmat_t_s = (hmat_t.to_owned() + hmat_t.t().to_owned())
445                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
446                let smat_t_s = (smat_t.to_owned() + smat_t.t().to_owned())
447                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
448                (hmat_t_s, smat_t_s)
449            } else {
450                // Complex-Hermitian
451                let max_offdiag_h = *(hmat_t.to_owned() - hmat_t.map(|v| v.conj()).t())
452                        .mapv(ComplexFloat::abs)
453                        .iter()
454                        .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
455                        .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the Hamiltonian complex-Hermitian deviation matrix."))?;
456                ensure!(
457                    max_offdiag_h <= thresh_offdiag,
458                    "Transformed Hamiltonian matrix is not complex-Hermitian: ||H~ - (H~)^†||_∞ = {max_offdiag_h:.3e} > {thresh_offdiag:.3e}."
459                );
460                let max_offdiag_s = *(smat_t.to_owned() - smat_t.map(|v| v.conj()).t())
461                        .mapv(ComplexFloat::abs)
462                        .iter()
463                        .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
464                        .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-Hermitian deviation matrix."))?;
465                ensure!(
466                    max_offdiag_s <= thresh_offdiag,
467                    "Transformed overlap matrix is not complex-Hermitian: ||S~ - (S~)^†||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
468                );
469                let hmat_t_s = (hmat_t.to_owned() + hmat_t.map(|v| v.conj()).t().to_owned())
470                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
471                let smat_t_s = (smat_t.to_owned() + smat_t.map(|v| v.conj()).t().to_owned())
472                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
473                (hmat_t_s, smat_t_s)
474            };
475        let smat_t_sym_d = smat_t_sym.map(|v| v.conj()).t().to_owned();
476        log::debug!("Complex-symmetric? {complex_symmetric}");
477        log::debug!("Canonical orthogonalisation X matrix:\n  {xmat:+.8e}");
478        log::debug!("Canonical-orthogonalised NOCI Hamiltonian matrix H~:\n  {hmat_t_sym:+.8e}");
479        log::debug!("Canonical-orthogonalised NOCI overlap matrix S~:\n  {smat_t_sym:+.8e}");
480
481        // smat_t_sym is not necessarily the identity, but is guaranteed to be Hermitian.
482        let max_diff = (&smat_t_sym_d.dot(&smat_t_sym) - &Array2::<T>::eye(smat_t_sym.nrows()))
483            .iter()
484            .map(|x| ComplexFloat::abs(*x))
485            .max_by(|x, y| {
486                x.partial_cmp(y)
487                    .expect("Unable to compare two `abs` values.")
488            })
489            .ok_or_else(|| {
490                format_err!("Unable to determine the maximum element of the |S^†.S - I| matrix.")
491            })?;
492        ensure!(
493            max_diff <= thresh_offdiag,
494            "The S^†.S matrix is not the identity matrix. S is therefore not Hermitian."
495        );
496        let smat_t_sym_d_hmat_t_sym = smat_t_sym_d.dot(&hmat_t_sym);
497        log::debug!(
498            "Hamiltonian matrix for diagonalisation (S~)^†.(H~):\n  {smat_t_sym_d_hmat_t_sym:+.8e}"
499        );
500
501        let (eigvals_t, eigvecs_t) = smat_t_sym_d_hmat_t_sym.eig()?;
502
503        // Sort the eigenvalues and eigenvectors
504        let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
505            &eigvals_t.view(),
506            &eigvecs_t.view(),
507            &eigenvalue_comparison_mode,
508        );
509        log::debug!("Sorted eigenvalues of (S~)^†.(H~):");
510        for (i, eigval) in eigvals_t_sorted.iter().enumerate() {
511            log::debug!("  {i}: {eigval:+.8e}");
512        }
513        log::debug!("");
514        log::debug!("Sorted eigenvectors of (S~)^†.(H~):\n  {eigvecs_t_sorted:+.8e}");
515        log::debug!("");
516
517        // Check orthogonality
518        // let _ = normalise_eigenvectors_complex(
519        //     &eigvecs_t.view(),
520        //     &smat_t.view(),
521        //     complex_symmetric,
522        //     Some(thresh_offdiag),
523        // )?;
524
525        let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
526
527        // Normalise the eigenvectors
528        let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
529            &eigvecs_sorted.view(),
530            &smat.view(),
531            complex_symmetric,
532            None,
533        )?;
534
535        // Regularise the eigenvectors
536        let eigvecs_sorted_normalised_regularised =
537            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
538
539        Ok(GeneralisedEigenvalueResult {
540            eigenvalues: eigvals_t_sorted,
541            eigenvectors: eigvecs_sorted_normalised_regularised,
542        })
543    }
544
545    fn solve_generalised_eigenvalue_problem_with_ggev(
546        &self,
547        complex_symmetric: bool,
548        thresh_offdiag: T,
549        thresh_zeroov: T,
550        eigenvalue_comparison_mode: EigenvalueComparisonMode,
551    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
552        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
553
554        if complex_symmetric {
555            // Complex-symmetric H and S
556            let max_offdiag_h = *(hmat.to_owned() - hmat.t())
557                    .mapv(ComplexFloat::abs)
558                    .iter()
559                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
560                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the Hamiltonian complex-symmetric deviation matrix."))?;
561            ensure!(
562                max_offdiag_h <= thresh_offdiag,
563                "Hamiltonian matrix is not complex-symmetric: ||H - H^T||_∞ = {max_offdiag_h:.3e} > {thresh_offdiag:.3e}."
564            );
565            let max_offdiag_s = *(smat.to_owned() - smat.t())
566                    .mapv(ComplexFloat::abs)
567                    .iter()
568                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
569                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-symmetric deviation matrix."))?;
570            ensure!(
571                max_offdiag_s <= thresh_offdiag,
572                "Overlap matrix is not complex-symmetric: ||S - S^T||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
573            );
574        } else {
575            // Complex-Hermitian H and S
576            let max_offdiag_h = *(hmat.to_owned() - hmat.map(|v| v.conj()).t())
577                    .mapv(ComplexFloat::abs)
578                    .iter()
579                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
580                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the Hamiltonian complex-Hermitian deviation matrix."))?;
581            ensure!(
582                max_offdiag_h <= thresh_offdiag,
583                "Hamiltonian matrix is not complex-Hermitian: ||H - H^†||_∞ = {max_offdiag_h:.3e} > {thresh_offdiag:.3e}."
584            );
585            let max_offdiag_s = *(smat.to_owned() - smat.map(|v| v.conj()).t())
586                    .mapv(ComplexFloat::abs)
587                    .iter()
588                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
589                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-Hermitian deviation matrix."))?;
590            ensure!(
591                max_offdiag_s <= thresh_offdiag,
592                "Overlap matrix is not complex-Hermitian: ||S - S^†||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
593            );
594        }
595
596        let (geneigvals, eigvecs) =
597            (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
598
599        // Filter and sort the eigenvalues and eigenvectors
600        let mut indices = (0..geneigvals.len())
601            .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
602            .collect_vec();
603
604        match eigenvalue_comparison_mode {
605            EigenvalueComparisonMode::Modulus => {
606                indices.sort_by(|i, j| {
607                    if let (
608                        GeneralizedEigenvalue::Finite(e_i, _),
609                        GeneralizedEigenvalue::Finite(e_j, _),
610                    ) = (&geneigvals[*i], &geneigvals[*j])
611                    {
612                        ComplexFloat::abs(*e_i)
613                            .partial_cmp(&ComplexFloat::abs(*e_j))
614                            .unwrap()
615                    } else {
616                        panic!("Unable to compare some eigenvalues.")
617                    }
618                });
619            }
620            EigenvalueComparisonMode::Real => {
621                indices.sort_by(|i, j| {
622                    if let (
623                        GeneralizedEigenvalue::Finite(e_i, _),
624                        GeneralizedEigenvalue::Finite(e_j, _),
625                    ) = (&geneigvals[*i], &geneigvals[*j])
626                    {
627                        e_i.re().partial_cmp(&e_j.re()).unwrap()
628                    } else {
629                        panic!("Unable to compare some eigenvalues.")
630                    }
631                });
632            }
633        }
634
635        let eigvals_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
636            if let GeneralizedEigenvalue::Finite(v, _) = gv {
637                *v
638            } else {
639                panic!("Unexpected indeterminate eigenvalue.")
640            }
641        });
642        let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
643
644        // Normalise the eigenvectors
645        let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
646            &eigvecs_sorted.view(),
647            &smat.view(),
648            complex_symmetric,
649            Some(thresh_offdiag),
650        )?;
651
652        // Regularise the eigenvectors
653        let eigvecs_sorted_normalised_regularised =
654            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
655
656        Ok(GeneralisedEigenvalueResult {
657            eigenvalues: eigvals_sorted,
658            eigenvectors: eigvecs_sorted_normalised_regularised,
659        })
660    }
661}
662
663// -------------------
664// Auxiliary functions
665// -------------------
666
667/// Sorts the eigenvalues and the corresponding eigenvectors.
668///
669/// # Arguments
670///
671/// * `eigvals` - The eigenvalues.
672/// * `eigvecs` - The corresponding eigenvectors.
673/// * `eigenvalue_comparison_mode` - Eigenvalue comparison mode.
674///
675/// # Returns
676///
677/// A tuple containing thw sorted eigenvalues and eigenvectors.
678fn sort_eigenvalues_eigenvectors<T: ComplexFloat>(
679    eigvals: &ArrayView1<T>,
680    eigvecs: &ArrayView2<T>,
681    eigenvalue_comparison_mode: &EigenvalueComparisonMode,
682) -> (Array1<T>, Array2<T>) {
683    let mut indices = (0..eigvals.len()).collect_vec();
684    match eigenvalue_comparison_mode {
685        EigenvalueComparisonMode::Modulus => {
686            indices.sort_by(|i, j| {
687                ComplexFloat::abs(eigvals[*i])
688                    .partial_cmp(&ComplexFloat::abs(eigvals[*j]))
689                    .unwrap()
690            });
691        }
692        EigenvalueComparisonMode::Real => {
693            indices.sort_by(|i, j| eigvals[*i].re().partial_cmp(&eigvals[*j].re()).unwrap());
694        }
695    }
696    let eigvals_sorted = eigvals.select(Axis(0), &indices);
697    let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
698    (eigvals_sorted, eigvecs_sorted)
699}
700
701/// Regularises the eigenvectors such that the first entry of each of them has a positive real
702/// part, or a positive imaginary part if the real part is zero.
703///
704/// # Arguments
705///
706/// * `eigvecs` - The eigenvectors to be regularised.
707/// * `thresh` - Threshold for determining if a real number is zero.
708///
709/// # Returns
710///
711/// The regularised eigenvectors.
712fn regularise_eigenvectors<T>(eigvecs: &ArrayView2<T>, thresh: T::Real) -> Array2<T>
713where
714    T: ComplexFloat + One,
715    T::Real: Float,
716{
717    let eigvecs_sgn = stack!(
718        Axis(0),
719        eigvecs
720            .row(0)
721            .map(|v| {
722                if Float::abs(ComplexFloat::re(*v)) > thresh {
723                    T::from(v.re().signum()).expect("Unable to convert a signum to the right type.")
724                } else if Float::abs(ComplexFloat::im(*v)) > thresh {
725                    T::from(v.im().signum()).expect("Unable to convert a signum to the right type.")
726                } else {
727                    T::one()
728                }
729            })
730            .view()
731    );
732    eigvecs * eigvecs_sgn
733}
734
735/// Normalises the real eigenvectors with respect to a metric.
736///
737/// # Arguments
738///
739/// * `eigvecs` - The eigenvectors to be normalised.
740/// * `smat` - The metric.
741/// * `thresh` - Threshold for verifying the orthogonality of the eigenvectors.
742///
743/// # Returns
744///
745/// The normalised eigenvectors.
746fn normalise_eigenvectors_real<T>(
747    eigvecs: &ArrayView2<T>,
748    smat: &ArrayView2<T>,
749    thresh: T,
750) -> Result<Array2<T>, anyhow::Error>
751where
752    T: LinalgScalar + Float + std::fmt::LowerExp,
753{
754    let sq_norm = einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
755        .map_err(|err| format_err!(err))?
756        .into_dimensionality::<Ix2>()
757        .map_err(|err| format_err!(err))?;
758    let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
759        .iter()
760        .map(|x| x.abs())
761        .max_by(|x, y| {
762            x.partial_cmp(y)
763                .expect("Unable to compare two `abs` values.")
764        })
765        .ok_or_else(|| {
766            format_err!(
767                "Unable to determine the maximum off-diagonal element of the C^T.S.C matrix."
768            )
769        })?;
770
771    ensure!(
772        max_diff <= thresh,
773        "The C^T.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thresh:.3e}."
774    );
775    ensure!(
776        sq_norm.diag().iter().all(|v| *v > T::zero()),
777        "Some eigenvectors have negative squared norms and cannot be normalised over the reals."
778    );
779    let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
780    Ok(eigvecs_normalised)
781}
782
783/// Normalises the complex eigenvectors with respect to a metric.
784///
785/// # Arguments
786///
787/// * `eigvecs` - The eigenvectors to be normalised.
788/// * `smat` - The metric.
789/// * `complex_symmetric` - Boolean indicating if the inner product is complex-symmetric or not.
790/// * `thresh` - Optioanl threshold for verifying the orthogonality of the eigenvectors. If `None`,
791///   orthogonality will not be verified.
792///
793/// # Returns
794///
795/// The normalised eigenvectors.
796fn normalise_eigenvectors_complex<T>(
797    eigvecs: &ArrayView2<T>,
798    smat: &ArrayView2<T>,
799    complex_symmetric: bool,
800    thresh: Option<T::Real>,
801) -> Result<Array2<T>, anyhow::Error>
802where
803    T: LinalgScalar + ComplexFloat + std::fmt::Display + std::fmt::LowerExp,
804    T::Real: Float + std::fmt::LowerExp,
805{
806    let sq_norm = if complex_symmetric {
807        einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
808            .map_err(|err| format_err!(err))?
809            .into_dimensionality::<Ix2>()
810            .map_err(|err| format_err!(err))?
811    } else {
812        einsum(
813            "ji,jk,kl->il",
814            &[&eigvecs.map(|v| v.conj()).view(), smat, eigvecs],
815        )
816        .map_err(|err| format_err!(err))?
817        .into_dimensionality::<Ix2>()
818        .map_err(|err| format_err!(err))?
819    };
820
821    if let Some(thr) = thresh {
822        let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
823            .iter()
824            .map(|x| ComplexFloat::abs(*x))
825            .max_by(|x, y| {
826                x.partial_cmp(y)
827                    .expect("Unable to compare two `abs` values.")
828            })
829            .ok_or_else(|| {
830                if complex_symmetric {
831                    format_err!(
832                        "Unable to determine the maximum off-diagonal element of the C^T.S.C matrix."
833                    )
834                } else {
835                    format_err!(
836                        "Unable to determine the maximum off-diagonal element of the C^†.S.C matrix."
837                    )
838                }
839            })?;
840
841        if complex_symmetric {
842            log::debug!("C^T.S.C:\n  {sq_norm:+.8e}");
843            ensure!(
844                max_diff <= thr,
845                "The C^T.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thr:.3e}."
846            )
847        } else {
848            log::debug!("C^†.S.C:\n  {sq_norm:+.8e}");
849            ensure!(
850                max_diff <= thr,
851                "The C^†.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thr:.3e}."
852            )
853        };
854    }
855    let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
856    Ok(eigvecs_normalised)
857}