qsym2/target/noci/backend/nonortho/
mod.rs

1use std::fmt::LowerExp;
2use std::iter::Product;
3
4use anyhow::{self, ensure, format_err};
5use derive_builder::Builder;
6use duplicate::duplicate_item;
7use indexmap::IndexSet;
8use itertools::Itertools;
9use ndarray::{
10    Array1, Array2, ArrayView1, ArrayView2, ArrayView4, Axis, Ix0, Ix2, ScalarOperand, stack,
11};
12use ndarray_einsum::einsum;
13use ndarray_linalg::types::Lapack;
14use ndarray_linalg::{Determinant, Eig, Eigh, SVD, Scalar, UPLO};
15use num::{Complex, Float};
16use num_complex::ComplexFloat;
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19
20use super::denmat::{calc_unweighted_codensity_matrix, calc_weighted_codensity_matrix};
21
22#[cfg(test)]
23#[path = "nonortho_tests.rs"]
24mod nonortho_tests;
25
26/// Structure containing Löwdin-paired coefficients, the corresponding Löwdin overlaps, and the
27/// indices of the zero overlaps.
28///
29/// The Löwdin-paired coefficients satisfy
30/// ```math
31///     ^{wx}\mathbf{\Lambda}
32///         = \mathrm{diag}(^{wx}\lambda_i)
33///         = ^{w}\!\tilde{\mathbf{C}}^{\dagger\lozenge}
34///           \ \mathbf{S}_{\mathrm{AO}}
35///           \ ^{x}\tilde{\mathbf{C}}.
36/// ```
37#[derive(Builder)]
38#[builder(build_fn(validate = "Self::validate"))]
39pub struct LowdinPairedCoefficients<T: ComplexFloat> {
40    /// The $`^{w}\!\tilde{\mathbf{C}}`$ coefficient matrix.
41    paired_cw: Array2<T>,
42
43    /// The $`^{x}\!\tilde{\mathbf{C}}`$ coefficient matrix.
44    paired_cx: Array2<T>,
45
46    /// The Löwdin overlaps $`\{^{wx}\lambda_i\}`$.
47    lowdin_overlaps: Vec<T>,
48
49    /// The indices of the zero overlaps with respect to [`Self::thresh_zeroov`]. If these are not
50    /// provided, the specified [`Self::thresh_zeroov`] will be used to deduce them from the
51    /// specified `[Self::lowdin_overlaps]`.
52    #[builder(default = "self.default_zero_indices()?")]
53    zero_indices: IndexSet<usize>,
54
55    /// Threshold for determining zero overlaps.
56    thresh_zeroov: T::Real,
57
58    /// Boolean indicating whether the coefficients have been Löwdin-paired with respect to the
59    /// complex-symmetric inner product.
60    complex_symmetric: bool,
61}
62
63impl<T: ComplexFloat> LowdinPairedCoefficientsBuilder<T> {
64    fn validate(&self) -> Result<(), String> {
65        let paired_cw = self
66            .paired_cw
67            .as_ref()
68            .ok_or("Löwdin-paired coefficients `paired_cw` not set.".to_string())?;
69        let paired_cx = self
70            .paired_cx
71            .as_ref()
72            .ok_or("Löwdin-paired coefficients `paired_cx` not set.".to_string())?;
73        let lowdin_overlaps = self
74            .lowdin_overlaps
75            .as_ref()
76            .ok_or("Löwdin overlaps not set.".to_string())?;
77        let zero_indices = self
78            .zero_indices
79            .as_ref()
80            .ok_or("Indices of zero Löwdin overlaps not set.".to_string())?;
81
82        if paired_cw.shape() == paired_cx.shape() {
83            let lowdin_dim = paired_cw.shape()[1];
84            if lowdin_dim == lowdin_overlaps.len() {
85                if zero_indices.iter().all(|i| *i < lowdin_dim) {
86                    Ok(())
87                } else {
88                    Err("Some indices of zero Löwdin overlaps are out-of-bound!".to_string())
89                }
90            } else {
91                Err(
92                    "Inconsistent number of Löwdin-paired orbitals and Löwdin overlaps."
93                        .to_string(),
94                )
95            }
96        } else {
97            Err(format!(
98                "Inconsistent shapes between `paired_cw` ({:?}) and `paired_cx` ({:?}).",
99                paired_cw.shape(),
100                paired_cx.shape()
101            ))
102        }
103    }
104
105    fn default_zero_indices(&self) -> Result<IndexSet<usize>, String> {
106        let lowdin_overlaps = self
107            .lowdin_overlaps
108            .as_ref()
109            .ok_or("Löwdin overlaps not set.".to_string())?;
110        let thresh_zeroov = self
111            .thresh_zeroov
112            .as_ref()
113            .ok_or("threshold for zero Löwdin overlaps not set.".to_string())?;
114        let zero_indices = lowdin_overlaps
115            .iter()
116            .enumerate()
117            .filter(|(_, ov)| ComplexFloat::abs(**ov) < *thresh_zeroov)
118            .map(|(i, _)| i)
119            .collect::<IndexSet<_>>();
120        Ok(zero_indices)
121    }
122}
123
124impl<T: ComplexFloat> LowdinPairedCoefficients<T> {
125    pub fn builder() -> LowdinPairedCoefficientsBuilder<T> {
126        LowdinPairedCoefficientsBuilder::<T>::default()
127    }
128
129    /// Returns the Löwdin-paired coefficients.
130    pub fn paired_coefficients(&self) -> (&Array2<T>, &Array2<T>) {
131        (&self.paired_cw, &self.paired_cx)
132    }
133
134    /// Returns the number of atomic-orbital basis functions in which the coefficient matrices are
135    /// expressed.
136    pub fn nbasis(&self) -> usize {
137        self.paired_cw.nrows()
138    }
139
140    /// Returns the number of molecular orbitals being Löwdin-paired.
141    pub fn lowdin_dim(&self) -> usize {
142        self.lowdin_overlaps.len()
143    }
144
145    /// Returns the number of zero Löwdin overlaps.
146    pub fn n_lowdin_zeros(&self) -> usize {
147        self.zero_indices.len()
148    }
149
150    /// Returns the Löwdin overlaps.
151    pub fn lowdin_overlaps(&self) -> &Vec<T> {
152        &self.lowdin_overlaps
153    }
154
155    /// Returns the indices of the zero Löwdin overlaps.
156    pub fn zero_indices(&self) -> &IndexSet<usize> {
157        &self.zero_indices
158    }
159
160    /// Returns the indices of the non-zero Löwdin overlaps.
161    pub fn nonzero_indices(&self) -> IndexSet<usize> {
162        (0..self.lowdin_dim())
163            .filter(|i| !self.zero_indices.contains(i))
164            .collect::<IndexSet<_>>()
165    }
166
167    /// Returns the boolean indicating whether the Löwdin pairing is with respect to the
168    /// complex-symmetric inner product.
169    pub fn complex_symmetric(&self) -> bool {
170        self.complex_symmetric
171    }
172}
173
174impl<T: ComplexFloat + Product> LowdinPairedCoefficients<T> {
175    /// The reduced overlap between the two determinants.
176    pub fn reduced_overlap(&self) -> T {
177        self.nonzero_indices()
178            .iter()
179            .map(|i| self.lowdin_overlaps[*i])
180            .product()
181    }
182}
183
184/// Performs Löwdin pairing on two coefficient matrices $`^{w}\mathbf{C}`$ and
185/// $`^{x}\mathbf{C}`$.
186///
187/// Löwdin pairing ensures that
188/// ```math
189///     ^{wx}\mathbf{\Lambda}
190///         = \mathrm{diag}(^{wx}\lambda_i)
191///         = ^{w}\!\tilde{\mathbf{C}}^{\dagger\lozenge}
192///           \ \mathbf{S}_{\mathrm{AO}}
193///           \ ^{x}\tilde{\mathbf{C}},
194/// ```
195/// where the Löwdin-paired coefficient matrices are given by
196/// ```math
197///     \begin{align*}
198///         ^{w}\!\tilde{\mathbf{C}}
199///             &=\ ^{w}\mathbf{C}\ ^{wx}\mathbf{U}^{\lozenge} \\
200///         ^{x}\!\tilde{\mathbf{C}}
201///             &=\ ^{x}\mathbf{C}\ ^{wx}\mathbf{V}
202///     \end{align*}
203/// ```
204/// with $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$ being SVD factorisation matrices:
205/// ```math
206///     ^{w}\mathbf{C}^{\dagger\lozenge}
207///     \ \mathbf{S}_{\mathrm{AO}}
208///     \ ^{x}\mathbf{C}
209///     =
210///     \ ^{wx}\mathbf{U}
211///     \ ^{wx}\mathbf{\Lambda}
212///     \ ^{wx}\mathbf{V}^{\dagger}.
213/// ```
214///
215/// We note that the first columns of $`^{w}\!\tilde{\mathbf{C}}`$ and $`^{x}\!\tilde{\mathbf{C}}`$
216/// are also adjusted by the determinants of $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$ as
217/// appropriate to ensure that the Slater determinants corresponding to them remain invariant with
218/// respect to the unitary transformations brought about by $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$.
219///
220/// # Arguments
221///
222/// * `cw` - Coefficient matrix $`^{w}\mathbf{C}`$.
223/// * `cx` - Coefficient matrix $`^{x}\mathbf{C}`$.
224/// * `sao` - The overlap matrix $`\mathbf{S}_{\mathrm{AO}}`$ of the underlying atomic basis
225///   functions.
226/// * `complex_symmetric` - If `true`, $`\lozenge = \star`$. If `false`, $`\lozenge = \hat{e}`$.
227/// * `thresh_offdiag` - Threshold to check if the off-diagonal elements in the original orbital
228///   overlap matrix and in the Löwdin-paired orbital overlap matrix $`^{wx}\mathbf{\Lambda}`$ are zero.
229/// * `thresh_zeroov` - Threshold to identify which Löwdin overlaps $`^{wx}\lambda_i`$ are zero.
230///
231/// # Returns
232///
233/// A [`LowdinPairedCoefficients`] structure containing the result of the Löwdin pairing.
234pub fn calc_lowdin_pairing<T>(
235    cw: &ArrayView2<T>,
236    cx: &ArrayView2<T>,
237    sao: &ArrayView2<T>,
238    complex_symmetric: bool,
239    thresh_offdiag: <T as ComplexFloat>::Real,
240    thresh_zeroov: <T as ComplexFloat>::Real,
241) -> Result<LowdinPairedCoefficients<T>, anyhow::Error>
242where
243    T: ComplexFloat + Lapack,
244    <T as ComplexFloat>::Real: PartialOrd + LowerExp,
245{
246    if cw.shape() != cx.shape() {
247        Err(format_err!(
248            "Coefficient dimensions mismatched: cw ({:?}) !~ cx ({:?}).",
249            cw.shape(),
250            cx.shape()
251        ))
252    } else {
253        let init_orb_ovmat = if complex_symmetric {
254            einsum("ji,jk,kl->il", &[cw, sao, cx])
255        } else {
256            einsum("ji,jk,kl->il", &[&cw.map(|x| x.conj()), sao, cx])
257        }
258        .map_err(|err| format_err!(err))?
259        .into_dimensionality::<Ix2>()?;
260
261        let max_offdiag = (&init_orb_ovmat - &Array2::from_diag(&init_orb_ovmat.diag().to_owned()))
262            .iter()
263            .map(|x| ComplexFloat::abs(*x))
264            .max_by(|x, y| {
265                x.partial_cmp(y)
266                    .expect("Unable to compare two `abs` values.")
267            })
268            .ok_or_else(|| {
269                format_err!(
270                    "Unable to determine the maximum off-diagonal element for\n{}.",
271                    &init_orb_ovmat
272                )
273            })?;
274
275        if max_offdiag <= thresh_offdiag {
276            let lowdin_overlaps = init_orb_ovmat.into_diag().to_vec();
277            let zero_indices = lowdin_overlaps
278                .iter()
279                .enumerate()
280                .filter(|(_, ov)| ComplexFloat::abs(**ov) < thresh_zeroov)
281                .map(|(i, _)| i)
282                .collect::<IndexSet<_>>();
283            LowdinPairedCoefficients::builder()
284                .paired_cw(cw.to_owned())
285                .paired_cx(cx.to_owned())
286                .lowdin_overlaps(lowdin_overlaps)
287                .zero_indices(zero_indices)
288                .thresh_zeroov(thresh_zeroov)
289                .complex_symmetric(complex_symmetric)
290                .build()
291                .map_err(|err| format_err!(err))
292        } else {
293            let (u_opt, _, vh_opt) = init_orb_ovmat.svd(true, true)?;
294            let u = u_opt.ok_or_else(|| format_err!("Unable to compute the U matrix from SVD."))?;
295            let vh =
296                vh_opt.ok_or_else(|| format_err!("Unable to compute the V matrix from SVD."))?;
297            let v = vh.t().map(|x| x.conj());
298            let det_v_c = v.det()?.conj();
299
300            let paired_cw = if complex_symmetric {
301                let uc = u.map(|x| x.conj());
302                let mut cwt = cw.dot(&uc);
303                let det_uc_c = uc.det()?.conj();
304                cwt.column_mut(0).iter_mut().for_each(|x| *x *= det_uc_c);
305                cwt
306            } else {
307                let mut cwt = cw.dot(&u);
308                let det_u_c = u.det()?.conj();
309                cwt.column_mut(0).iter_mut().for_each(|x| *x *= det_u_c);
310                cwt
311            };
312
313            let paired_cx = {
314                let mut cxt = cx.dot(&v);
315                cxt.column_mut(0).iter_mut().for_each(|x| *x *= det_v_c);
316                cxt
317            };
318
319            let lowdin_orb_ovmat = if complex_symmetric {
320                einsum("ji,jk,kl->il", &[&paired_cw, sao, &paired_cx])
321            } else {
322                einsum(
323                    "ji,jk,kl->il",
324                    &[&paired_cw.map(|x| x.conj()), sao, &paired_cx],
325                )
326            }
327            .map_err(|err| format_err!(err))?
328            .into_dimensionality::<Ix2>()?;
329
330            let max_offdiag_lowdin = (&lowdin_orb_ovmat - &Array2::from_diag(&lowdin_orb_ovmat.diag().to_owned()))
331                .iter()
332                .map(|x| ComplexFloat::abs(*x))
333                .max_by(|x, y| {
334                    x.partial_cmp(y)
335                        .expect("Unable to compare two `abs` values.")
336                })
337                .ok_or_else(|| format_err!("Unable to determine the maximum off-diagonal element of the Lowdin-paired overlap matrix."))?;
338            if max_offdiag_lowdin <= thresh_offdiag {
339                let lowdin_overlaps = lowdin_orb_ovmat.into_diag().to_vec();
340                let zero_indices = lowdin_overlaps
341                    .iter()
342                    .enumerate()
343                    .filter(|(_, ov)| ComplexFloat::abs(**ov) < thresh_zeroov)
344                    .map(|(i, _)| i)
345                    .collect::<IndexSet<_>>();
346                LowdinPairedCoefficients::builder()
347                    .paired_cw(paired_cw.clone())
348                    .paired_cx(paired_cx.clone())
349                    .lowdin_overlaps(lowdin_overlaps)
350                    .zero_indices(zero_indices)
351                    .thresh_zeroov(thresh_zeroov)
352                    .complex_symmetric(complex_symmetric)
353                    .build()
354                    .map_err(|err| format_err!(err))
355            } else {
356                Err(format_err!(
357                    "Löwdin overlap matrix deviates from diagonality. Maximum off-diagonal overlap has magnitude {max_offdiag_lowdin:.3e} > threshold of {thresh_offdiag:.3e}. Löwdin pairing has failed."
358                ))
359            }
360        }
361    }
362}
363
364/// Calculates the matrix element of a zero-particle operator between two Löwdin-paired
365/// determinants.
366///
367/// # Arguments
368///
369/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
370/// subspace determined by the specified structure constraint.
371/// `o0` - The zero-particle operator.
372/// `structure_constraint` - The structure constraint governing the coefficients.
373///
374/// # Returns
375///
376/// The zero-particle matrix element.
377pub fn calc_o0_matrix_element<T, SC>(
378    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
379    o0: T,
380    structure_constraint: &SC,
381) -> Result<T, anyhow::Error>
382where
383    T: ComplexFloat + ScalarOperand + Product,
384    SC: StructureConstraint,
385{
386    let nzeros_explicit: usize = lowdin_paired_coefficientss
387        .iter()
388        .map(|lpc| lpc.n_lowdin_zeros())
389        .sum();
390    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
391    if nzeros > 0 {
392        Ok(T::zero())
393    } else {
394        let reduced_ov_explicit: T = lowdin_paired_coefficientss
395            .iter()
396            .map(|lpc| lpc.reduced_overlap())
397            .product();
398        let reduced_ov = (0..structure_constraint.implicit_factor()?)
399            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
400        Ok(reduced_ov * o0)
401    }
402}
403
404/// Calculates the matrix element of a one-particle operator between two Löwdin-paired
405/// determinants.
406///
407/// # Arguments
408///
409/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
410/// subspace determined by the specified structure constraint.
411/// `o1` - The one-particle operator in the atomic-orbital basis.
412/// `structure_constraint` - The structure constraint governing the coefficients.
413///
414/// # Returns
415///
416/// The one-particle matrix element.
417pub fn calc_o1_matrix_element<T, SC>(
418    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
419    o1: &ArrayView2<T>,
420    structure_constraint: &SC,
421) -> Result<T, anyhow::Error>
422where
423    T: ComplexFloat + ScalarOperand + Product,
424    SC: StructureConstraint,
425{
426    let nzeros_explicit: usize = lowdin_paired_coefficientss
427        .iter()
428        .map(|lpc| lpc.n_lowdin_zeros())
429        .sum();
430    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
431    if nzeros > 1 {
432        Ok(T::zero())
433    } else {
434        let reduced_ov_explicit: T = lowdin_paired_coefficientss
435            .iter()
436            .map(|lpc| lpc.reduced_overlap())
437            .product();
438        let reduced_ov = (0..structure_constraint.implicit_factor()?)
439            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
440
441        if nzeros == 0 {
442            let nbasis = lowdin_paired_coefficientss[0].nbasis();
443            let w = (0..structure_constraint.implicit_factor()?)
444                .cartesian_product(lowdin_paired_coefficientss.iter())
445                .try_fold(Array2::<T>::zeros((nbasis, nbasis)), |acc, (_, lpc)| {
446                    calc_weighted_codensity_matrix(lpc).map(|w| acc + w)
447                })?;
448            // i = μ, j = μ'
449            einsum("ij,ji->", &[o1, &w.view()])
450                .map_err(|err| format_err!(err))?
451                .into_dimensionality::<Ix0>()?
452                .into_iter()
453                .next()
454                .ok_or_else(|| {
455                    format_err!("Unable to extract the result of the einsum contraction.")
456                })
457                .map(|v| v * reduced_ov)
458        } else {
459            ensure!(
460                nzeros == 1,
461                "Unexpected number of zero Löwdin overlaps: {nzeros} != 1."
462            );
463            let ps = (0..structure_constraint.implicit_factor()?)
464                .flat_map(|_| {
465                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
466                        lpc.zero_indices()
467                            .iter()
468                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
469                    })
470                })
471                .collect::<Result<Vec<_>, _>>()?;
472            ensure!(
473                ps.len() == 1,
474                "Unexpected number of unweighted codensity matrices ({}) for one zero overlap.",
475                ps.len()
476            );
477            let p_mbar = ps.first().ok_or_else(|| {
478                format_err!("Unable to retrieve the computed unweighted codensity matrix.")
479            })?;
480
481            // i = μ, j = μ'
482            einsum("ij,ji->", &[o1, &p_mbar.view()])
483                .map_err(|err| format_err!(err))?
484                .into_dimensionality::<Ix0>()?
485                .into_iter()
486                .next()
487                .ok_or_else(|| {
488                    format_err!("Unable to extract the result of the einsum contraction.")
489                })
490                .map(|v| v * reduced_ov)
491        }
492    }
493}
494
495/// Calculates the matrix element of a two-particle operator between two Löwdin-paired
496/// determinants.
497///
498/// # Arguments
499///
500/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
501/// subspace determined by the specified structure constraint.
502/// `o2_opt` - The two-particle operator in the atomic-orbital basis.
503/// `structure_constraint` - The structure constraint governing the coefficients.
504///
505/// # Returns
506///
507/// The two-particle matrix element.
508pub fn calc_o2_matrix_element<'a, 'b, T, SC, F>(
509    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
510    o2_opt: Option<&'b ArrayView4<'a, T>>,
511    get_jk_opt: Option<&F>,
512    structure_constraint: &SC,
513) -> Result<T, anyhow::Error>
514where
515    'a: 'b,
516    T: ComplexFloat + ScalarOperand + Product + std::fmt::Display,
517    SC: StructureConstraint,
518    F: Fn(&Array2<T>) -> Result<(Array2<T>, Array2<T>), anyhow::Error>,
519{
520    let nzeros_explicit: usize = lowdin_paired_coefficientss
521        .iter()
522        .map(|lpc| lpc.n_lowdin_zeros())
523        .sum();
524    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
525    if nzeros > 2 {
526        Ok(T::zero())
527    } else {
528        let reduced_ov_explicit: T = lowdin_paired_coefficientss
529            .iter()
530            .map(|lpc| lpc.reduced_overlap())
531            .product();
532        let reduced_ov = (0..structure_constraint.implicit_factor()?)
533            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
534
535        if nzeros == 0 {
536            let nbasis = lowdin_paired_coefficientss[0].nbasis();
537            let w_sigmas = (0..structure_constraint.implicit_factor()?)
538                .cartesian_product(lowdin_paired_coefficientss.iter())
539                .map(|(_, lpc)| calc_weighted_codensity_matrix(lpc))
540                .collect::<Result<Vec<_>, _>>()?;
541            let w = w_sigmas
542                .iter()
543                .fold(Array2::<T>::zeros((nbasis, nbasis)), |acc, w_sigma| {
544                    acc + w_sigma
545                });
546
547            match (o2_opt, get_jk_opt) {
548                (Some(o2), None) => {
549                    // i = μ, j = μ', k = ν, l = ν'
550                    let j_term = einsum("ikjl,ji,lk->", &[o2, &w.view(), &w.view()])
551                        .map_err(|err| format_err!(err))?
552                        .into_dimensionality::<Ix0>()?
553                        .into_iter()
554                        .next()
555                        .ok_or_else(|| {
556                            format_err!("Unable to extract the result of the einsum contraction.")
557                        })
558                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
559                    let k_term = w_sigmas
560                        .iter()
561                        .try_fold(T::zero(), |acc, w_sigma| {
562                            einsum("ikjl,li,jk->", &[o2, &w_sigma.view(), &w_sigma.view()])
563                                .map_err(|err| format_err!(err))?
564                                .into_dimensionality::<Ix0>()?
565                                .into_iter()
566                                .next()
567                                .ok_or_else(|| {
568                                    format_err!(
569                                        "Unable to extract the result of the einsum contraction."
570                                    )
571                                })
572                                .map(|v| acc + v)
573                        })
574                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
575                    Ok(j_term - k_term)
576                }
577                (None, Some(get_jk)) => {
578                    // i = μ, j = μ', k = ν, l = ν'
579                    let (j_w, _) = get_jk(&w)?;
580                    let j_term = einsum("ij,ji->", &[&j_w.view(), &w.view()])
581                        .map_err(|err| format_err!(err))?
582                        .into_dimensionality::<Ix0>()?
583                        .into_iter()
584                        .next()
585                        .ok_or_else(|| {
586                            format_err!("Unable to extract the result of the einsum contraction.")
587                        })
588                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
589                    let k_term = w_sigmas
590                        .iter()
591                        .try_fold(T::zero(), |acc, w_sigma| {
592                            let (_, k_w_sigma) = get_jk(w_sigma)?;
593                            einsum("il,li->", &[&k_w_sigma.view(), &w_sigma.view()])
594                                .map_err(|err| format_err!(err))?
595                                .into_dimensionality::<Ix0>()?
596                                .into_iter()
597                                .next()
598                                .ok_or_else(|| {
599                                    format_err!(
600                                        "Unable to extract the result of the einsum contraction."
601                                    )
602                                })
603                                .map(|v| acc + v)
604                        })
605                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
606                    Ok(j_term - k_term)
607                }
608                _ => Err(format_err!(
609                    "One and only one of `o2` or `get_jk` should be provided."
610                )),
611            }
612        } else if nzeros == 1 {
613            ensure!(
614                nzeros_explicit == 1,
615                "Unexpected number of explicit zero Löwdin overlaps: {nzeros_explicit} != 1."
616            );
617
618            let nbasis = lowdin_paired_coefficientss[0].nbasis();
619            let w = (0..structure_constraint.implicit_factor()?)
620                .cartesian_product(lowdin_paired_coefficientss.iter())
621                .try_fold(Array2::<T>::zeros((nbasis, nbasis)), |acc, (_, lpc)| {
622                    calc_weighted_codensity_matrix(lpc).map(|w_sigma| acc + w_sigma)
623                })?;
624
625            lowdin_paired_coefficientss
626                .iter()
627                .filter_map(|lpc| {
628                    if lpc.n_lowdin_zeros() == 1 {
629                        let w_sigma_res = calc_weighted_codensity_matrix(lpc);
630                        let mbar = lpc.zero_indices()[0];
631                        let p_mbar_sigma_res = calc_unweighted_codensity_matrix(lpc, mbar);
632                        Some((w_sigma_res, p_mbar_sigma_res))
633                    } else {
634                        None
635                    }
636                })
637                .try_fold(T::zero(), |acc, (w_sigma_res, p_mbar_sigma_res)| {
638                    w_sigma_res.and_then(|w_sigma| {
639                        p_mbar_sigma_res.and_then(|p_mbar_sigma| {
640                            match (o2_opt, get_jk_opt) {
641                                (Some(o2), None) => {
642                                    // i = μ, j = μ', k = ν, l = ν'
643                                    let j_term_1 = einsum(
644                                        "ikjl,ji,lk->",
645                                        &[o2, &w.view(), &p_mbar_sigma.view()],
646                                    )
647                                    .map_err(|err| format_err!(err))?
648                                    .into_dimensionality::<Ix0>()?
649                                    .into_iter()
650                                    .next()
651                                    .ok_or_else(|| {
652                                        format_err!(
653                                            "Unable to extract the result of the einsum contraction."
654                                        )
655                                    })?;
656                                    let j_term_2 = einsum(
657                                        "ikjl,ji,lk->",
658                                        &[o2, &p_mbar_sigma.view(), &w.view()],
659                                    )
660                                    .map_err(|err| format_err!(err))?
661                                    .into_dimensionality::<Ix0>()?
662                                    .into_iter()
663                                    .next()
664                                    .ok_or_else(|| {
665                                        format_err!(
666                                            "Unable to extract the result of the einsum contraction."
667                                        )
668                                    })?;
669                                    let k_term_1 = einsum(
670                                        "ikjl,li,jk->",
671                                        &[o2, &w_sigma.view(), &p_mbar_sigma.view()],
672                                    )
673                                    .map_err(|err| format_err!(err))?
674                                    .into_dimensionality::<Ix0>()?
675                                    .into_iter()
676                                    .next()
677                                    .ok_or_else(|| {
678                                        format_err!(
679                                            "Unable to extract the result of the einsum contraction."
680                                        )
681                                    })?;
682                                    let k_term_2 = einsum(
683                                        "ikjl,li,jk->",
684                                        &[o2, &p_mbar_sigma.view(), &w_sigma.view()],
685                                    )
686                                    .map_err(|err| format_err!(err))?
687                                    .into_dimensionality::<Ix0>()?
688                                    .into_iter()
689                                    .next()
690                                    .ok_or_else(|| {
691                                        format_err!(
692                                            "Unable to extract the result of the einsum contraction."
693                                        )
694                                    })?;
695                                    Ok(acc + j_term_1 + j_term_2 - k_term_1 - k_term_2)
696                                },
697                                (None, Some(get_jk)) => {
698                                    // i = μ, j = μ', k = ν, l = ν'
699                                    let (j_p_mbar_sigma, k_p_mbar_sigma) = get_jk(&p_mbar_sigma)?;
700                                    let j_term_1 = einsum(
701                                        "ij,ji->",
702                                        &[&j_p_mbar_sigma.view(), &w.view()],
703                                    )
704                                    .map_err(|err| format_err!(err))?
705                                    .into_dimensionality::<Ix0>()?
706                                    .into_iter()
707                                    .next()
708                                    .ok_or_else(|| {
709                                        format_err!(
710                                            "Unable to extract the result of the einsum contraction."
711                                        )
712                                    })?;
713                                    let (j_w, _) = get_jk(&w)?;
714                                    let j_term_2 = einsum(
715                                        "ij,ji->",
716                                        &[&j_w.view(), &p_mbar_sigma.view()],
717                                    )
718                                    .map_err(|err| format_err!(err))?
719                                    .into_dimensionality::<Ix0>()?
720                                    .into_iter()
721                                    .next()
722                                    .ok_or_else(|| {
723                                        format_err!(
724                                            "Unable to extract the result of the einsum contraction."
725                                        )
726                                    })?;
727                                    let k_term_1 = einsum(
728                                        "il,li->",
729                                        &[&k_p_mbar_sigma.view(), &w_sigma.view()],
730                                    )
731                                    .map_err(|err| format_err!(err))?
732                                    .into_dimensionality::<Ix0>()?
733                                    .into_iter()
734                                    .next()
735                                    .ok_or_else(|| {
736                                        format_err!(
737                                            "Unable to extract the result of the einsum contraction."
738                                        )
739                                    })?;
740                                    let (_, k_w_sigma) = get_jk(&w_sigma)?;
741                                    let k_term_2 = einsum(
742                                        "il,li->",
743                                        &[&k_w_sigma.view(), &p_mbar_sigma.view()],
744                                    )
745                                    .map_err(|err| format_err!(err))?
746                                    .into_dimensionality::<Ix0>()?
747                                    .into_iter()
748                                    .next()
749                                    .ok_or_else(|| {
750                                        format_err!(
751                                            "Unable to extract the result of the einsum contraction."
752                                        )
753                                    })?;
754                                    Ok(acc + j_term_1 + j_term_2 - k_term_1 - k_term_2)
755                                }
756                                _ => Err(format_err!(
757                                    "One and only one of `o2` or `get_jk` should be provided."
758                                )),
759                            }
760                        })
761                    })
762                })
763                .map(|v| v * reduced_ov / (T::one() + T::one()))
764        } else {
765            ensure!(
766                nzeros == 2,
767                "Unexpected number of zero Löwdin overlaps: {nzeros} != 2."
768            );
769
770            let ps = (0..structure_constraint.implicit_factor()?)
771                .flat_map(|_| {
772                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
773                        lpc.zero_indices()
774                            .iter()
775                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
776                    })
777                })
778                .collect::<Result<Vec<_>, _>>()?;
779            ensure!(
780                ps.len() == 2,
781                "Unexpected number of unweighted codensity matrices ({}) for two zero overlaps.",
782                ps.len()
783            );
784            let p_mbar = ps.first().ok_or_else(|| {
785                format_err!("Unable to retrieve the first computed unweighted codensity matrix.")
786            })?;
787            let p_nbar = ps.last().ok_or_else(|| {
788                format_err!("Unable to retrieve the second computed unweighted codensity matrix.")
789            })?;
790
791            match (o2_opt, get_jk_opt) {
792                (Some(o2), None) => {
793                    // i = μ, j = μ', k = ν, l = ν'
794                    let j_term_1 = einsum("ikjl,ji,lk->", &[o2, &p_mbar.view(), &p_nbar.view()])
795                        .map_err(|err| format_err!(err))?
796                        .into_dimensionality::<Ix0>()?
797                        .into_iter()
798                        .next()
799                        .ok_or_else(|| {
800                            format_err!("Unable to extract the result of the einsum contraction.")
801                        })?;
802                    let j_term_2 = einsum("ikjl,ji,lk->", &[o2, &p_nbar.view(), &p_mbar.view()])
803                        .map_err(|err| format_err!(err))?
804                        .into_dimensionality::<Ix0>()?
805                        .into_iter()
806                        .next()
807                        .ok_or_else(|| {
808                            format_err!("Unable to extract the result of the einsum contraction.")
809                        })?;
810
811                    let (k_term_1, k_term_2) = if lowdin_paired_coefficientss
812                        .iter()
813                        .any(|lpc| lpc.n_lowdin_zeros() == 2)
814                    {
815                        let k_term_1 =
816                            einsum("ikjl,li,jk->", &[o2, &p_mbar.view(), &p_nbar.view()])
817                                .map_err(|err| format_err!(err))?
818                                .into_dimensionality::<Ix0>()?
819                                .into_iter()
820                                .next()
821                                .ok_or_else(|| {
822                                    format_err!(
823                                        "Unable to extract the result of the einsum contraction."
824                                    )
825                                })?;
826                        let k_term_2 =
827                            einsum("ikjl,li,jk->", &[o2, &p_nbar.view(), &p_mbar.view()])
828                                .map_err(|err| format_err!(err))?
829                                .into_dimensionality::<Ix0>()?
830                                .into_iter()
831                                .next()
832                                .ok_or_else(|| {
833                                    format_err!(
834                                        "Unable to extract the result of the einsum contraction."
835                                    )
836                                })?;
837                        (k_term_1, k_term_2)
838                    } else {
839                        (T::zero(), T::zero())
840                    };
841                    Ok(reduced_ov * (j_term_1 - k_term_1 + j_term_2 - k_term_2)
842                        / (T::one() + T::one()))
843                }
844                (None, Some(get_jk)) => {
845                    // i = μ, j = μ', k = ν, l = ν'
846                    let (j_p_nbar, k_p_nbar) = get_jk(p_nbar)?;
847                    let j_term_1 = einsum("ij,ji->", &[&j_p_nbar.view(), &p_mbar.view()])
848                        .map_err(|err| format_err!(err))?
849                        .into_dimensionality::<Ix0>()?
850                        .into_iter()
851                        .next()
852                        .ok_or_else(|| {
853                            format_err!("Unable to extract the result of the einsum contraction.")
854                        })?;
855                    let (j_p_mbar, k_p_mbar) = get_jk(p_mbar)?;
856                    let j_term_2 = einsum("ij,ji->", &[&j_p_mbar.view(), &p_nbar.view()])
857                        .map_err(|err| format_err!(err))?
858                        .into_dimensionality::<Ix0>()?
859                        .into_iter()
860                        .next()
861                        .ok_or_else(|| {
862                            format_err!("Unable to extract the result of the einsum contraction.")
863                        })?;
864
865                    let (k_term_1, k_term_2) = if lowdin_paired_coefficientss
866                        .iter()
867                        .any(|lpc| lpc.n_lowdin_zeros() == 2)
868                    {
869                        let k_term_1 = einsum("il,li->", &[&k_p_nbar.view(), &p_mbar.view()])
870                            .map_err(|err| format_err!(err))?
871                            .into_dimensionality::<Ix0>()?
872                            .into_iter()
873                            .next()
874                            .ok_or_else(|| {
875                                format_err!(
876                                    "Unable to extract the result of the einsum contraction."
877                                )
878                            })?;
879                        let k_term_2 = einsum("il,li->", &[&k_p_mbar.view(), &p_nbar.view()])
880                            .map_err(|err| format_err!(err))?
881                            .into_dimensionality::<Ix0>()?
882                            .into_iter()
883                            .next()
884                            .ok_or_else(|| {
885                                format_err!(
886                                    "Unable to extract the result of the einsum contraction."
887                                )
888                            })?;
889                        (k_term_1, k_term_2)
890                    } else {
891                        (T::zero(), T::zero())
892                    };
893                    Ok(reduced_ov * (j_term_1 - k_term_1 + j_term_2 - k_term_2)
894                        / (T::one() + T::one()))
895                }
896                _ => Err(format_err!(
897                    "One and only one of `o2` or `get_jk` should be provided."
898                )),
899            }
900        }
901    }
902}
903
904/// Calculates the transition density matrix between two Löwdin-paired determinants.
905///
906/// # Arguments
907///
908/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
909/// subspace determined by the specified structure constraint.
910/// `structure_constraint` - The structure constraint governing the coefficients.
911///
912/// # Returns
913///
914/// The one-particle matrix element.
915pub fn calc_transition_density_matrix<T, SC>(
916    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
917    structure_constraint: &SC,
918) -> Result<Array2<T>, anyhow::Error>
919where
920    T: ComplexFloat + ScalarOperand + Product,
921    SC: StructureConstraint,
922{
923    let nzeros_explicit: usize = lowdin_paired_coefficientss
924        .iter()
925        .map(|lpc| lpc.n_lowdin_zeros())
926        .sum();
927    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
928    let nbasis = lowdin_paired_coefficientss[0].nbasis();
929    if nzeros > 1 {
930        Ok(Array2::<T>::zeros((nbasis, nbasis)))
931    } else {
932        let reduced_ov_explicit: T = lowdin_paired_coefficientss
933            .iter()
934            .map(|lpc| lpc.reduced_overlap())
935            .product();
936        let reduced_ov = (0..structure_constraint.implicit_factor()?)
937            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
938
939        if nzeros == 0 {
940            let nbasis = lowdin_paired_coefficientss[0].nbasis();
941            let w = (0..structure_constraint.implicit_factor()?)
942                .cartesian_product(lowdin_paired_coefficientss.iter())
943                .try_fold(Array2::<T>::zeros((nbasis, nbasis)), |acc, (_, lpc)| {
944                    calc_weighted_codensity_matrix(lpc).map(|w| acc + w)
945                })?;
946            Ok(w.mapv(|v| v * reduced_ov))
947        } else {
948            ensure!(
949                nzeros == 1,
950                "Unexpected number of zero Löwdin overlaps: {nzeros} != 1."
951            );
952            let ps = (0..structure_constraint.implicit_factor()?)
953                .flat_map(|_| {
954                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
955                        lpc.zero_indices()
956                            .iter()
957                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
958                    })
959                })
960                .collect::<Result<Vec<_>, _>>()?;
961            ensure!(
962                ps.len() == 1,
963                "Unexpected number of unweighted codensity matrices ({}) for one zero overlap.",
964                ps.len()
965            );
966            let p_mbar = ps.first().ok_or_else(|| {
967                format_err!("Unable to retrieve the computed unweighted codensity matrix.")
968            })?;
969            Ok(p_mbar.mapv(|v| v * reduced_ov))
970        }
971    }
972}
973
974/// Performs modified Gram--Schmidt orthonormalisation on a set of column vectors in a matrix with
975/// respect to the complex-symmetric or Hermitian dot product.
976///
977/// # Arguments
978///
979/// * `vmat` - Matrix containing column vectors forming a basis for a subspace.
980/// * `complex_symmetric` - A boolean indicating if the vector dot product is complex-symmetric. If
981///   `false`, the conventional Hermitian dot product is used.
982/// * `thresh` - A threshold for determining self-orthogonal vectors.
983///
984/// # Returns
985///
986/// The orthonormal vectors forming a basis for the same subspace collected as column vectors in a
987/// matrix.
988///
989/// # Errors
990///
991/// Errors when the orthonormalisation procedure fails, which occurs when there is linear dependency
992/// between the basis vectors and/or when self-orthogonal vectors are encountered.
993pub fn complex_modified_gram_schmidt<T>(
994    vmat: &ArrayView2<T>,
995    complex_symmetric: bool,
996    thresh: <T as ComplexFloat>::Real,
997) -> Result<Array2<T>, anyhow::Error>
998where
999    T: ComplexFloat + std::fmt::Display + 'static,
1000{
1001    let mut us: Vec<Array1<T>> = Vec::with_capacity(vmat.shape()[1]);
1002    let mut us_sq_norm: Vec<T> = Vec::with_capacity(vmat.shape()[1]);
1003    for (i, vi) in vmat.columns().into_iter().enumerate() {
1004        // u[i] now initialised with v[i]
1005        us.push(vi.to_owned());
1006
1007        // Project ui onto all uj (0 <= j < i)
1008        // This is the 'modified' part of Gram--Schmidt. We project the current (and being updated)
1009        // ui onto uj, rather than projecting vi onto uj. This enhances numerical stability.
1010        for j in 0..i {
1011            let p_uj_ui = if complex_symmetric {
1012                us[j].t().dot(&us[i]) / us_sq_norm[j]
1013            } else {
1014                us[j].t().map(|x| x.conj()).dot(&us[i]) / us_sq_norm[j]
1015            };
1016            us[i] = &us[i] - us[j].map(|&x| x * p_uj_ui);
1017        }
1018
1019        // Evaluate the squared norm of ui which will no longer be changed after this iteration.
1020        // us_sq_norm[i] now available.
1021        let us_sq_norm_i = if complex_symmetric {
1022            us[i].t().dot(&us[i])
1023        } else {
1024            us[i].t().map(|x| x.conj()).dot(&us[i])
1025        };
1026        if us_sq_norm_i.abs() < thresh {
1027            return Err(format_err!("A zero-norm vector found: {}", us[i]));
1028        }
1029        us_sq_norm.push(us_sq_norm_i);
1030    }
1031
1032    // Normalise ui
1033    for i in 0..us.len() {
1034        us[i].mapv_inplace(|x| x / us_sq_norm[i].sqrt());
1035    }
1036
1037    let ortho_check = us.iter().enumerate().all(|(i, ui)| {
1038        us.iter().enumerate().all(|(j, uj)| {
1039            let ov_ij = if complex_symmetric {
1040                ui.dot(uj)
1041            } else {
1042                ui.map(|x| x.conj()).dot(uj)
1043            };
1044            i == j || ov_ij.abs() < thresh
1045        })
1046    });
1047
1048    if ortho_check {
1049        stack(Axis(1), &us.iter().map(|u| u.view()).collect_vec()).map_err(|err| format_err!(err))
1050    } else {
1051        Err(format_err!(
1052            "Post-Gram--Schmidt orthogonality check failed."
1053        ))
1054    }
1055}
1056
1057/// Trait for Löwdin canonical orthogonalisation of a square matrix.
1058pub trait CanonicalOrthogonalisable {
1059    /// Numerical type of the matrix elements.
1060    type NumType;
1061
1062    /// Type of real threshold values.
1063    type RealType;
1064
1065    /// Calculates the Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$ for a square
1066    /// matrix.
1067    ///
1068    /// # Arguments
1069    ///
1070    /// * `complex_symmetric` - Boolean indicating if the orthogonalisation is with respect to the
1071    ///   complex-symmetric inner product.
1072    /// * `preserves_full_rank` - Boolean indicating if a full-rank square matrix should be left
1073    ///   unchanged, thus forcing $`\mathbf{X} = \mathbf{I}`$.
1074    /// * `thresh_offdiag` - Threshold for verifying that the orthogonalised matrix is indeed
1075    ///   orthogonal.
1076    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of the input square matrix.
1077    ///
1078    /// # Returns
1079    ///
1080    /// The canonical orthogonalisation result.
1081    fn calc_canonical_orthogonal_matrix(
1082        &self,
1083        complex_symmetric: bool,
1084        preserves_full_rank: bool,
1085        thresh_offdiag: Self::RealType,
1086        thresh_zeroov: Self::RealType,
1087    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error>;
1088}
1089
1090/// Structure containing the results of the Löwdin canonical orthogonalisation.
1091pub struct CanonicalOrthogonalisationResult<T> {
1092    /// The eigenvalues of the input matrix.
1093    eigenvalues: Array1<T>,
1094
1095    /// The Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$.
1096    xmat: Array2<T>,
1097
1098    /// The conjugate of the Löwdin canonical orthogonalisation matrix,
1099    /// $`\mathbf{X}^{\dagger\lozenge}`$, where $`\lozenge = \star`$ for complex-symmetric matrices
1100    /// and $`\lozenge = \hat{e}`$ otherwise.
1101    xmat_d: Array2<T>,
1102}
1103
1104impl<T> CanonicalOrthogonalisationResult<T> {
1105    /// Returns the eigenvalues of the input matrix.
1106    pub fn eigenvalues(&'_ self) -> ArrayView1<'_, T> {
1107        self.eigenvalues.view()
1108    }
1109
1110    /// Returns the Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$.
1111    pub fn xmat(&'_ self) -> ArrayView2<'_, T> {
1112        self.xmat.view()
1113    }
1114
1115    /// Returns the conjugate of the Löwdin canonical orthogonalisation matrix,
1116    /// $`\mathbf{X}^{\dagger\lozenge}`$, where $`\lozenge = \star`$ for complex-symmetric matrices
1117    /// and $`\lozenge = \hat{e}`$ otherwise.
1118    pub fn xmat_d(&'_ self) -> ArrayView2<'_, T> {
1119        self.xmat_d.view()
1120    }
1121}
1122
1123#[duplicate_item(
1124    [
1125        dtype_ [ f64 ]
1126    ]
1127    [
1128        dtype_ [ f32 ]
1129    ]
1130)]
1131impl CanonicalOrthogonalisable for ArrayView2<'_, dtype_> {
1132    type NumType = dtype_;
1133
1134    type RealType = dtype_;
1135
1136    fn calc_canonical_orthogonal_matrix(
1137        &self,
1138        _: bool,
1139        preserves_full_rank: bool,
1140        thresh_offdiag: dtype_,
1141        thresh_zeroov: dtype_,
1142    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error> {
1143        let smat = self;
1144
1145        // Real, symmetric S
1146        let max_offdiag_s = *(smat.to_owned() - smat.t()).map(|v| v.abs()).iter()
1147                .max_by(|a, b| a.total_cmp(b))
1148                .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap symmetric deviation matrix."))?;
1149        ensure!(
1150            max_offdiag_s <= thresh_offdiag,
1151            "Overlap matrix is not real-symmetric: ||S - S^T||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
1152        );
1153
1154        // S is real-symmetric, so U is orthogonal, i.e. U^T = U^(-1).
1155        let (s_eig, umat) = smat.eigh(UPLO::Lower).map_err(|err| format_err!(err))?;
1156        // Real eigenvalues, so both comparison modes are the same.
1157        let nonzero_s_indices = s_eig
1158            .iter()
1159            .positions(|x| x.abs() > thresh_zeroov)
1160            .collect_vec();
1161        let nonzero_s_eig = s_eig.select(Axis(0), &nonzero_s_indices);
1162        if nonzero_s_eig.iter().any(|v| *v < 0.0) {
1163            return Err(format_err!(
1164                "The matrix has negative eigenvalues and therefore cannot be orthogonalised over the reals."
1165            ));
1166        }
1167        let nonzero_umat = umat.select(Axis(1), &nonzero_s_indices);
1168        let nullity = smat.shape()[0] - nonzero_s_indices.len();
1169        let (xmat, xmat_d) = if nullity == 0 && preserves_full_rank {
1170            (Array2::eye(smat.shape()[0]), Array2::eye(smat.shape()[0]))
1171        } else {
1172            let s_s = Array2::<dtype_>::from_diag(&nonzero_s_eig.mapv(|x| 1.0 / x.sqrt()));
1173            (nonzero_umat.dot(&s_s), s_s.dot(&nonzero_umat.t()))
1174        };
1175        let res = CanonicalOrthogonalisationResult {
1176            eigenvalues: s_eig,
1177            xmat,
1178            xmat_d,
1179        };
1180        Ok(res)
1181    }
1182}
1183
1184impl<T> CanonicalOrthogonalisable for ArrayView2<'_, Complex<T>>
1185where
1186    T: Float + Scalar<Complex = Complex<T>>,
1187    Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
1188{
1189    type NumType = Complex<T>;
1190
1191    type RealType = T;
1192
1193    fn calc_canonical_orthogonal_matrix(
1194        &self,
1195        complex_symmetric: bool,
1196        preserves_full_rank: bool,
1197        thresh_offdiag: T,
1198        thresh_zeroov: T,
1199    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error> {
1200        let smat = self;
1201
1202        if complex_symmetric {
1203            // Complex-symmetric S
1204            let max_offdiag = *(smat.to_owned() - smat.t())
1205                    .mapv(ComplexFloat::abs)
1206                    .iter()
1207                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
1208                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-symmetric deviation matrix."))?;
1209            ensure!(
1210                max_offdiag <= thresh_offdiag,
1211                "Overlap matrix is not complex-symmetric."
1212            );
1213        } else {
1214            // Complex-Hermitian S
1215            let max_offdiag = *(smat.to_owned() - smat.map(|v| v.conj()).t())
1216                    .mapv(ComplexFloat::abs)
1217                    .iter()
1218                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
1219                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-symmetric deviation matrix."))?;
1220            ensure!(
1221                max_offdiag <= thresh_offdiag,
1222                "Overlap matrix is not complex-Hermitian."
1223            );
1224        }
1225
1226        let (s_eig, umat_nonortho) = smat.eig().map_err(|err| format_err!(err))?;
1227        log::debug!("Overlap eigenvalues for canonical orthogonalisation:");
1228        for (i, eig) in s_eig.iter().enumerate() {
1229            log::debug!("  {i}: {eig:+.8e}");
1230        }
1231        log::debug!("");
1232
1233        let nonzero_s_indices = s_eig
1234            .iter()
1235            .positions(|x| ComplexFloat::abs(*x) > thresh_zeroov)
1236            .collect_vec();
1237        log::debug!("Non-zero overlap indices w.r.t. threshold {thresh_zeroov:.8e}:");
1238        log::debug!(
1239            "  {}",
1240            nonzero_s_indices.iter().map(|i| i.to_string()).join(", ")
1241        );
1242        log::debug!("");
1243        let nonzero_s_eig = s_eig.select(Axis(0), &nonzero_s_indices);
1244        let nonzero_umat_nonortho = umat_nonortho.select(Axis(1), &nonzero_s_indices);
1245
1246        // `eig` does not guarantee orthogonality of `nonzero_umat_nonortho`.
1247        // Gram--Schmidt is therefore required.
1248        let nonzero_umat = complex_modified_gram_schmidt(
1249            &nonzero_umat_nonortho.view(),
1250            complex_symmetric,
1251            thresh_zeroov,
1252        )
1253        .map_err(
1254            |_| format_err!("Unable to orthonormalise the linearly-independent eigenvectors of the overlap matrix.")
1255        )?;
1256
1257        let nonzero_s_eig_from_u = if complex_symmetric {
1258            nonzero_umat.t().dot(smat).dot(&nonzero_umat)
1259        } else {
1260            nonzero_umat
1261                .map(|v| v.conj())
1262                .t()
1263                .dot(smat)
1264                .dot(&nonzero_umat)
1265        };
1266
1267        let max_offdiag_s = *(nonzero_s_eig_from_u - Array2::from_diag(&nonzero_s_eig)).mapv(ComplexFloat::abs).iter()
1268                .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
1269                .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap symmetric deviation matrix."))?;
1270        ensure!(
1271            max_offdiag_s <= thresh_offdiag,
1272            if complex_symmetric {
1273                "Canonical orthogonalisation has failed: ||U^T.S.U - s||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
1274            } else {
1275                "Canonical orthogonalisation has failed: ||U^†.S.U - s||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
1276            }
1277        );
1278
1279        let nullity = smat.shape()[0] - nonzero_s_indices.len();
1280        let (xmat, xmat_d) = if nullity == 0 && preserves_full_rank {
1281            (
1282                Array2::<Complex<T>>::eye(smat.shape()[0]),
1283                Array2::<Complex<T>>::eye(smat.shape()[0]),
1284            )
1285        } else {
1286            let s_s = Array2::<Complex<T>>::from_diag(
1287                &nonzero_s_eig.mapv(|x| Complex::<T>::from(T::one()) / x.sqrt()),
1288            );
1289            let xmat = nonzero_umat.dot(&s_s);
1290            let xmat_d = if complex_symmetric {
1291                // (nonzero_umat.dot(&s_s), s_s.dot(&nonzero_umat.t()))
1292                xmat.t().to_owned()
1293            } else {
1294                xmat.map(|v| v.conj()).t().to_owned()
1295            };
1296            (xmat, xmat_d)
1297        };
1298        let res = CanonicalOrthogonalisationResult {
1299            eigenvalues: s_eig,
1300            xmat,
1301            xmat_d,
1302        };
1303        Ok(res)
1304    }
1305}