qsym2/target/noci/backend/solver/
mod.rs

1use std::fmt::LowerExp;
2
3use anyhow::{self, ensure, format_err};
4use duplicate::duplicate_item;
5use itertools::Itertools;
6use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, Ix2, LinalgScalar, stack};
7use ndarray_einsum::einsum;
8use ndarray_linalg::{Eig, EigGeneralized, Eigh, GeneralizedEigenvalue, Lapack, Scalar, UPLO};
9use num::traits::FloatConst;
10use num::{Float, One};
11use num_complex::{Complex, ComplexFloat};
12use num_traits::float::TotalOrder;
13
14use crate::analysis::EigenvalueComparisonMode;
15
16use crate::io::format::qsym2_warn;
17use crate::target::noci::backend::nonortho::CanonicalOrthogonalisable;
18
19pub mod noci;
20
21#[cfg(test)]
22#[path = "solver_tests.rs"]
23mod solver_tests;
24
25// -----------------------------
26// GeneralisedEigenvalueSolvable
27// -----------------------------
28
29/// Trait to solve the generalised eigenvalue equation for a pair of square matrices $`\mathbf{A}`$
30/// and $`\mathbf{B}`$:
31/// ```math
32///     \mathbf{A} \mathbf{v} = \lambda \mathbf{B} \mathbf{v},
33/// ```
34/// where $`\mathbf{A}`$ and $`\mathbf{B}`$ are in general non-Hermitian and non-positive-definite.
35pub trait GeneralisedEigenvalueSolvable {
36    /// Numerical type of the matrix elements constituting the generalised eigenvalue problem.
37    type NumType;
38
39    /// Numerical type of the various thresholds for comparison.
40    type RealType;
41
42    /// Solves the *auxiliary* generalised eigenvalue problem
43    /// ```math
44    ///     \tilde{\mathbf{A}} \tilde{\mathbf{v}} = \tilde{\lambda} \tilde{\mathbf{B}} \tilde{\mathbf{v}},
45    /// ```
46    /// where $`\tilde{\mathbf{B}}`$ is the canonical-orthogonalised version of $`\mathbf{B}`$. If
47    /// $`\mathbf{B}`$ is not of full rank, then the two eigenvalue problems are different.
48    ///
49    /// # Arguments
50    ///
51    /// * `complex_symmetric` - Boolean indicating whether the provided pair of matrices are
52    ///   complex-symmetric.
53    /// * `thresh_offdiag` - Threshold for checking if any off-diagonal elements are non-zero when
54    ///   verifying orthogonality.
55    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of $`\mathbf{B}`$.
56    /// * `eigenvalue_comparison_mode` - Comparison mode for sorting eigenvalues and their
57    ///   corresponding eigenvectors.
58    ///
59    /// # Returns
60    ///
61    /// The generalised eigenvalue result.
62    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
63        &self,
64        complex_symmetric: bool,
65        thresh_offdiag: Self::RealType,
66        thresh_zeroov: Self::RealType,
67        eigenvalue_comparison_mode: EigenvalueComparisonMode,
68    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
69
70    /// Solves the generalised eigenvalue problem using LAPACK's `?ggev` generalised eigensolver.
71    ///
72    /// Note that this can be numerically unstable if $`\mathbf{B}`$ is not of full rank.
73    ///
74    /// # Arguments
75    ///
76    /// * `complex_symmetric` - Boolean indicating whether the provided pair of matrices are
77    ///   complex-symmetric.
78    /// * `thresh_offdiag` - Threshold for checking if any off-diagonal elements are non-zero when
79    ///   verifying orthogonality.
80    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of $`\mathbf{B}`$.
81    /// * `eigenvalue_comparison_mode` - Comparison mode for sorting eigenvalues and their
82    ///   corresponding eigenvectors.
83    ///
84    /// # Returns
85    ///
86    /// The generalised eigenvalue result.
87    fn solve_generalised_eigenvalue_problem_with_ggev(
88        &self,
89        complex_symmetric: bool,
90        thresh_offdiag: Self::RealType,
91        thresh_zeroov: Self::RealType,
92        eigenvalue_comparison_mode: EigenvalueComparisonMode,
93    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
94}
95
96/// Structure containing the eigenvalues and eigenvectors of a generalised
97/// eigenvalue problem.
98pub struct GeneralisedEigenvalueResult<T> {
99    /// The resulting eigenvalues.
100    eigenvalues: Array1<T>,
101
102    /// The corresponding eigenvectors.
103    eigenvectors: Array2<T>,
104}
105
106impl<T> GeneralisedEigenvalueResult<T> {
107    /// Returns the eigenvalues.
108    pub fn eigenvalues(&'_ self) -> ArrayView1<'_, T> {
109        self.eigenvalues.view()
110    }
111
112    /// Returns the eigenvectors.
113    pub fn eigenvectors(&'_ self) -> ArrayView2<'_, T> {
114        self.eigenvectors.view()
115    }
116}
117
118#[duplicate_item(
119    [
120        dtype_ [ f64 ]
121    ]
122    [
123        dtype_ [ f32 ]
124    ]
125)]
126impl GeneralisedEigenvalueSolvable for (&ArrayView2<'_, dtype_>, &ArrayView2<'_, dtype_>) {
127    type NumType = dtype_;
128    type RealType = dtype_;
129
130    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
131        &self,
132        _: bool,
133        thresh_offdiag: dtype_,
134        thresh_zeroov: dtype_,
135        eigenvalue_comparison_mode: EigenvalueComparisonMode,
136    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
137        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
138
139        // Symmetrise `hmat` and `smat` to improve numerical stability
140        let (hmat, smat): (Array2<dtype_>, Array2<dtype_>) = {
141            // Real, symmetric S and H
142            check_real_matrix_symmetry(&hmat.view(), thresh_offdiag, "Hamiltonian", "H")?;
143            check_real_matrix_symmetry(&smat.view(), thresh_offdiag, "Overlap", "S")?;
144
145            (
146                (hmat.to_owned() + hmat.t().to_owned()).map(|v| v / (2.0)),
147                (smat.to_owned() + smat.t().to_owned()).map(|v| v / (2.0)),
148            )
149        };
150
151        // CanonicalOrthogonalisationResult::calc_canonical_orthogonal_matrix checks for
152        // real-symmetry of S.
153        // This will fail over the reals if smat contains negative eigenvalues.
154        let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
155            true,
156            false,
157            thresh_offdiag,
158            thresh_zeroov,
159        )?;
160
161        let xmat = xmat_res.xmat();
162        let xmat_d = xmat_res.xmat_d();
163
164        let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
165        let smat_t = xmat_d.dot(&smat).dot(&xmat);
166
167        log::debug!("Canonical-orthogonalised NOCI Hamiltonian matrix H~:\n  {hmat_t:+.8e}");
168        log::debug!("Canonical-orthogonalised NOCI overlap matrix S~:\n  {smat_t:+.8e}");
169
170        // Over the reals, canonical orthogonalisation cannot handle `smat` with negative
171        // eigenvalues. This means that `smat_t` can only be the identity.
172        let (pos, max_diff) = (&smat_t - &Array2::<dtype_>::eye(smat_t.nrows()))
173            .iter()
174            .map(|x| ComplexFloat::abs(*x))
175            .enumerate()
176            .max_by(|(_, x), (_, y)| {
177                x.partial_cmp(y)
178                    .expect("Unable to compare two `abs` values.")
179            })
180            .ok_or_else(|| {
181                format_err!("Unable to determine the maximum element of the |S - I| matrix.")
182            })?;
183        let (pos_i, pos_j) = (pos.div_euclid(hmat.ncols()), pos.rem_euclid(hmat.ncols()));
184        ensure!(
185            max_diff <= thresh_offdiag,
186            "The orthogonalised overlap matrix is not the identity matrix: the maximum absolute deviation is {max_diff:.3e} > {thresh_offdiag:.3e} at ({pos_i}, {pos_j})."
187        );
188
189        let (eigvals_t, eigvecs_t) = hmat_t.eigh(UPLO::Lower)?;
190
191        // Sort the eigenvalues and eigenvectors
192        let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
193            &eigvals_t.view(),
194            &eigvecs_t.view(),
195            &eigenvalue_comparison_mode,
196        );
197        let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
198
199        // Normalise the eigenvectors
200        let eigvecs_sorted_normalised =
201            normalise_eigenvectors_real(&eigvecs_sorted.view(), &smat.view(), thresh_offdiag)?;
202
203        // Regularise the eigenvectors
204        let eigvecs_sorted_normalised_regularised =
205            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
206
207        Ok(GeneralisedEigenvalueResult {
208            eigenvalues: eigvals_t_sorted,
209            eigenvectors: eigvecs_sorted_normalised_regularised,
210        })
211    }
212
213    fn solve_generalised_eigenvalue_problem_with_ggev(
214        &self,
215        _: bool,
216        thresh_offdiag: dtype_,
217        thresh_zeroov: dtype_,
218        eigenvalue_comparison_mode: EigenvalueComparisonMode,
219    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
220        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
221
222        // Real, symmetric S and H
223        check_real_matrix_symmetry(&hmat.view(), thresh_offdiag, "Hamiltonian", "H")?;
224        check_real_matrix_symmetry(&smat.view(), thresh_offdiag, "Overlap", "S")?;
225
226        let (geneigvals, eigvecs) =
227            (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
228
229        for gv in geneigvals.iter() {
230            if let GeneralizedEigenvalue::Finite(v, _) = gv {
231                ensure!(
232                    v.im().abs() <= thresh_offdiag,
233                    "Unexpected complex eigenvalue {v} for real, symmetric S and H."
234                );
235            }
236        }
237
238        // Filter and sort the eigenvalues and eigenvectors
239        let mut indices = (0..geneigvals.len())
240            .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
241            .collect_vec();
242
243        match eigenvalue_comparison_mode {
244            EigenvalueComparisonMode::Modulus => {
245                indices.sort_by(|i, j| {
246                    if let (
247                        GeneralizedEigenvalue::Finite(e_i, _),
248                        GeneralizedEigenvalue::Finite(e_j, _),
249                    ) = (&geneigvals[*i], &geneigvals[*j])
250                    {
251                        ComplexFloat::abs(*e_i)
252                            .partial_cmp(&ComplexFloat::abs(*e_j))
253                            .unwrap()
254                    } else {
255                        panic!("Unable to compare some eigenvalues.")
256                    }
257                });
258            }
259            EigenvalueComparisonMode::Real => {
260                indices.sort_by(|i, j| {
261                    if let (
262                        GeneralizedEigenvalue::Finite(e_i, _),
263                        GeneralizedEigenvalue::Finite(e_j, _),
264                    ) = (&geneigvals[*i], &geneigvals[*j])
265                    {
266                        e_i.re().partial_cmp(&e_j.re()).unwrap()
267                    } else {
268                        panic!("Unable to compare some eigenvalues.")
269                    }
270                });
271            }
272        }
273
274        let eigvals_re_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
275            if let GeneralizedEigenvalue::Finite(v, _) = gv {
276                v.re()
277            } else {
278                panic!("Unexpected indeterminate eigenvalue.")
279            }
280        });
281        let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
282        ensure!(
283            eigvecs_sorted.iter().all(|v| v.im().abs() < thresh_offdiag),
284            "Unexpected complex eigenvectors."
285        );
286        let eigvecs_re_sorted = eigvecs_sorted.map(|v| v.re());
287
288        // Normalise the eigenvectors
289        let eigvecs_re_sorted_normalised =
290            normalise_eigenvectors_real(&eigvecs_re_sorted.view(), &smat.view(), thresh_offdiag)?;
291
292        // Regularise the eigenvectors
293        let eigvecs_re_sorted_normalised_regularised =
294            regularise_eigenvectors(&eigvecs_re_sorted_normalised.view(), thresh_offdiag);
295
296        Ok(GeneralisedEigenvalueResult {
297            eigenvalues: eigvals_re_sorted,
298            eigenvectors: eigvecs_re_sorted_normalised_regularised,
299        })
300    }
301}
302
303impl<T> GeneralisedEigenvalueSolvable for (&ArrayView2<'_, Complex<T>>, &ArrayView2<'_, Complex<T>>)
304where
305    T: Float + FloatConst + Scalar<Complex = Complex<T>>,
306    Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
307    for<'a> ArrayView2<'a, Complex<T>>:
308        CanonicalOrthogonalisable<NumType = Complex<T>, RealType = T>,
309{
310    type NumType = Complex<T>;
311
312    type RealType = T;
313
314    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
315        &self,
316        complex_symmetric: bool,
317        thresh_offdiag: T,
318        thresh_zeroov: T,
319        eigenvalue_comparison_mode: EigenvalueComparisonMode,
320    ) -> Result<GeneralisedEigenvalueResult<Complex<T>>, anyhow::Error> {
321        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
322
323        // Symmetrise `hmat` and `smat` to improve numerical stability
324        let (hmat, smat): (Array2<Complex<T>>, Array2<Complex<T>>) = if complex_symmetric {
325            // Complex-symmetric
326            check_complex_matrix_symmetry(&hmat.view(), complex_symmetric, thresh_offdiag, "Hamiltonian", "H")?;
327            check_complex_matrix_symmetry(&smat.view(), complex_symmetric, thresh_offdiag, "Overlap", "S")?;
328            (
329                (hmat.to_owned() + hmat.t().to_owned())
330                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
331                (smat.to_owned() + smat.t().to_owned())
332                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
333            )
334        } else {
335            // Complex-Hermitian
336            check_complex_matrix_symmetry(&hmat.view(), complex_symmetric, thresh_offdiag, "Hamiltonian", "H")?;
337            check_complex_matrix_symmetry(&smat.view(), complex_symmetric, thresh_offdiag, "Overlap", "S")?;
338            (
339                (hmat.to_owned() + hmat.map(|v| v.conj()).t().to_owned())
340                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
341                (smat.to_owned() + smat.map(|v| v.conj()).t().to_owned())
342                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one())),
343            )
344        };
345
346        // CanonicalOrthogonalisationResult::calc_canonical_orthogonal_matrix checks for
347        // complex-symmetry or complex-Hermiticity of S.
348        let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
349            complex_symmetric,
350            false,
351            thresh_offdiag,
352            thresh_zeroov,
353        )?;
354
355        let xmat = xmat_res.xmat();
356        let xmat_d = xmat_res.xmat_d();
357
358        let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
359        let smat_t = xmat_d.dot(&smat).dot(&xmat);
360
361        // Symmetrise `hmat_t` and `smat_t` to improve numerical stability
362        let (hmat_t_sym, smat_t_sym): (Array2<Complex<T>>, Array2<Complex<T>>) =
363            if complex_symmetric {
364                // Complex-symmetric
365                check_complex_matrix_symmetry(&hmat_t.view(), complex_symmetric, thresh_offdiag, "Transformed Hamiltonian", "H~")?;
366                check_complex_matrix_symmetry(&smat_t.view(), complex_symmetric, thresh_offdiag, "Transformed Overlap", "S~")?;
367                let hmat_t_s = (hmat_t.to_owned() + hmat_t.t().to_owned())
368                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
369                let smat_t_s = (smat_t.to_owned() + smat_t.t().to_owned())
370                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
371                (hmat_t_s, smat_t_s)
372            } else {
373                // Complex-Hermitian
374                check_complex_matrix_symmetry(&hmat_t.view(), complex_symmetric, thresh_offdiag, "Transformed Hamiltonian", "H~")?;
375                check_complex_matrix_symmetry(&smat_t.view(), complex_symmetric, thresh_offdiag, "Transformed Overlap", "S~")?;
376                let hmat_t_s = (hmat_t.to_owned() + hmat_t.map(|v| v.conj()).t().to_owned())
377                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
378                let smat_t_s = (smat_t.to_owned() + smat_t.map(|v| v.conj()).t().to_owned())
379                    .map(|v| v / (Complex::<T>::one() + Complex::<T>::one()));
380                (hmat_t_s, smat_t_s)
381            };
382        let smat_t_sym_d = smat_t_sym.map(|v| v.conj()).t().to_owned();
383        log::debug!("Complex-symmetric? {complex_symmetric}");
384        log::debug!("Canonical orthogonalisation X matrix:\n  {xmat:+.8e}");
385        log::debug!("Canonical-orthogonalised NOCI Hamiltonian matrix H~:\n  {hmat_t_sym:+.8e}");
386        log::debug!("Canonical-orthogonalised NOCI overlap matrix S~:\n  {smat_t_sym:+.8e}");
387
388        // smat_t_sym is not necessarily the identity, but is guaranteed to be Hermitian.
389        let max_diff = (&smat_t_sym_d.dot(&smat_t_sym) - &Array2::<T>::eye(smat_t_sym.nrows()))
390            .iter()
391            .map(|x| ComplexFloat::abs(*x))
392            .max_by(|x, y| {
393                x.partial_cmp(y)
394                    .expect("Unable to compare two `abs` values.")
395            })
396            .ok_or_else(|| {
397                format_err!("Unable to determine the maximum element of the |S^†.S - I| matrix.")
398            })?;
399        ensure!(
400            max_diff <= thresh_offdiag,
401            "The S^†.S matrix is not the identity matrix. S is therefore not Hermitian."
402        );
403        let smat_t_sym_d_hmat_t_sym = smat_t_sym_d.dot(&hmat_t_sym);
404        log::debug!(
405            "Hamiltonian matrix for diagonalisation (S~)^†.(H~):\n  {smat_t_sym_d_hmat_t_sym:+.8e}"
406        );
407
408        let (eigvals_t, eigvecs_t) = smat_t_sym_d_hmat_t_sym.eig()?;
409
410        // Sort the eigenvalues and eigenvectors
411        let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
412            &eigvals_t.view(),
413            &eigvecs_t.view(),
414            &eigenvalue_comparison_mode,
415        );
416        log::debug!("Sorted eigenvalues of (S~)^†.(H~):");
417        for (i, eigval) in eigvals_t_sorted.iter().enumerate() {
418            log::debug!("  {i}: {eigval:+.8e}");
419        }
420        log::debug!("");
421        log::debug!("Sorted eigenvectors of (S~)^†.(H~):\n  {eigvecs_t_sorted:+.8e}");
422        log::debug!("");
423
424        // Check orthogonality
425        // let _ = normalise_eigenvectors_complex(
426        //     &eigvecs_t.view(),
427        //     &smat_t.view(),
428        //     complex_symmetric,
429        //     Some(thresh_offdiag),
430        // )?;
431
432        let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
433
434        // Normalise the eigenvectors
435        let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
436            &eigvecs_sorted.view(),
437            &smat.view(),
438            complex_symmetric,
439            None,
440        )?;
441
442        // Regularise the eigenvectors
443        let eigvecs_sorted_normalised_regularised =
444            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
445
446        Ok(GeneralisedEigenvalueResult {
447            eigenvalues: eigvals_t_sorted,
448            eigenvectors: eigvecs_sorted_normalised_regularised,
449        })
450    }
451
452    fn solve_generalised_eigenvalue_problem_with_ggev(
453        &self,
454        complex_symmetric: bool,
455        thresh_offdiag: T,
456        thresh_zeroov: T,
457        eigenvalue_comparison_mode: EigenvalueComparisonMode,
458    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
459        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
460
461        check_complex_matrix_symmetry(&hmat.view(), complex_symmetric, thresh_offdiag, "Hamiltonian", "H")?;
462        check_complex_matrix_symmetry(&smat.view(), complex_symmetric, thresh_offdiag, "Overlap", "S")?;
463
464        let (geneigvals, eigvecs) =
465            (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
466
467        // Filter and sort the eigenvalues and eigenvectors
468        let mut indices = (0..geneigvals.len())
469            .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
470            .collect_vec();
471
472        match eigenvalue_comparison_mode {
473            EigenvalueComparisonMode::Modulus => {
474                indices.sort_by(|i, j| {
475                    if let (
476                        GeneralizedEigenvalue::Finite(e_i, _),
477                        GeneralizedEigenvalue::Finite(e_j, _),
478                    ) = (&geneigvals[*i], &geneigvals[*j])
479                    {
480                        ComplexFloat::abs(*e_i)
481                            .partial_cmp(&ComplexFloat::abs(*e_j))
482                            .unwrap()
483                    } else {
484                        panic!("Unable to compare some eigenvalues.")
485                    }
486                });
487            }
488            EigenvalueComparisonMode::Real => {
489                indices.sort_by(|i, j| {
490                    if let (
491                        GeneralizedEigenvalue::Finite(e_i, _),
492                        GeneralizedEigenvalue::Finite(e_j, _),
493                    ) = (&geneigvals[*i], &geneigvals[*j])
494                    {
495                        e_i.re().partial_cmp(&e_j.re()).unwrap()
496                    } else {
497                        panic!("Unable to compare some eigenvalues.")
498                    }
499                });
500            }
501        }
502
503        let eigvals_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
504            if let GeneralizedEigenvalue::Finite(v, _) = gv {
505                *v
506            } else {
507                panic!("Unexpected indeterminate eigenvalue.")
508            }
509        });
510        let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
511
512        // Normalise the eigenvectors
513        let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
514            &eigvecs_sorted.view(),
515            &smat.view(),
516            complex_symmetric,
517            Some(thresh_offdiag),
518        )?;
519
520        // Regularise the eigenvectors
521        let eigvecs_sorted_normalised_regularised =
522            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
523
524        Ok(GeneralisedEigenvalueResult {
525            eigenvalues: eigvals_sorted,
526            eigenvectors: eigvecs_sorted_normalised_regularised,
527        })
528    }
529}
530
531// -------------------
532// Auxiliary functions
533// -------------------
534
535/// Sorts the eigenvalues and the corresponding eigenvectors.
536///
537/// # Arguments
538///
539/// * `eigvals` - The eigenvalues.
540/// * `eigvecs` - The corresponding eigenvectors.
541/// * `eigenvalue_comparison_mode` - Eigenvalue comparison mode.
542///
543/// # Returns
544///
545/// A tuple containing thw sorted eigenvalues and eigenvectors.
546fn sort_eigenvalues_eigenvectors<T: ComplexFloat>(
547    eigvals: &ArrayView1<T>,
548    eigvecs: &ArrayView2<T>,
549    eigenvalue_comparison_mode: &EigenvalueComparisonMode,
550) -> (Array1<T>, Array2<T>) {
551    let mut indices = (0..eigvals.len()).collect_vec();
552    match eigenvalue_comparison_mode {
553        EigenvalueComparisonMode::Modulus => {
554            indices.sort_by(|i, j| {
555                ComplexFloat::abs(eigvals[*i])
556                    .partial_cmp(&ComplexFloat::abs(eigvals[*j]))
557                    .unwrap()
558            });
559        }
560        EigenvalueComparisonMode::Real => {
561            indices.sort_by(|i, j| eigvals[*i].re().partial_cmp(&eigvals[*j].re()).unwrap());
562        }
563    }
564    let eigvals_sorted = eigvals.select(Axis(0), &indices);
565    let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
566    (eigvals_sorted, eigvecs_sorted)
567}
568
569/// Regularises the eigenvectors such that the first entry of each of them has a positive real
570/// part, or a positive imaginary part if the real part is zero.
571///
572/// # Arguments
573///
574/// * `eigvecs` - The eigenvectors to be regularised.
575/// * `thresh` - Threshold for determining if a real number is zero.
576///
577/// # Returns
578///
579/// The regularised eigenvectors.
580fn regularise_eigenvectors<T>(eigvecs: &ArrayView2<T>, thresh: T::Real) -> Array2<T>
581where
582    T: ComplexFloat + One,
583    T::Real: Float,
584{
585    let eigvecs_sgn = stack!(
586        Axis(0),
587        eigvecs
588            .row(0)
589            .map(|v| {
590                if Float::abs(ComplexFloat::re(*v)) > thresh {
591                    T::from(v.re().signum()).expect("Unable to convert a signum to the right type.")
592                } else if Float::abs(ComplexFloat::im(*v)) > thresh {
593                    T::from(v.im().signum()).expect("Unable to convert a signum to the right type.")
594                } else {
595                    T::one()
596                }
597            })
598            .view()
599    );
600    eigvecs * eigvecs_sgn
601}
602
603/// Normalises the real eigenvectors with respect to a metric.
604///
605/// # Arguments
606///
607/// * `eigvecs` - The eigenvectors to be normalised.
608/// * `smat` - The metric.
609/// * `thresh` - Threshold for verifying the orthogonality of the eigenvectors.
610///
611/// # Returns
612///
613/// The normalised eigenvectors.
614fn normalise_eigenvectors_real<T>(
615    eigvecs: &ArrayView2<T>,
616    smat: &ArrayView2<T>,
617    thresh: T,
618) -> Result<Array2<T>, anyhow::Error>
619where
620    T: LinalgScalar + Float + std::fmt::LowerExp,
621{
622    let sq_norm = einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
623        .map_err(|err| format_err!(err))?
624        .into_dimensionality::<Ix2>()
625        .map_err(|err| format_err!(err))?;
626    let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
627        .iter()
628        .map(|x| x.abs())
629        .max_by(|x, y| {
630            x.partial_cmp(y)
631                .expect("Unable to compare two `abs` values.")
632        })
633        .ok_or_else(|| {
634            format_err!(
635                "Unable to determine the maximum off-diagonal element of the C^T.S.C matrix."
636            )
637        })?;
638
639    ensure!(
640        max_diff <= thresh,
641        "The C^T.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thresh:.3e}."
642    );
643    ensure!(
644        sq_norm.diag().iter().all(|v| *v > T::zero()),
645        "Some eigenvectors have negative squared norms and cannot be normalised over the reals."
646    );
647    let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
648    Ok(eigvecs_normalised)
649}
650
651/// Normalises the complex eigenvectors with respect to a metric.
652///
653/// # Arguments
654///
655/// * `eigvecs` - The eigenvectors to be normalised.
656/// * `smat` - The metric.
657/// * `complex_symmetric` - Boolean indicating if the inner product is complex-symmetric or not.
658/// * `thresh` - Optioanl threshold for verifying the orthogonality of the eigenvectors. If `None`,
659///   orthogonality will not be verified.
660///
661/// # Returns
662///
663/// The normalised eigenvectors.
664fn normalise_eigenvectors_complex<T>(
665    eigvecs: &ArrayView2<T>,
666    smat: &ArrayView2<T>,
667    complex_symmetric: bool,
668    thresh: Option<T::Real>,
669) -> Result<Array2<T>, anyhow::Error>
670where
671    T: LinalgScalar + ComplexFloat + std::fmt::Display + std::fmt::LowerExp,
672    T::Real: Float + std::fmt::LowerExp,
673{
674    let sq_norm = if complex_symmetric {
675        einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
676            .map_err(|err| format_err!(err))?
677            .into_dimensionality::<Ix2>()
678            .map_err(|err| format_err!(err))?
679    } else {
680        einsum(
681            "ji,jk,kl->il",
682            &[&eigvecs.map(|v| v.conj()).view(), smat, eigvecs],
683        )
684        .map_err(|err| format_err!(err))?
685        .into_dimensionality::<Ix2>()
686        .map_err(|err| format_err!(err))?
687    };
688
689    if let Some(thr) = thresh {
690        let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
691            .iter()
692            .map(|x| ComplexFloat::abs(*x))
693            .max_by(|x, y| {
694                x.partial_cmp(y)
695                    .expect("Unable to compare two `abs` values.")
696            })
697            .ok_or_else(|| {
698                if complex_symmetric {
699                    format_err!(
700                        "Unable to determine the maximum off-diagonal element of the C^T.S.C matrix."
701                    )
702                } else {
703                    format_err!(
704                        "Unable to determine the maximum off-diagonal element of the C^†.S.C matrix."
705                    )
706                }
707            })?;
708
709        if complex_symmetric {
710            log::debug!("C^T.S.C:\n  {sq_norm:+.8e}");
711            ensure!(
712                max_diff <= thr,
713                "The C^T.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thr:.3e}."
714            )
715        } else {
716            log::debug!("C^†.S.C:\n  {sq_norm:+.8e}");
717            ensure!(
718                max_diff <= thr,
719                "The C^†.S.C matrix is not a diagonal matrix: the maximum absolute value of the off-diagonal elements is {max_diff:.3e} > {thr:.3e}."
720            )
721        };
722    }
723    let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
724    Ok(eigvecs_normalised)
725}
726
727/// Checks for complex-symmetric or complex-Hermitian symmetry of a complex square matrix.
728///
729/// # Arguments
730///
731/// * `mat` - The complex square matrix to be checked.
732/// * `complex_symmetric` - Boolean indicating if complex-symmetry is to be checked instead of
733///   complex-Hermiticity.
734/// * `thresh_offdiag` - Threshold for checking.
735/// * `matname` - Name of the matrix.
736/// * `matsymbol` - Symbol of the matrix.
737pub(crate) fn check_complex_matrix_symmetry<T>(
738    mat: &ArrayView2<T>,
739    complex_symmetric: bool,
740    thresh_offdiag: <T as ComplexFloat>::Real,
741    matname: &str,
742    matsymbol: &str,
743) -> Result<(), anyhow::Error>
744where
745    T: LinalgScalar + ComplexFloat + std::fmt::Display + std::fmt::LowerExp,
746    <T as ComplexFloat>::Real: Float + std::fmt::LowerExp + std::fmt::Display,
747{
748    if complex_symmetric {
749        let deviation = mat.to_owned() - mat.t();
750        let (pos, &max_offdiag) = deviation
751            .mapv(ComplexFloat::abs)
752            .iter()
753            .enumerate()
754            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
755            .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the {matname} complex-symmetric deviation matrix."))?;
756        let (pos_i, pos_j) = (pos.div_euclid(mat.ncols()), pos.rem_euclid(mat.ncols()));
757        log::debug!("{matname} matrix:\n  {mat:+.3e}");
758        log::debug!("{matname} matrix complex-symmetric deviation:\n  {deviation:+.3e}",);
759        qsym2_warn!("{matname} matrix complex-symmetric deviation:\n  {deviation:+.3e}",);
760        ensure!(
761            max_offdiag <= thresh_offdiag,
762            "The {matname} matrix is not complex-symmetric: ||{matsymbol} - ({matsymbol})^T||_∞ = {max_offdiag:.3e} > {thresh_offdiag:.3e} at ({pos_i}, {pos_j})."
763        );
764    } else {
765        let deviation = mat.to_owned() - mat.map(|v| v.conj()).t();
766        let (pos, &max_offdiag) = deviation
767                .mapv(ComplexFloat::abs)
768                .iter()
769                .enumerate()
770                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
771                .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the {matname} complex-Hermitian deviation matrix."))?;
772        let (pos_i, pos_j) = (pos.div_euclid(mat.ncols()), pos.rem_euclid(mat.ncols()));
773        log::debug!("{matname} matrix:\n  {mat:+.3e}");
774        log::debug!("{matname} matrix complex-Hermitian deviation:\n  {deviation:+.3e}",);
775        qsym2_warn!("{matname} matrix complex-Hermitian deviation:\n  {deviation:+.3e}",);
776        ensure!(
777            max_offdiag <= thresh_offdiag,
778            "The {matname} matrix is not complex-Hermitian: ||{matsymbol} - ({matsymbol})^†||_∞ = {max_offdiag:.3e} > {thresh_offdiag:.3e} at ({pos_i}, {pos_j})."
779        );
780    }
781    Ok(())
782}
783
784/// Checks for real-symmetric symmetry of a complex square matrix.
785///
786/// # Arguments
787///
788/// * `mat` - The real square matrix to be checked.
789/// * `thresh_offdiag` - Threshold for checking.
790/// * `matname` - Name of the matrix.
791/// * `matsymbol` - Symbol of the matrix.
792pub(crate) fn check_real_matrix_symmetry<T>(
793    mat: &ArrayView2<T>,
794    thresh_offdiag: T,
795    matname: &str,
796    matsymbol: &str,
797) -> Result<(), anyhow::Error>
798where
799    T: LowerExp + Clone + LinalgScalar + Float + TotalOrder,
800{
801    let deviation = mat.to_owned() - mat.t();
802    let (pos, &max_offdiag) = deviation
803        .map(|v| v.abs())
804        .iter()
805        .enumerate()
806        .max_by(|(_, a), (_, b)| a.total_cmp(b))
807        .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the {matname} real-symmetric deviation matrix."))?;
808    log::debug!("{matname} matrix:\n  {mat:+.3e}");
809    log::debug!("{matname} matrix real-symmetric deviation:\n  {deviation:+.3e}",);
810    qsym2_warn!("{matname} matrix real-symmetric deviation:\n  {deviation:+.3e}",);
811    let (pos_i, pos_j) = (pos.div_euclid(mat.ncols()), pos.rem_euclid(mat.ncols()));
812    ensure!(
813        max_offdiag <= thresh_offdiag,
814        "{matname} matrix is not real-symmetric: ||{matsymbol} - ({matsymbol})^T||_∞ = {max_offdiag:.3e} > {thresh_offdiag:.3e} at ({pos_i}, {pos_j})."
815    );
816    Ok(())
817}