qsym2/target/noci/backend/solver/
mod.rs

1use anyhow::{self, ensure, format_err};
2use duplicate::duplicate_item;
3use itertools::Itertools;
4use ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis, Ix2, LinalgScalar};
5use ndarray_einsum::einsum;
6use ndarray_linalg::{
7    Eig, EigGeneralized, Eigh, GeneralizedEigenvalue, Lapack, Norm, Scalar, UPLO,
8};
9use num::traits::FloatConst;
10use num::{Float, One};
11use num_complex::{Complex, ComplexFloat};
12
13use crate::analysis::EigenvalueComparisonMode;
14
15use crate::target::noci::backend::nonortho::CanonicalOrthogonalisable;
16
17pub mod noci;
18
19#[cfg(test)]
20#[path = "solver_tests.rs"]
21mod solver_tests;
22
23// -----------------------------
24// GeneralisedEigenvalueSolvable
25// -----------------------------
26
27/// Trait to solve the generalised eigenvalue equation for a pair of square matrices $`\mathbf{A}`$
28/// and $`\mathbf{B}`$:
29/// ```math
30///     \mathbf{A} \mathbf{v} = \lambda \mathbf{B} \mathbf{v},
31/// ```
32/// where $`\mathbf{A}`$ and $`\mathbf{B}`$ are in general non-Hermitian and non-positive-definite.
33pub trait GeneralisedEigenvalueSolvable {
34    /// Numerical type of the matrix elements constituting the generalised eigenvalue problem.
35    type NumType;
36
37    /// Numerical type of the various thresholds for comparison.
38    type RealType;
39
40    /// Solves the *auxiliary* generalised eigenvalue problem
41    /// ```math
42    ///     \tilde{\mathbf{A}} \tilde{\mathbf{v}} = \tilde{\lambda} \tilde{\mathbf{B}} \tilde{\mathbf{v}},
43    /// ```
44    /// where $`\tilde{\mathbf{B}}`$ is the canonical-orthogonalised version of $`\mathbf{B}`$. If
45    /// $`\mathbf{B}`$ is not of full rank, then the two eigenvalue problems are different.
46    ///
47    /// # Arguments
48    ///
49    /// * `complex_symmetric` - Boolean indicating whether the provided pair of matrices are
50    /// complex-symmetric.
51    /// * `thresh_offdiag` - Threshold for checking if any off-diagonal elements are non-zero when
52    /// verifying orthogonality.
53    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of $`\mathbf{B}`$.
54    /// * `eigenvalue_comparison_mode` - Comparison mode for sorting eigenvalues and their
55    /// corresponding eigenvectors.
56    ///
57    /// # Returns
58    ///
59    /// The generalised eigenvalue result.
60    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
61        &self,
62        complex_symmetric: bool,
63        thresh_offdiag: Self::RealType,
64        thresh_zeroov: Self::RealType,
65        eigenvalue_comparison_mode: EigenvalueComparisonMode,
66    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
67
68    /// Solves the generalised eigenvalue problem using LAPACK's `?ggev` generalised eigensolver.
69    ///
70    /// Note that this can be numerically unstable if $`\mathbf{B}`$ is not of full rank.
71    ///
72    /// # Arguments
73    ///
74    /// * `complex_symmetric` - Boolean indicating whether the provided pair of matrices are
75    /// complex-symmetric.
76    /// * `thresh_offdiag` - Threshold for checking if any off-diagonal elements are non-zero when
77    /// verifying orthogonality.
78    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of $`\mathbf{B}`$.
79    /// * `eigenvalue_comparison_mode` - Comparison mode for sorting eigenvalues and their
80    /// corresponding eigenvectors.
81    ///
82    /// # Returns
83    ///
84    /// The generalised eigenvalue result.
85    fn solve_generalised_eigenvalue_problem_with_ggev(
86        &self,
87        complex_symmetric: bool,
88        thresh_offdiag: Self::RealType,
89        thresh_zeroov: Self::RealType,
90        eigenvalue_comparison_mode: EigenvalueComparisonMode,
91    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
92}
93
94/// Structure containing the eigenvalues and eigenvectors of a generalised
95/// eigenvalue problem.
96pub struct GeneralisedEigenvalueResult<T> {
97    /// The resulting eigenvalues.
98    eigenvalues: Array1<T>,
99
100    /// The corresponding eigenvectors.
101    eigenvectors: Array2<T>,
102}
103
104impl<T> GeneralisedEigenvalueResult<T> {
105    /// Returns the eigenvalues.
106    pub fn eigenvalues(&self) -> ArrayView1<T> {
107        self.eigenvalues.view()
108    }
109
110    /// Returns the eigenvectors.
111    pub fn eigenvectors(&self) -> ArrayView2<T> {
112        self.eigenvectors.view()
113    }
114}
115
116#[duplicate_item(
117    [
118        dtype_ [ f64 ]
119    ]
120    [
121        dtype_ [ f32 ]
122    ]
123)]
124impl GeneralisedEigenvalueSolvable for (&ArrayView2<'_, dtype_>, &ArrayView2<'_, dtype_>) {
125    type NumType = dtype_;
126    type RealType = dtype_;
127
128    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
129        &self,
130        _: bool,
131        thresh_offdiag: dtype_,
132        thresh_zeroov: dtype_,
133        eigenvalue_comparison_mode: EigenvalueComparisonMode,
134    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
135        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
136
137        // Real, symmetric S and H
138        ensure!(
139            (hmat.to_owned() - hmat.t()).norm_l2() <= thresh_offdiag,
140            "Hamiltonian matrix is not real-symmetric."
141        );
142
143        // CanonicalOrthogonalisationResult::calc_canonical_orthogonal_matrix checks for
144        // real-symmetry of S.
145        // This will fail over the reals if smat contains negative eigenvalues.
146        let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
147            true,
148            false,
149            thresh_offdiag,
150            thresh_zeroov,
151        )?;
152
153        let xmat = xmat_res.xmat();
154        let xmat_d = xmat_res.xmat_d();
155
156        let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
157        let smat_t = xmat_d.dot(&smat).dot(&xmat);
158
159        // Over the reals, canonical orthogonalisation cannot handle `smat` with negative
160        // eigenvalues. This means that `smat_t` can only be the identity.
161        let max_diff = (&smat_t - &Array2::<dtype_>::eye(smat_t.nrows()))
162            .iter()
163            .map(|x| ComplexFloat::abs(*x))
164            .max_by(|x, y| {
165                x.partial_cmp(y)
166                    .expect("Unable to compare two `abs` values.")
167            })
168            .ok_or_else(|| {
169                format_err!("Unable to determine the maximum element of the |S - I| matrix.")
170            })?;
171        ensure!(
172            max_diff <= thresh_offdiag,
173            "The orthogonalised overlap matrix is not the identity matrix."
174        );
175
176        let (eigvals_t, eigvecs_t) = hmat_t.eigh(UPLO::Lower)?;
177
178        // Sort the eigenvalues and eigenvectors
179        let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
180            &eigvals_t.view(),
181            &eigvecs_t.view(),
182            &eigenvalue_comparison_mode,
183        );
184        let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
185
186        // Normalise the eigenvectors
187        let eigvecs_sorted_normalised =
188            normalise_eigenvectors_real(&eigvecs_sorted.view(), &smat.view(), thresh_offdiag)?;
189
190        // Regularise the eigenvectors
191        let eigvecs_sorted_normalised_regularised =
192            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
193
194        Ok(GeneralisedEigenvalueResult {
195            eigenvalues: eigvals_t_sorted,
196            eigenvectors: eigvecs_sorted_normalised_regularised,
197        })
198    }
199
200    fn solve_generalised_eigenvalue_problem_with_ggev(
201        &self,
202        _: bool,
203        thresh_offdiag: dtype_,
204        thresh_zeroov: dtype_,
205        eigenvalue_comparison_mode: EigenvalueComparisonMode,
206    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
207        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
208
209        // Real, symmetric S and H
210        ensure!(
211            (hmat.to_owned() - hmat.t()).norm_l2() <= thresh_offdiag,
212            "Hamiltonian matrix is not real-symmetric."
213        );
214        ensure!(
215            (smat.to_owned() - smat.t()).norm_l2() <= thresh_offdiag,
216            "Overlap matrix is not real-symmetric."
217        );
218
219        let (geneigvals, eigvecs) =
220            (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
221
222        for gv in geneigvals.iter() {
223            if let GeneralizedEigenvalue::Finite(v, _) = gv {
224                ensure!(
225                    v.im().abs() <= thresh_offdiag,
226                    "Unexpected complex eigenvalue {v} for real, symmetric S and H."
227                );
228            }
229        }
230
231        // Filter and sort the eigenvalues and eigenvectors
232        let mut indices = (0..geneigvals.len())
233            .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
234            .collect_vec();
235
236        match eigenvalue_comparison_mode {
237            EigenvalueComparisonMode::Modulus => {
238                indices.sort_by(|i, j| {
239                    if let (
240                        GeneralizedEigenvalue::Finite(e_i, _),
241                        GeneralizedEigenvalue::Finite(e_j, _),
242                    ) = (&geneigvals[*i], &geneigvals[*j])
243                    {
244                        ComplexFloat::abs(*e_i)
245                            .partial_cmp(&ComplexFloat::abs(*e_j))
246                            .unwrap()
247                    } else {
248                        panic!("Unable to compare some eigenvalues.")
249                    }
250                });
251            }
252            EigenvalueComparisonMode::Real => {
253                indices.sort_by(|i, j| {
254                    if let (
255                        GeneralizedEigenvalue::Finite(e_i, _),
256                        GeneralizedEigenvalue::Finite(e_j, _),
257                    ) = (&geneigvals[*i], &geneigvals[*j])
258                    {
259                        e_i.re().partial_cmp(&e_j.re()).unwrap()
260                    } else {
261                        panic!("Unable to compare some eigenvalues.")
262                    }
263                });
264            }
265        }
266
267        let eigvals_re_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
268            if let GeneralizedEigenvalue::Finite(v, _) = gv {
269                v.re()
270            } else {
271                panic!("Unexpected indeterminate eigenvalue.")
272            }
273        });
274        let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
275        ensure!(
276            eigvecs_sorted.iter().all(|v| v.im().abs() < thresh_offdiag),
277            "Unexpected complex eigenvectors."
278        );
279        let eigvecs_re_sorted = eigvecs_sorted.map(|v| v.re());
280
281        // Normalise the eigenvectors
282        let eigvecs_re_sorted_normalised =
283            normalise_eigenvectors_real(&eigvecs_re_sorted.view(), &smat.view(), thresh_offdiag)?;
284
285        // Regularise the eigenvectors
286        let eigvecs_re_sorted_normalised_regularised =
287            regularise_eigenvectors(&eigvecs_re_sorted_normalised.view(), thresh_offdiag);
288
289        Ok(GeneralisedEigenvalueResult {
290            eigenvalues: eigvals_re_sorted,
291            eigenvectors: eigvecs_re_sorted_normalised_regularised,
292        })
293    }
294}
295
296impl<T> GeneralisedEigenvalueSolvable for (&ArrayView2<'_, Complex<T>>, &ArrayView2<'_, Complex<T>>)
297where
298    T: Float + FloatConst + Scalar<Complex = Complex<T>>,
299    Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
300    for<'a> ArrayView2<'a, Complex<T>>:
301        CanonicalOrthogonalisable<NumType = Complex<T>, RealType = T>,
302{
303    type NumType = Complex<T>;
304
305    type RealType = T;
306
307    fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
308        &self,
309        complex_symmetric: bool,
310        thresh_offdiag: T,
311        thresh_zeroov: T,
312        eigenvalue_comparison_mode: EigenvalueComparisonMode,
313    ) -> Result<GeneralisedEigenvalueResult<Complex<T>>, anyhow::Error> {
314        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
315
316        if complex_symmetric {
317            // Complex-symmetric H
318            ensure!(
319                (hmat.to_owned() - hmat.t()).norm_l2() <= thresh_offdiag,
320                "Hamiltonian matrix is not complex-symmetric."
321            );
322        } else {
323            // Complex-Hermitian H
324            ensure!(
325                (hmat.to_owned() - hmat.map(|v| v.conj()).t()).norm_l2() <= thresh_offdiag,
326                "Hamiltonian matrix is not complex-Hermitian."
327            );
328        }
329
330        // CanonicalOrthogonalisationResult::calc_canonical_orthogonal_matrix checks for
331        // complex-symmetry or complex-Hermiticity of S.
332        let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
333            complex_symmetric,
334            false,
335            thresh_offdiag,
336            thresh_zeroov,
337        )?;
338
339        let xmat = xmat_res.xmat();
340        let xmat_d = xmat_res.xmat_d();
341
342        let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
343        let smat_t = xmat_d.dot(&smat).dot(&xmat);
344        let smat_t_d = smat_t.map(|v| v.conj()).t().to_owned();
345
346        // smat_t is not necessarily the identity, but is guaranteed to be Hermitian.
347        let max_diff = (&smat_t_d.dot(&smat_t) - &Array2::<T>::eye(smat_t.nrows()))
348            .iter()
349            .map(|x| ComplexFloat::abs(*x))
350            .max_by(|x, y| {
351                x.partial_cmp(y)
352                    .expect("Unable to compare two `abs` values.")
353            })
354            .ok_or_else(|| {
355                format_err!("Unable to determine the maximum element of the |S - I| matrix.")
356            })?;
357        ensure!(
358            max_diff <= thresh_offdiag,
359            "The orthogonalised overlap matrix is not the identity matrix."
360        );
361        let smat_t_d_hmat_t = smat_t_d.dot(&hmat_t);
362
363        let (eigvals_t, eigvecs_t) = smat_t_d_hmat_t.eig()?;
364
365        // Sort the eigenvalues and eigenvectors
366        let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
367            &eigvals_t.view(),
368            &eigvecs_t.view(),
369            &eigenvalue_comparison_mode,
370        );
371        let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
372
373        // Normalise the eigenvectors
374        let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
375            &eigvecs_sorted.view(),
376            &smat.view(),
377            complex_symmetric,
378            thresh_offdiag,
379        )?;
380
381        // Regularise the eigenvectors
382        let eigvecs_sorted_normalised_regularised =
383            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
384
385        Ok(GeneralisedEigenvalueResult {
386            eigenvalues: eigvals_t_sorted,
387            eigenvectors: eigvecs_sorted_normalised_regularised,
388        })
389    }
390
391    fn solve_generalised_eigenvalue_problem_with_ggev(
392        &self,
393        complex_symmetric: bool,
394        thresh_offdiag: T,
395        thresh_zeroov: T,
396        eigenvalue_comparison_mode: EigenvalueComparisonMode,
397    ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
398        let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
399
400        if complex_symmetric {
401            // Complex-symmetric H and S
402            ensure!(
403                (hmat.to_owned() - hmat.t()).norm_l2() <= thresh_offdiag,
404                "Hamiltonian matrix is not complex-symmetric."
405            );
406            ensure!(
407                (smat.to_owned() - smat.t()).norm_l2() <= thresh_offdiag,
408                "Overlap matrix is not complex-symmetric."
409            );
410        } else {
411            // Complex-Hermitian H and S
412            ensure!(
413                (hmat.to_owned() - hmat.map(|v| v.conj()).t()).norm_l2() <= thresh_offdiag,
414                "Hamiltonian matrix is not complex-Hermitian."
415            );
416            ensure!(
417                (smat.to_owned() - smat.map(|v| v.conj()).t()).norm_l2() <= thresh_offdiag,
418                "Overlap matrix is not complex-Hermitian."
419            );
420        }
421
422        let (geneigvals, eigvecs) =
423            (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
424
425        // Filter and sort the eigenvalues and eigenvectors
426        let mut indices = (0..geneigvals.len())
427            .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
428            .collect_vec();
429
430        match eigenvalue_comparison_mode {
431            EigenvalueComparisonMode::Modulus => {
432                indices.sort_by(|i, j| {
433                    if let (
434                        GeneralizedEigenvalue::Finite(e_i, _),
435                        GeneralizedEigenvalue::Finite(e_j, _),
436                    ) = (&geneigvals[*i], &geneigvals[*j])
437                    {
438                        ComplexFloat::abs(*e_i)
439                            .partial_cmp(&ComplexFloat::abs(*e_j))
440                            .unwrap()
441                    } else {
442                        panic!("Unable to compare some eigenvalues.")
443                    }
444                });
445            }
446            EigenvalueComparisonMode::Real => {
447                indices.sort_by(|i, j| {
448                    if let (
449                        GeneralizedEigenvalue::Finite(e_i, _),
450                        GeneralizedEigenvalue::Finite(e_j, _),
451                    ) = (&geneigvals[*i], &geneigvals[*j])
452                    {
453                        e_i.re().partial_cmp(&e_j.re()).unwrap()
454                    } else {
455                        panic!("Unable to compare some eigenvalues.")
456                    }
457                });
458            }
459        }
460
461        let eigvals_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
462            if let GeneralizedEigenvalue::Finite(v, _) = gv {
463                *v
464            } else {
465                panic!("Unexpected indeterminate eigenvalue.")
466            }
467        });
468        let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
469
470        // Normalise the eigenvectors
471        let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
472            &eigvecs_sorted.view(),
473            &smat.view(),
474            complex_symmetric,
475            thresh_offdiag,
476        )?;
477
478        // Regularise the eigenvectors
479        let eigvecs_sorted_normalised_regularised =
480            regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
481
482        Ok(GeneralisedEigenvalueResult {
483            eigenvalues: eigvals_sorted,
484            eigenvectors: eigvecs_sorted_normalised_regularised,
485        })
486    }
487}
488
489// -------------------
490// Auxiliary functions
491// -------------------
492
493/// Sorts the eigenvalues and the corresponding eigenvectors.
494///
495/// # Arguments
496///
497/// * `eigvals` - The eigenvalues.
498/// * `eigvecs` - The corresponding eigenvectors.
499/// * `eigenvalue_comparison_mode` - Eigenvalue comparison mode.
500///
501/// # Returns
502///
503/// A tuple containing thw sorted eigenvalues and eigenvectors.
504fn sort_eigenvalues_eigenvectors<T: ComplexFloat>(
505    eigvals: &ArrayView1<T>,
506    eigvecs: &ArrayView2<T>,
507    eigenvalue_comparison_mode: &EigenvalueComparisonMode,
508) -> (Array1<T>, Array2<T>) {
509    let mut indices = (0..eigvals.len()).collect_vec();
510    match eigenvalue_comparison_mode {
511        EigenvalueComparisonMode::Modulus => {
512            indices.sort_by(|i, j| {
513                ComplexFloat::abs(eigvals[*i])
514                    .partial_cmp(&ComplexFloat::abs(eigvals[*j]))
515                    .unwrap()
516            });
517        }
518        EigenvalueComparisonMode::Real => {
519            indices.sort_by(|i, j| eigvals[*i].re().partial_cmp(&eigvals[*j].re()).unwrap());
520        }
521    }
522    let eigvals_sorted = eigvals.select(Axis(0), &indices);
523    let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
524    (eigvals_sorted, eigvecs_sorted)
525}
526
527/// Regularises the eigenvectors such that the first entry of each of them has a positive real
528/// part, or a positive imaginary part if the real part is zero.
529///
530/// # Arguments
531///
532/// * `eigvecs` - The eigenvectors to be regularised.
533/// * `thresh` - Threshold for determining if a real number is zero.
534///
535/// # Returns
536///
537/// The regularised eigenvectors.
538fn regularise_eigenvectors<T>(eigvecs: &ArrayView2<T>, thresh: T::Real) -> Array2<T>
539where
540    T: ComplexFloat + One,
541    T::Real: Float,
542{
543    let eigvecs_sgn = stack!(
544        Axis(0),
545        eigvecs
546            .row(0)
547            .map(|v| {
548                if Float::abs(ComplexFloat::re(*v)) > thresh {
549                    T::from(v.re().signum()).expect("Unable to convert a signum to the right type.")
550                } else if Float::abs(ComplexFloat::im(*v)) > thresh {
551                    T::from(v.im().signum()).expect("Unable to convert a signum to the right type.")
552                } else {
553                    T::one()
554                }
555            })
556            .view()
557    );
558    let eigvecs_regularised = eigvecs * eigvecs_sgn;
559    eigvecs_regularised
560}
561
562/// Normalises the real eigenvectors with respect to a metric.
563///
564/// # Arguments
565///
566/// * `eigvecs` - The eigenvectors to be normalised.
567/// * `smat` - The metric.
568/// * `thresh` - Threshold for verifying the orthogonality of the eigenvectors.
569///
570/// # Returns
571///
572/// The normalised eigenvectors.
573fn normalise_eigenvectors_real<T>(
574    eigvecs: &ArrayView2<T>,
575    smat: &ArrayView2<T>,
576    thresh: T,
577) -> Result<Array2<T>, anyhow::Error>
578where
579    T: LinalgScalar + Float,
580{
581    let sq_norm = einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
582        .map_err(|err| format_err!(err))?
583        .into_dimensionality::<Ix2>()
584        .map_err(|err| format_err!(err))?;
585    let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
586        .iter()
587        .map(|x| x.abs())
588        .max_by(|x, y| {
589            x.partial_cmp(y)
590                .expect("Unable to compare two `abs` values.")
591        })
592        .ok_or_else(|| {
593            format_err!(
594                "Unable to determine the maximum off-diagonal element of the C^T.S.C matrix."
595            )
596        })?;
597
598    ensure!(
599        max_diff <= thresh,
600        "The C^T.S.C matrix is not a diagonal matrix."
601    );
602    ensure!(
603        sq_norm.diag().iter().all(|v| *v > T::zero()),
604        "Some eigenvectors have negative squared norms and cannot be normalised over the reals."
605    );
606    let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
607    Ok(eigvecs_normalised)
608}
609
610/// Normalises the complex eigenvectors with respect to a metric.
611///
612/// # Arguments
613///
614/// * `eigvecs` - The eigenvectors to be normalised.
615/// * `smat` - The metric.
616/// * `complex_symmetric` - Boolean indicating if the inner product is complex-symmetric or not.
617/// * `thresh` - Threshold for verifying the orthogonality of the eigenvectors.
618///
619/// # Returns
620///
621/// The normalised eigenvectors.
622fn normalise_eigenvectors_complex<T>(
623    eigvecs: &ArrayView2<T>,
624    smat: &ArrayView2<T>,
625    complex_symmetric: bool,
626    thresh: T::Real,
627) -> Result<Array2<T>, anyhow::Error>
628where
629    T: LinalgScalar + ComplexFloat + std::fmt::Display,
630    T::Real: Float,
631{
632    let sq_norm = if complex_symmetric {
633        einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
634            .map_err(|err| format_err!(err))?
635            .into_dimensionality::<Ix2>()
636            .map_err(|err| format_err!(err))?
637    } else {
638        einsum(
639            "ji,jk,kl->il",
640            &[&eigvecs.map(|v| v.conj()).view(), smat, eigvecs],
641        )
642        .map_err(|err| format_err!(err))?
643        .into_dimensionality::<Ix2>()
644        .map_err(|err| format_err!(err))?
645    };
646    let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
647        .iter()
648        .map(|x| ComplexFloat::abs(*x))
649        .max_by(|x, y| {
650            x.partial_cmp(y)
651                .expect("Unable to compare two `abs` values.")
652        })
653        .ok_or_else(|| {
654            if complex_symmetric {
655                format_err!(
656                    "Unable to determine the maximum off-diagonal element of the C^†.S.C matrix."
657                )
658            } else {
659                format_err!(
660                    "Unable to determine the maximum off-diagonal element of the C^†.S.C matrix."
661                )
662            }
663        })?;
664
665    ensure!(
666        max_diff <= thresh,
667        if complex_symmetric {
668            "The C^T.S.C matrix is not a diagonal matrix."
669        } else {
670            "The C^†.S.C matrix is not a diagonal matrix."
671        }
672    );
673    let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
674    Ok(eigvecs_normalised)
675}