qsym2/target/noci/backend/solver/
mod.rs

1use anyhow::{self, ensure, format_err};
2use duplicate::duplicate_item;
3use itertools::Itertools;
4use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, Ix2, LinalgScalar, stack};
5use ndarray_einsum::einsum;
6use ndarray_linalg::{
7    Eig, EigGeneralized, Eigh, GeneralizedEigenvalue, Lapack, Norm, Scalar, UPLO,
8};
9use num::traits::FloatConst;
10use num::{Float, One};
11use num_complex::{Complex, ComplexFloat};
12
13use crate::analysis::EigenvalueComparisonMode;
14
15use crate::target::noci::backend::nonortho::CanonicalOrthogonalisable;
16
17pub mod noci;
18
19#[cfg(test)]
20#[path = "solver_tests.rs"]
21mod solver_tests;
22
23// -----------------------------
24// GeneralisedEigenvalueSolvable
25// -----------------------------
26
27/// Trait to solve the generalised eigenvalue equation for a pair of square matrices $`\mathbf{A}`$
28/// and $`\mathbf{B}`$:
29/// ```math
30///     \mathbf{A} \mathbf{v} = \lambda \mathbf{B} \mathbf{v},
31/// ```
32/// where $`\mathbf{A}`$ and $`\mathbf{B}`$ are in general non-Hermitian and non-positive-definite.
33pub trait GeneralisedEigenvalueSolvable {
34    /// Numerical type of the matrix elements constituting the generalised eigenvalue problem.
35    type NumType;
36
37    /// Numerical type of the various thresholds for comparison.
38    type RealType;
39
40    /// Solves the *auxiliary* generalised eigenvalue problem
41    /// ```math
42    ///     \tilde{\mathbf{A}} \tilde{\mathbf{v}} = \tilde{\lambda} \tilde{\mathbf{B}} \tilde{\mathbf{v}},
43    /// ```
44    /// where $`\tilde{\mathbf{B}}`$ is the canonical-orthogonalised version of $`\mathbf{B}`$. If
45    /// $`\mathbf{B}`$ is not of full rank, then the two eigenvalue problems are different.
46    ///
47    /// # Arguments
48    ///
49    /// * `complex_symmetric` - Boolean indicating whether the provided pair of matrices are
50    /// complex-symmetric.
51    /// * `thresh_offdiag` - Threshold for checking if any off-diagonal elements are non-zero when
52    /// verifying orthogonality.
53    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of $`\mathbf{B}`$.
54    /// * `eigenvalue_comparison_mode` - Comparison mode for sorting eigenvalues and their
55    /// corresponding eigenvectors.
56    ///
57    /// # Returns
58    ///
59    /// The generalised eigenvalue result.
60    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
61        &self,
62        complex_symmetric: bool,
63        thresh_offdiag: Self::RealType,
64        thresh_zeroov: Self::RealType,
65        eigenvalue_comparison_mode: EigenvalueComparisonMode,
66    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
67
68    /// Solves the generalised eigenvalue problem using LAPACK's `?ggev` generalised eigensolver.
69    ///
70    /// Note that this can be numerically unstable if $`\mathbf{B}`$ is not of full rank.
71    ///
72    /// # Arguments
73    ///
74    /// * `complex_symmetric` - Boolean indicating whether the provided pair of matrices are
75    /// complex-symmetric.
76    /// * `thresh_offdiag` - Threshold for checking if any off-diagonal elements are non-zero when
77    /// verifying orthogonality.
78    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of $`\mathbf{B}`$.
79    /// * `eigenvalue_comparison_mode` - Comparison mode for sorting eigenvalues and their
80    /// corresponding eigenvectors.
81    ///
82    /// # Returns
83    ///
84    /// The generalised eigenvalue result.
85    fn solve_generalised_eigenvalue_problem_with_ggev(
86        &self,
87        complex_symmetric: bool,
88        thresh_offdiag: Self::RealType,
89        thresh_zeroov: Self::RealType,
90        eigenvalue_comparison_mode: EigenvalueComparisonMode,
91    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
92}
93
94/// Structure containing the eigenvalues and eigenvectors of a generalised
95/// eigenvalue problem.
96pub struct GeneralisedEigenvalueResult<T> {
97    /// The resulting eigenvalues.
98    eigenvalues: Array1<T>,
99
100    /// The corresponding eigenvectors.
101    eigenvectors: Array2<T>,
102}
103
104impl<T> GeneralisedEigenvalueResult<T> {
105    /// Returns the eigenvalues.
106    pub fn eigenvalues(&'_ self) -> ArrayView1<'_, T> {
107        self.eigenvalues.view()
108    }
109
110    /// Returns the eigenvectors.
111    pub fn eigenvectors(&'_ self) -> ArrayView2<'_, T> {
112        self.eigenvectors.view()
113    }
114}
115
116#[duplicate_item(
117    [
118        dtype_ [ f64 ]
119    ]
120    [
121        dtype_ [ f32 ]
122    ]
123)]
124impl GeneralisedEigenvalueSolvable for (&ArrayView2<'_, dtype_>, &ArrayView2<'_, dtype_>) {
125    type NumType = dtype_;
126    type RealType = dtype_;
127
128    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
129        &self,
130        _: bool,
131        thresh_offdiag: dtype_,
132        thresh_zeroov: dtype_,
133        eigenvalue_comparison_mode: EigenvalueComparisonMode,
134    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
135        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
136
137        // Real, symmetric S and H
138        let deviation_h = (hmat.to_owned() - hmat.t()).norm_l2();
139        ensure!(
140            deviation_h <= thresh_offdiag,
141            "Hamiltonian matrix is not real-symmetric: ||H - H^T|| = {deviation_h:.3e} > {thresh_offdiag:.3e}."
142        );
143
144        // CanonicalOrthogonalisationResult::calc_canonical_orthogonal_matrix checks for
145        // real-symmetry of S.
146        // This will fail over the reals if smat contains negative eigenvalues.
147        let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
148            true,
149            false,
150            thresh_offdiag,
151            thresh_zeroov,
152        )?;
153
154        let xmat = xmat_res.xmat();
155        let xmat_d = xmat_res.xmat_d();
156
157        let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
158        let smat_t = xmat_d.dot(&smat).dot(&xmat);
159
160        // Over the reals, canonical orthogonalisation cannot handle `smat` with negative
161        // eigenvalues. This means that `smat_t` can only be the identity.
162        let max_diff = (&smat_t - &Array2::<dtype_>::eye(smat_t.nrows()))
163            .iter()
164            .map(|x| ComplexFloat::abs(*x))
165            .max_by(|x, y| {
166                x.partial_cmp(y)
167                    .expect("Unable to compare two `abs` values.")
168            })
169            .ok_or_else(|| {
170                format_err!("Unable to determine the maximum element of the |S - I| matrix.")
171            })?;
172        ensure!(
173            max_diff <= thresh_offdiag,
174            "The orthogonalised overlap matrix is not the identity matrix: the maximum absolute deviation is {max_diff:.3e} > {thresh_offdiag:.3e}."
175        );
176
177        let (eigvals_t, eigvecs_t) = hmat_t.eigh(UPLO::Lower)?;
178
179        // Sort the eigenvalues and eigenvectors
180        let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
181            &eigvals_t.view(),
182            &eigvecs_t.view(),
183            &eigenvalue_comparison_mode,
184        );
185        let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
186
187        // Normalise the eigenvectors
188        let eigvecs_sorted_normalised =
189            normalise_eigenvectors_real(&eigvecs_sorted.view(), &smat.view(), thresh_offdiag)?;
190
191        // Regularise the eigenvectors
192        let eigvecs_sorted_normalised_regularised =
193            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
194
195        Ok(GeneralisedEigenvalueResult {
196            eigenvalues: eigvals_t_sorted,
197            eigenvectors: eigvecs_sorted_normalised_regularised,
198        })
199    }
200
201    fn solve_generalised_eigenvalue_problem_with_ggev(
202        &self,
203        _: bool,
204        thresh_offdiag: dtype_,
205        thresh_zeroov: dtype_,
206        eigenvalue_comparison_mode: EigenvalueComparisonMode,
207    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
208        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
209
210        // Real, symmetric S and H
211        let deviation_h = (hmat.to_owned() - hmat.t()).norm_l2();
212        ensure!(
213            deviation_h <= thresh_offdiag,
214            "Hamiltonian matrix is not real-symmetric: ||H - H^T|| = {deviation_h:.3e} > {thresh_offdiag:.3e}."
215        );
216        let deviation_s = (smat.to_owned() - smat.t()).norm_l2();
217        ensure!(
218            deviation_s <= thresh_offdiag,
219            "Overlap matrix is not real-symmetric: ||S - S^T|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
220        );
221
222        let (geneigvals, eigvecs) =
223            (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
224
225        for gv in geneigvals.iter() {
226            if let GeneralizedEigenvalue::Finite(v, _) = gv {
227                ensure!(
228                    v.im().abs() <= thresh_offdiag,
229                    "Unexpected complex eigenvalue {v} for real, symmetric S and H."
230                );
231            }
232        }
233
234        // Filter and sort the eigenvalues and eigenvectors
235        let mut indices = (0..geneigvals.len())
236            .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
237            .collect_vec();
238
239        match eigenvalue_comparison_mode {
240            EigenvalueComparisonMode::Modulus => {
241                indices.sort_by(|i, j| {
242                    if let (
243                        GeneralizedEigenvalue::Finite(e_i, _),
244                        GeneralizedEigenvalue::Finite(e_j, _),
245                    ) = (&geneigvals[*i], &geneigvals[*j])
246                    {
247                        ComplexFloat::abs(*e_i)
248                            .partial_cmp(&ComplexFloat::abs(*e_j))
249                            .unwrap()
250                    } else {
251                        panic!("Unable to compare some eigenvalues.")
252                    }
253                });
254            }
255            EigenvalueComparisonMode::Real => {
256                indices.sort_by(|i, j| {
257                    if let (
258                        GeneralizedEigenvalue::Finite(e_i, _),
259                        GeneralizedEigenvalue::Finite(e_j, _),
260                    ) = (&geneigvals[*i], &geneigvals[*j])
261                    {
262                        e_i.re().partial_cmp(&e_j.re()).unwrap()
263                    } else {
264                        panic!("Unable to compare some eigenvalues.")
265                    }
266                });
267            }
268        }
269
270        let eigvals_re_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
271            if let GeneralizedEigenvalue::Finite(v, _) = gv {
272                v.re()
273            } else {
274                panic!("Unexpected indeterminate eigenvalue.")
275            }
276        });
277        let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
278        ensure!(
279            eigvecs_sorted.iter().all(|v| v.im().abs() < thresh_offdiag),
280            "Unexpected complex eigenvectors."
281        );
282        let eigvecs_re_sorted = eigvecs_sorted.map(|v| v.re());
283
284        // Normalise the eigenvectors
285        let eigvecs_re_sorted_normalised =
286            normalise_eigenvectors_real(&eigvecs_re_sorted.view(), &smat.view(), thresh_offdiag)?;
287
288        // Regularise the eigenvectors
289        let eigvecs_re_sorted_normalised_regularised =
290            regularise_eigenvectors(&eigvecs_re_sorted_normalised.view(), thresh_offdiag);
291
292        Ok(GeneralisedEigenvalueResult {
293            eigenvalues: eigvals_re_sorted,
294            eigenvectors: eigvecs_re_sorted_normalised_regularised,
295        })
296    }
297}
298
299impl<T> GeneralisedEigenvalueSolvable for (&ArrayView2<'_, Complex<T>>, &ArrayView2<'_, Complex<T>>)
300where
301    T: Float + FloatConst + Scalar<Complex = Complex<T>>,
302    Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
303    for<'a> ArrayView2<'a, Complex<T>>:
304        CanonicalOrthogonalisable<NumType = Complex<T>, RealType = T>,
305{
306    type NumType = Complex<T>;
307
308    type RealType = T;
309
310    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
311        &self,
312        complex_symmetric: bool,
313        thresh_offdiag: T,
314        thresh_zeroov: T,
315        eigenvalue_comparison_mode: EigenvalueComparisonMode,
316    ) -> Result<GeneralisedEigenvalueResult<Complex<T>>, anyhow::Error> {
317        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
318
319        // Symmetrise `hmat` and `smat` to improve numerical stability
320        let (hmat, smat): (Array2<Complex<T>>, Array2<Complex<T>>) = if complex_symmetric {
321            // Complex-symmetric
322            let deviation_h = (hmat.to_owned() - hmat.t()).norm_l2();
323            ensure!(
324                deviation_h <= thresh_offdiag,
325                "Hamiltonian matrix is not complex-symmetric: ||H - H^T|| = {deviation_h:.3e} > {thresh_offdiag:.3e}."
326            );
327            let deviation_s = (smat.to_owned() - smat.t()).norm_l2();
328            ensure!(
329                deviation_s <= thresh_offdiag,
330                "Overlap matrix is not complex-symmetric: ||S - S^T|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
331            );
332            (
333                (hmat.to_owned() + hmat.t().to_owned())
334                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
335                (smat.to_owned() + smat.t().to_owned())
336                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
337            )
338        } else {
339            // Complex-Hermitian
340            let deviation_h = (hmat.to_owned() - hmat.map(|v| v.conj()).t()).norm_l2();
341            ensure!(
342                deviation_h <= thresh_offdiag,
343                "Hamiltonian matrix is not complex-Hermitian: ||H - H^†|| = {deviation_h:.3e} > {thresh_offdiag:.3e}."
344            );
345            let deviation_s = (smat.to_owned() - smat.map(|v| v.conj()).t()).norm_l2();
346            ensure!(
347                deviation_s <= thresh_offdiag,
348                "Overlap matrix is not complex-Hermitian: ||S - S^†|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
349            );
350            (
351                (hmat.to_owned() + hmat.map(|v| v.conj()).t().to_owned())
352                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
353                (smat.to_owned() + smat.map(|v| v.conj()).t().to_owned())
354                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
355            )
356        };
357
358        // CanonicalOrthogonalisationResult::calc_canonical_orthogonal_matrix checks for
359        // complex-symmetry or complex-Hermiticity of S.
360        let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
361            complex_symmetric,
362            false,
363            thresh_offdiag,
364            thresh_zeroov,
365        )?;
366
367        let xmat = xmat_res.xmat();
368        let xmat_d = xmat_res.xmat_d();
369
370        let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
371        let smat_t = xmat_d.dot(&smat).dot(&xmat);
372
373        // Symmetrise `hmat_t` and `smat_t` to improve numerical stability
374        let (hmat_t_sym, smat_t_sym): (Array2<Complex<T>>, Array2<Complex<T>>) =
375            if complex_symmetric {
376                // Complex-symmetric
377                let deviation_h = (hmat_t.to_owned() - hmat_t.t()).norm_l2();
378                ensure!(
379                    deviation_h <= thresh_offdiag,
380                    "Transformed Hamiltonian matrix is not complex-symmetric: ||H~ - (H~)^T|| = {deviation_h:.3e} > {thresh_offdiag:.3e}."
381                );
382                let deviation_s = (smat_t.to_owned() - smat_t.t()).norm_l2();
383                ensure!(
384                    deviation_s <= thresh_offdiag,
385                    "Transformed overlap matrix is not complex-symmetric: ||S~ - (S~)^T|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
386                );
387                let hmat_t_s = (hmat_t.to_owned() + hmat_t.t().to_owned())
388                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
389                let smat_t_s = (smat_t.to_owned() + smat_t.t().to_owned())
390                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
391                (hmat_t_s, smat_t_s)
392            } else {
393                // Complex-Hermitian
394                let deviation_h = (hmat_t.to_owned() - hmat_t.map(|v| v.conj()).t()).norm_l2();
395                ensure!(
396                    deviation_h <= thresh_offdiag,
397                    "Transformed Hamiltonian matrix is not complex-Hermitian: ||H~ - (H~)^†|| = {deviation_h:.3e} > {thresh_offdiag:.3e}."
398                );
399                let deviation_s = (smat_t.to_owned() - smat_t.map(|v| v.conj()).t()).norm_l2();
400                ensure!(
401                    deviation_s <= thresh_offdiag,
402                    "Transformed overlap matrix is not complex-Hermitian: ||S~ - (S~)^†|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
403                );
404                let hmat_t_s = (hmat_t.to_owned() + hmat_t.map(|v| v.conj()).t().to_owned())
405                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
406                let smat_t_s = (smat_t.to_owned() + smat_t.map(|v| v.conj()).t().to_owned())
407                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
408                (hmat_t_s, smat_t_s)
409            };
410        let smat_t_sym_d = smat_t_sym.map(|v| v.conj()).t().to_owned();
411        log::debug!("Complex-symmetric? {complex_symmetric}");
412        log::debug!("Canonical orthogonalisation X matrix:\n  {xmat:+.8e}");
413        log::debug!("Canonical-orthogonalised NOCI Hamiltonian matrix H~:\n  {hmat_t_sym:+.8e}");
414        log::debug!("Canonical-orthogonalised NOCI overlap matrix S~:\n  {smat_t_sym:+.8e}");
415
416        // smat_t_sym is not necessarily the identity, but is guaranteed to be Hermitian.
417        let max_diff = (&smat_t_sym_d.dot(&smat_t_sym) - &Array2::<T>::eye(smat_t_sym.nrows()))
418            .iter()
419            .map(|x| ComplexFloat::abs(*x))
420            .max_by(|x, y| {
421                x.partial_cmp(y)
422                    .expect("Unable to compare two `abs` values.")
423            })
424            .ok_or_else(|| {
425                format_err!("Unable to determine the maximum element of the |S^†.S - I| matrix.")
426            })?;
427        ensure!(
428            max_diff <= thresh_offdiag,
429            "The S^†.S matrix is not the identity matrix. S is therefore not Hermitian."
430        );
431        let smat_t_sym_d_hmat_t_sym = smat_t_sym_d.dot(&hmat_t_sym);
432        log::debug!(
433            "Hamiltonian matrix for diagonalisation (S~)^†.(H~):\n  {smat_t_sym_d_hmat_t_sym:+.8e}"
434        );
435
436        let (eigvals_t, eigvecs_t) = smat_t_sym_d_hmat_t_sym.eig()?;
437
438        // Sort the eigenvalues and eigenvectors
439        let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
440            &eigvals_t.view(),
441            &eigvecs_t.view(),
442            &eigenvalue_comparison_mode,
443        );
444        log::debug!("Sorted eigenvalues of (S~)^†.(H~):");
445        for (i, eigval) in eigvals_t_sorted.iter().enumerate() {
446            log::debug!("  {i}: {eigval:+.8e}");
447        }
448        log::debug!("");
449        log::debug!("Sorted eigenvectors of (S~)^†.(H~):\n  {eigvecs_t_sorted:+.8e}");
450        log::debug!("");
451
452        // Check orthogonality
453        // let _ = normalise_eigenvectors_complex(
454        //     &eigvecs_t.view(),
455        //     &smat_t.view(),
456        //     complex_symmetric,
457        //     Some(thresh_offdiag),
458        // )?;
459
460        let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
461
462        // Normalise the eigenvectors
463        let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
464            &eigvecs_sorted.view(),
465            &smat.view(),
466            complex_symmetric,
467            None,
468        )?;
469
470        // Regularise the eigenvectors
471        let eigvecs_sorted_normalised_regularised =
472            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
473
474        Ok(GeneralisedEigenvalueResult {
475            eigenvalues: eigvals_t_sorted,
476            eigenvectors: eigvecs_sorted_normalised_regularised,
477        })
478    }
479
480    fn solve_generalised_eigenvalue_problem_with_ggev(
481        &self,
482        complex_symmetric: bool,
483        thresh_offdiag: T,
484        thresh_zeroov: T,
485        eigenvalue_comparison_mode: EigenvalueComparisonMode,
486    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
487        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
488
489        if complex_symmetric {
490            // Complex-symmetric H and S
491            let deviation_h = (hmat.to_owned() - hmat.t()).norm_l2();
492            ensure!(
493                deviation_h <= thresh_offdiag,
494                "Hamiltonian matrix is not complex-symmetric: ||H - H^T|| = {deviation_h:.3e} > {thresh_offdiag:.3e}."
495            );
496            let deviation_s = (smat.to_owned() - smat.t()).norm_l2();
497            ensure!(
498                deviation_s <= thresh_offdiag,
499                "Overlap matrix is not complex-symmetric: ||S - S^T|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
500            );
501        } else {
502            // Complex-Hermitian H and S
503            let deviation_h = (hmat.to_owned() - hmat.map(|v| v.conj()).t()).norm_l2();
504            ensure!(
505                deviation_h <= thresh_offdiag,
506                "Hamiltonian matrix is not complex-Hermitian: ||H - H^†|| = {deviation_h:.3e} > {thresh_offdiag:.3e}."
507            );
508            let deviation_s = (smat.to_owned() - smat.map(|v| v.conj()).t()).norm_l2();
509            ensure!(
510                deviation_s <= thresh_offdiag,
511                "Overlap matrix is not complex-Hermitian: ||S - S^†|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
512            );
513        }
514
515        let (geneigvals, eigvecs) =
516            (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
517
518        // Filter and sort the eigenvalues and eigenvectors
519        let mut indices = (0..geneigvals.len())
520            .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
521            .collect_vec();
522
523        match eigenvalue_comparison_mode {
524            EigenvalueComparisonMode::Modulus => {
525                indices.sort_by(|i, j| {
526                    if let (
527                        GeneralizedEigenvalue::Finite(e_i, _),
528                        GeneralizedEigenvalue::Finite(e_j, _),
529                    ) = (&geneigvals[*i], &geneigvals[*j])
530                    {
531                        ComplexFloat::abs(*e_i)
532                            .partial_cmp(&ComplexFloat::abs(*e_j))
533                            .unwrap()
534                    } else {
535                        panic!("Unable to compare some eigenvalues.")
536                    }
537                });
538            }
539            EigenvalueComparisonMode::Real => {
540                indices.sort_by(|i, j| {
541                    if let (
542                        GeneralizedEigenvalue::Finite(e_i, _),
543                        GeneralizedEigenvalue::Finite(e_j, _),
544                    ) = (&geneigvals[*i], &geneigvals[*j])
545                    {
546                        e_i.re().partial_cmp(&e_j.re()).unwrap()
547                    } else {
548                        panic!("Unable to compare some eigenvalues.")
549                    }
550                });
551            }
552        }
553
554        let eigvals_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
555            if let GeneralizedEigenvalue::Finite(v, _) = gv {
556                *v
557            } else {
558                panic!("Unexpected indeterminate eigenvalue.")
559            }
560        });
561        let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
562
563        // Normalise the eigenvectors
564        let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
565            &eigvecs_sorted.view(),
566            &smat.view(),
567            complex_symmetric,
568            Some(thresh_offdiag),
569        )?;
570
571        // Regularise the eigenvectors
572        let eigvecs_sorted_normalised_regularised =
573            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
574
575        Ok(GeneralisedEigenvalueResult {
576            eigenvalues: eigvals_sorted,
577            eigenvectors: eigvecs_sorted_normalised_regularised,
578        })
579    }
580}
581
582// -------------------
583// Auxiliary functions
584// -------------------
585
586/// Sorts the eigenvalues and the corresponding eigenvectors.
587///
588/// # Arguments
589///
590/// * `eigvals` - The eigenvalues.
591/// * `eigvecs` - The corresponding eigenvectors.
592/// * `eigenvalue_comparison_mode` - Eigenvalue comparison mode.
593///
594/// # Returns
595///
596/// A tuple containing thw sorted eigenvalues and eigenvectors.
597fn sort_eigenvalues_eigenvectors<T: ComplexFloat>(
598    eigvals: &ArrayView1<T>,
599    eigvecs: &ArrayView2<T>,
600    eigenvalue_comparison_mode: &EigenvalueComparisonMode,
601) -> (Array1<T>, Array2<T>) {
602    let mut indices = (0..eigvals.len()).collect_vec();
603    match eigenvalue_comparison_mode {
604        EigenvalueComparisonMode::Modulus => {
605            indices.sort_by(|i, j| {
606                ComplexFloat::abs(eigvals[*i])
607                    .partial_cmp(&ComplexFloat::abs(eigvals[*j]))
608                    .unwrap()
609            });
610        }
611        EigenvalueComparisonMode::Real => {
612            indices.sort_by(|i, j| eigvals[*i].re().partial_cmp(&eigvals[*j].re()).unwrap());
613        }
614    }
615    let eigvals_sorted = eigvals.select(Axis(0), &indices);
616    let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
617    (eigvals_sorted, eigvecs_sorted)
618}
619
620/// Regularises the eigenvectors such that the first entry of each of them has a positive real
621/// part, or a positive imaginary part if the real part is zero.
622///
623/// # Arguments
624///
625/// * `eigvecs` - The eigenvectors to be regularised.
626/// * `thresh` - Threshold for determining if a real number is zero.
627///
628/// # Returns
629///
630/// The regularised eigenvectors.
631fn regularise_eigenvectors<T>(eigvecs: &ArrayView2<T>, thresh: T::Real) -> Array2<T>
632where
633    T: ComplexFloat + One,
634    T::Real: Float,
635{
636    let eigvecs_sgn = stack!(
637        Axis(0),
638        eigvecs
639            .row(0)
640            .map(|v| {
641                if Float::abs(ComplexFloat::re(*v)) > thresh {
642                    T::from(v.re().signum()).expect("Unable to convert a signum to the right type.")
643                } else if Float::abs(ComplexFloat::im(*v)) > thresh {
644                    T::from(v.im().signum()).expect("Unable to convert a signum to the right type.")
645                } else {
646                    T::one()
647                }
648            })
649            .view()
650    );
651    let eigvecs_regularised = eigvecs * eigvecs_sgn;
652    eigvecs_regularised
653}
654
655/// Normalises the real eigenvectors with respect to a metric.
656///
657/// # Arguments
658///
659/// * `eigvecs` - The eigenvectors to be normalised.
660/// * `smat` - The metric.
661/// * `thresh` - Threshold for verifying the orthogonality of the eigenvectors.
662///
663/// # Returns
664///
665/// The normalised eigenvectors.
666fn normalise_eigenvectors_real<T>(
667    eigvecs: &ArrayView2<T>,
668    smat: &ArrayView2<T>,
669    thresh: T,
670) -> Result<Array2<T>, anyhow::Error>
671where
672    T: LinalgScalar + Float + std::fmt::LowerExp,
673{
674    let sq_norm = einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
675        .map_err(|err| format_err!(err))?
676        .into_dimensionality::<Ix2>()
677        .map_err(|err| format_err!(err))?;
678    let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
679        .iter()
680        .map(|x| x.abs())
681        .max_by(|x, y| {
682            x.partial_cmp(y)
683                .expect("Unable to compare two `abs` values.")
684        })
685        .ok_or_else(|| {
686            format_err!(
687                "Unable to determine the maximum off-diagonal element of the C^T.S.C matrix."
688            )
689        })?;
690
691    ensure!(
692        max_diff <= thresh,
693        "The C^T.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thresh:.3e}."
694    );
695    ensure!(
696        sq_norm.diag().iter().all(|v| *v > T::zero()),
697        "Some eigenvectors have negative squared norms and cannot be normalised over the reals."
698    );
699    let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
700    Ok(eigvecs_normalised)
701}
702
703/// Normalises the complex eigenvectors with respect to a metric.
704///
705/// # Arguments
706///
707/// * `eigvecs` - The eigenvectors to be normalised.
708/// * `smat` - The metric.
709/// * `complex_symmetric` - Boolean indicating if the inner product is complex-symmetric or not.
710/// * `thresh` - Optioanl threshold for verifying the orthogonality of the eigenvectors. If `None`,
711/// orthogonality will not be verified.
712///
713/// # Returns
714///
715/// The normalised eigenvectors.
716fn normalise_eigenvectors_complex<T>(
717    eigvecs: &ArrayView2<T>,
718    smat: &ArrayView2<T>,
719    complex_symmetric: bool,
720    thresh: Option<T::Real>,
721) -> Result<Array2<T>, anyhow::Error>
722where
723    T: LinalgScalar + ComplexFloat + std::fmt::Display + std::fmt::LowerExp,
724    T::Real: Float + std::fmt::LowerExp,
725{
726    let sq_norm = if complex_symmetric {
727        einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
728            .map_err(|err| format_err!(err))?
729            .into_dimensionality::<Ix2>()
730            .map_err(|err| format_err!(err))?
731    } else {
732        einsum(
733            "ji,jk,kl->il",
734            &[&eigvecs.map(|v| v.conj()).view(), smat, eigvecs],
735        )
736        .map_err(|err| format_err!(err))?
737        .into_dimensionality::<Ix2>()
738        .map_err(|err| format_err!(err))?
739    };
740
741    if let Some(thr) = thresh {
742        let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
743            .iter()
744            .map(|x| ComplexFloat::abs(*x))
745            .max_by(|x, y| {
746                x.partial_cmp(y)
747                    .expect("Unable to compare two `abs` values.")
748            })
749            .ok_or_else(|| {
750                if complex_symmetric {
751                    format_err!(
752                        "Unable to determine the maximum off-diagonal element of the C^T.S.C matrix."
753                    )
754                } else {
755                    format_err!(
756                        "Unable to determine the maximum off-diagonal element of the C^†.S.C matrix."
757                    )
758                }
759            })?;
760
761        if complex_symmetric {
762            log::debug!("C^T.S.C:\n  {sq_norm:+.8e}");
763            ensure!(
764                max_diff <= thr,
765                "The C^T.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thr:.3e}."
766            )
767        } else {
768            log::debug!("C^†.S.C:\n  {sq_norm:+.8e}");
769            ensure!(
770                max_diff <= thr,
771                "The C^†.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thr:.3e}."
772            )
773        };
774    }
775    let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
776    Ok(eigvecs_normalised)
777}