qsym2/target/noci/backend/nonortho/
mod.rs

1use std::fmt::LowerExp;
2use std::iter::Product;
3
4use anyhow::{self, ensure, format_err};
5use derive_builder::Builder;
6use duplicate::duplicate_item;
7use indexmap::IndexSet;
8use itertools::Itertools;
9use ndarray::{
10    Array1, Array2, ArrayView1, ArrayView2, ArrayView4, Axis, Ix0, Ix2, ScalarOperand, stack,
11};
12use ndarray_einsum::einsum;
13use ndarray_linalg::types::Lapack;
14use ndarray_linalg::{Determinant, Eig, Eigh, SVD, Scalar, UPLO};
15use num::{Complex, Float};
16use num_complex::ComplexFloat;
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19
20use super::denmat::{calc_unweighted_codensity_matrix, calc_weighted_codensity_matrix};
21
22#[cfg(test)]
23#[path = "nonortho_tests.rs"]
24mod nonortho_tests;
25
26/// Structure containing Löwdin-paired coefficients, the corresponding Löwdin overlaps, and the
27/// indices of the zero overlaps.
28///
29/// The Löwdin-paired coefficients satisfy
30/// ```math
31///     ^{wx}\mathbf{\Lambda}
32///         = \mathrm{diag}(^{wx}\lambda_i)
33///         = ^{w}\!\tilde{\mathbf{C}}^{\dagger\lozenge}
34///           \ \mathbf{S}_{\mathrm{AO}}
35///           \ ^{x}\tilde{\mathbf{C}}.
36/// ```
37#[derive(Builder)]
38#[builder(build_fn(validate = "Self::validate"))]
39pub struct LowdinPairedCoefficients<T: ComplexFloat> {
40    /// The $`^{w}\!\tilde{\mathbf{C}}`$ coefficient matrix.
41    paired_cw: Array2<T>,
42
43    /// The $`^{x}\!\tilde{\mathbf{C}}`$ coefficient matrix.
44    paired_cx: Array2<T>,
45
46    /// The Löwdin overlaps $`\{^{wx}\lambda_i\}`$.
47    lowdin_overlaps: Vec<T>,
48
49    /// The indices of the zero overlaps with respect to [`Self::thresh_zeroov`]. If these are not
50    /// provided, the specified [`Self::thresh_zeroov`] will be used to deduce them from the
51    /// specified `[Self::lowdin_overlaps]`.
52    #[builder(default = "self.default_zero_indices()?")]
53    zero_indices: IndexSet<usize>,
54
55    /// Threshold for determining zero overlaps.
56    thresh_zeroov: T::Real,
57
58    /// Boolean indicating whether the coefficients have been Löwdin-paired with respect to the
59    /// complex-symmetric inner product.
60    complex_symmetric: bool,
61}
62
63impl<T: ComplexFloat> LowdinPairedCoefficientsBuilder<T> {
64    fn validate(&self) -> Result<(), String> {
65        let paired_cw = self
66            .paired_cw
67            .as_ref()
68            .ok_or("Löwdin-paired coefficients `paired_cw` not set.".to_string())?;
69        let paired_cx = self
70            .paired_cx
71            .as_ref()
72            .ok_or("Löwdin-paired coefficients `paired_cx` not set.".to_string())?;
73        let lowdin_overlaps = self
74            .lowdin_overlaps
75            .as_ref()
76            .ok_or("Löwdin overlaps not set.".to_string())?;
77        let zero_indices = self
78            .zero_indices
79            .as_ref()
80            .ok_or("Indices of zero Löwdin overlaps not set.".to_string())?;
81
82        if paired_cw.shape() == paired_cx.shape() {
83            let lowdin_dim = paired_cw.shape()[1];
84            if lowdin_dim == lowdin_overlaps.len() {
85                if zero_indices.iter().all(|i| *i < lowdin_dim) {
86                    Ok(())
87                } else {
88                    Err("Some indices of zero Löwdin overlaps are out-of-bound!".to_string())
89                }
90            } else {
91                Err(
92                    "Inconsistent number of Löwdin-paired orbitals and Löwdin overlaps."
93                        .to_string(),
94                )
95            }
96        } else {
97            Err(format!(
98                "Inconsistent shapes between `paired_cw` ({:?}) and `paired_cx` ({:?}).",
99                paired_cw.shape(),
100                paired_cx.shape()
101            ))
102        }
103    }
104
105    fn default_zero_indices(&self) -> Result<IndexSet<usize>, String> {
106        let lowdin_overlaps = self
107            .lowdin_overlaps
108            .as_ref()
109            .ok_or("Löwdin overlaps not set.".to_string())?;
110        let thresh_zeroov = self
111            .thresh_zeroov
112            .as_ref()
113            .ok_or("threshold for zero Löwdin overlaps not set.".to_string())?;
114        let zero_indices = lowdin_overlaps
115            .iter()
116            .enumerate()
117            .filter(|(_, ov)| ComplexFloat::abs(**ov) < *thresh_zeroov)
118            .map(|(i, _)| i)
119            .collect::<IndexSet<_>>();
120        Ok(zero_indices)
121    }
122}
123
124impl<T: ComplexFloat> LowdinPairedCoefficients<T> {
125    pub fn builder() -> LowdinPairedCoefficientsBuilder<T> {
126        LowdinPairedCoefficientsBuilder::<T>::default()
127    }
128
129    /// Returns the Löwdin-paired coefficients.
130    pub fn paired_coefficients(&self) -> (&Array2<T>, &Array2<T>) {
131        (&self.paired_cw, &self.paired_cx)
132    }
133
134    /// Returns the number of atomic-orbital basis functions in which the coefficient matrices are
135    /// expressed.
136    pub fn nbasis(&self) -> usize {
137        self.paired_cw.nrows()
138    }
139
140    /// Returns the number of molecular orbitals being Löwdin-paired.
141    pub fn lowdin_dim(&self) -> usize {
142        self.lowdin_overlaps.len()
143    }
144
145    /// Returns the number of zero Löwdin overlaps.
146    pub fn n_lowdin_zeros(&self) -> usize {
147        self.zero_indices.len()
148    }
149
150    /// Returns the Löwdin overlaps.
151    pub fn lowdin_overlaps(&self) -> &Vec<T> {
152        &self.lowdin_overlaps
153    }
154
155    /// Returns the indices of the zero Löwdin overlaps.
156    pub fn zero_indices(&self) -> &IndexSet<usize> {
157        &self.zero_indices
158    }
159
160    /// Returns the threshold for determining the zero Löwdin overlaps.
161    pub fn thresh_zeroov(&self) -> T::Real {
162        self.thresh_zeroov
163    }
164
165    /// Returns the indices of the non-zero Löwdin overlaps.
166    pub fn nonzero_indices(&self) -> IndexSet<usize> {
167        (0..self.lowdin_dim())
168            .filter(|i| !self.zero_indices.contains(i))
169            .collect::<IndexSet<_>>()
170    }
171
172    /// Returns the boolean indicating whether the Löwdin pairing is with respect to the
173    /// complex-symmetric inner product.
174    pub fn complex_symmetric(&self) -> bool {
175        self.complex_symmetric
176    }
177}
178
179impl<T: ComplexFloat + Product> LowdinPairedCoefficients<T> {
180    /// The reduced overlap between the two determinants.
181    pub fn reduced_overlap(&self) -> T {
182        self.nonzero_indices()
183            .iter()
184            .map(|i| self.lowdin_overlaps[*i])
185            .product()
186    }
187}
188
189/// Performs Löwdin pairing on two coefficient matrices $`^{w}\mathbf{C}`$ and
190/// $`^{x}\mathbf{C}`$.
191///
192/// Löwdin pairing ensures that
193/// ```math
194///     ^{wx}\mathbf{\Lambda}
195///         = \mathrm{diag}(^{wx}\lambda_i)
196///         = ^{w}\!\tilde{\mathbf{C}}^{\dagger\lozenge}
197///           \ \mathbf{S}_{\mathrm{AO}}
198///           \ ^{x}\tilde{\mathbf{C}},
199/// ```
200/// where the Löwdin-paired coefficient matrices are given by
201/// ```math
202///     \begin{align*}
203///         ^{w}\!\tilde{\mathbf{C}}
204///             &=\ ^{w}\mathbf{C}\ ^{wx}\mathbf{U}^{\lozenge} \\
205///         ^{x}\!\tilde{\mathbf{C}}
206///             &=\ ^{x}\mathbf{C}\ ^{wx}\mathbf{V}
207///     \end{align*}
208/// ```
209/// with $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$ being SVD factorisation matrices:
210/// ```math
211///     ^{w}\mathbf{C}^{\dagger\lozenge}
212///     \ \mathbf{S}_{\mathrm{AO}}
213///     \ ^{x}\mathbf{C}
214///     =
215///     \ ^{wx}\mathbf{U}
216///     \ ^{wx}\mathbf{\Lambda}
217///     \ ^{wx}\mathbf{V}^{\dagger}.
218/// ```
219///
220/// We note that the first columns of $`^{w}\!\tilde{\mathbf{C}}`$ and $`^{x}\!\tilde{\mathbf{C}}`$
221/// are also adjusted by the determinants of $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$ as
222/// appropriate to ensure that the Slater determinants corresponding to them remain invariant with
223/// respect to the unitary transformations brought about by $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$.
224///
225/// # Arguments
226///
227/// * `cw` - Coefficient matrix $`^{w}\mathbf{C}`$.
228/// * `cx` - Coefficient matrix $`^{x}\mathbf{C}`$.
229/// * `sao` - The overlap matrix $`\mathbf{S}_{\mathrm{AO}}`$ of the underlying atomic basis
230///   functions.
231/// * `complex_symmetric` - If `true`, $`\lozenge = \star`$. If `false`, $`\lozenge = \hat{e}`$.
232/// * `thresh_offdiag` - Threshold to check if the off-diagonal elements in the original orbital
233///   overlap matrix and in the Löwdin-paired orbital overlap matrix $`^{wx}\mathbf{\Lambda}`$ are zero.
234/// * `thresh_zeroov` - Threshold to identify which Löwdin overlaps $`^{wx}\lambda_i`$ are zero.
235///
236/// # Returns
237///
238/// A [`LowdinPairedCoefficients`] structure containing the result of the Löwdin pairing.
239pub fn calc_lowdin_pairing<T>(
240    cw: &ArrayView2<T>,
241    cx: &ArrayView2<T>,
242    sao: &ArrayView2<T>,
243    complex_symmetric: bool,
244    thresh_offdiag: <T as ComplexFloat>::Real,
245    thresh_zeroov: <T as ComplexFloat>::Real,
246) -> Result<LowdinPairedCoefficients<T>, anyhow::Error>
247where
248    T: ComplexFloat + Lapack,
249    <T as ComplexFloat>::Real: PartialOrd + LowerExp,
250{
251    if cw.shape() != cx.shape() {
252        Err(format_err!(
253            "Coefficient dimensions mismatched: cw ({:?}) !~ cx ({:?}).",
254            cw.shape(),
255            cx.shape()
256        ))
257    } else {
258        let init_orb_ovmat = if complex_symmetric {
259            einsum("ji,jk,kl->il", &[cw, sao, cx])
260        } else {
261            einsum("ji,jk,kl->il", &[&cw.map(|x| x.conj()), sao, cx])
262        }
263        .map_err(|err| format_err!(err))?
264        .into_dimensionality::<Ix2>()?;
265
266        let max_offdiag = (&init_orb_ovmat - &Array2::from_diag(&init_orb_ovmat.diag().to_owned()))
267            .iter()
268            .map(|x| ComplexFloat::abs(*x))
269            .max_by(|x, y| {
270                x.partial_cmp(y)
271                    .expect("Unable to compare two `abs` values.")
272            })
273            .ok_or_else(|| {
274                format_err!(
275                    "Unable to determine the maximum off-diagonal element for\n{}.",
276                    &init_orb_ovmat
277                )
278            })?;
279
280        if max_offdiag <= thresh_offdiag {
281            let lowdin_overlaps = init_orb_ovmat.into_diag().to_vec();
282            let zero_indices = lowdin_overlaps
283                .iter()
284                .enumerate()
285                .filter(|(_, ov)| ComplexFloat::abs(**ov) < thresh_zeroov)
286                .map(|(i, _)| i)
287                .collect::<IndexSet<_>>();
288            LowdinPairedCoefficients::builder()
289                .paired_cw(cw.to_owned())
290                .paired_cx(cx.to_owned())
291                .lowdin_overlaps(lowdin_overlaps)
292                .zero_indices(zero_indices)
293                .thresh_zeroov(thresh_zeroov)
294                .complex_symmetric(complex_symmetric)
295                .build()
296                .map_err(|err| format_err!(err))
297        } else {
298            let (u_opt, _, vh_opt) = init_orb_ovmat.svd(true, true)?;
299            let u = u_opt.ok_or_else(|| format_err!("Unable to compute the U matrix from SVD."))?;
300            let vh =
301                vh_opt.ok_or_else(|| format_err!("Unable to compute the V matrix from SVD."))?;
302            let v = vh.t().map(|x| x.conj());
303            let det_v_c = v.det()?.conj();
304
305            let paired_cw = if complex_symmetric {
306                let uc = u.map(|x| x.conj());
307                let mut cwt = cw.dot(&uc);
308                let det_uc_c = uc.det()?.conj();
309                cwt.column_mut(0).iter_mut().for_each(|x| *x *= det_uc_c);
310                cwt
311            } else {
312                let mut cwt = cw.dot(&u);
313                let det_u_c = u.det()?.conj();
314                cwt.column_mut(0).iter_mut().for_each(|x| *x *= det_u_c);
315                cwt
316            };
317
318            let paired_cx = {
319                let mut cxt = cx.dot(&v);
320                cxt.column_mut(0).iter_mut().for_each(|x| *x *= det_v_c);
321                cxt
322            };
323
324            let lowdin_orb_ovmat = if complex_symmetric {
325                einsum("ji,jk,kl->il", &[&paired_cw, sao, &paired_cx])
326            } else {
327                einsum(
328                    "ji,jk,kl->il",
329                    &[&paired_cw.map(|x| x.conj()), sao, &paired_cx],
330                )
331            }
332            .map_err(|err| format_err!(err))?
333            .into_dimensionality::<Ix2>()?;
334
335            let max_offdiag_lowdin = (&lowdin_orb_ovmat - &Array2::from_diag(&lowdin_orb_ovmat.diag().to_owned()))
336                .iter()
337                .map(|x| ComplexFloat::abs(*x))
338                .max_by(|x, y| {
339                    x.partial_cmp(y)
340                        .expect("Unable to compare two `abs` values.")
341                })
342                .ok_or_else(|| format_err!("Unable to determine the maximum off-diagonal element of the Lowdin-paired overlap matrix."))?;
343            if max_offdiag_lowdin <= thresh_offdiag {
344                let lowdin_overlaps = lowdin_orb_ovmat.into_diag().to_vec();
345                let zero_indices = lowdin_overlaps
346                    .iter()
347                    .enumerate()
348                    .filter(|(_, ov)| ComplexFloat::abs(**ov) < thresh_zeroov)
349                    .map(|(i, _)| i)
350                    .collect::<IndexSet<_>>();
351                LowdinPairedCoefficients::builder()
352                    .paired_cw(paired_cw.clone())
353                    .paired_cx(paired_cx.clone())
354                    .lowdin_overlaps(lowdin_overlaps)
355                    .zero_indices(zero_indices)
356                    .thresh_zeroov(thresh_zeroov)
357                    .complex_symmetric(complex_symmetric)
358                    .build()
359                    .map_err(|err| format_err!(err))
360            } else {
361                Err(format_err!(
362                    "Löwdin overlap matrix deviates from diagonality. Maximum off-diagonal overlap has magnitude {max_offdiag_lowdin:.3e} > threshold of {thresh_offdiag:.3e}. Löwdin pairing has failed."
363                ))
364            }
365        }
366    }
367}
368
369/// Calculates the matrix element of a zero-particle operator between two Löwdin-paired
370/// determinants.
371///
372/// # Arguments
373///
374/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
375/// subspace determined by the specified structure constraint.
376/// `o0` - The zero-particle operator.
377/// `structure_constraint` - The structure constraint governing the coefficients.
378///
379/// # Returns
380///
381/// The zero-particle matrix element.
382pub fn calc_o0_matrix_element<T, SC>(
383    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
384    o0: T,
385    structure_constraint: &SC,
386) -> Result<T, anyhow::Error>
387where
388    T: ComplexFloat + ScalarOperand + Product,
389    SC: StructureConstraint,
390{
391    let nzeros_explicit: usize = lowdin_paired_coefficientss
392        .iter()
393        .map(|lpc| lpc.n_lowdin_zeros())
394        .sum();
395    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
396    if nzeros > 0 {
397        Ok(T::zero())
398    } else {
399        let reduced_ov_explicit: T = lowdin_paired_coefficientss
400            .iter()
401            .map(|lpc| lpc.reduced_overlap())
402            .product();
403        let reduced_ov = (0..structure_constraint.implicit_factor()?)
404            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
405        Ok(reduced_ov * o0)
406    }
407}
408
409/// Calculates the matrix element of a one-particle operator between two Löwdin-paired
410/// determinants.
411///
412/// # Arguments
413///
414/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
415/// subspace determined by the specified structure constraint.
416/// `o1` - The one-particle operator in the atomic-orbital basis.
417/// `structure_constraint` - The structure constraint governing the coefficients.
418///
419/// # Returns
420///
421/// The one-particle matrix element.
422pub fn calc_o1_matrix_element<T, SC>(
423    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
424    o1: &ArrayView2<T>,
425    structure_constraint: &SC,
426) -> Result<T, anyhow::Error>
427where
428    T: ComplexFloat + ScalarOperand + Product,
429    SC: StructureConstraint,
430{
431    let nzeros_explicit: usize = lowdin_paired_coefficientss
432        .iter()
433        .map(|lpc| lpc.n_lowdin_zeros())
434        .sum();
435    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
436    if nzeros > 1 {
437        Ok(T::zero())
438    } else {
439        let reduced_ov_explicit: T = lowdin_paired_coefficientss
440            .iter()
441            .map(|lpc| lpc.reduced_overlap())
442            .product();
443        let reduced_ov = (0..structure_constraint.implicit_factor()?)
444            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
445
446        if nzeros == 0 {
447            let nbasis = lowdin_paired_coefficientss[0].nbasis();
448            let w = (0..structure_constraint.implicit_factor()?)
449                .cartesian_product(lowdin_paired_coefficientss.iter())
450                .try_fold(Array2::<T>::zeros((nbasis, nbasis)), |acc, (_, lpc)| {
451                    calc_weighted_codensity_matrix(lpc).map(|w| acc + w)
452                })?;
453            // i = μ, j = μ'
454            einsum("ij,ji->", &[o1, &w.view()])
455                .map_err(|err| format_err!(err))?
456                .into_dimensionality::<Ix0>()?
457                .into_iter()
458                .next()
459                .ok_or_else(|| {
460                    format_err!("Unable to extract the result of the einsum contraction.")
461                })
462                .map(|v| v * reduced_ov)
463        } else {
464            ensure!(
465                nzeros == 1,
466                "Unexpected number of zero Löwdin overlaps: {nzeros} != 1."
467            );
468            let ps = (0..structure_constraint.implicit_factor()?)
469                .flat_map(|_| {
470                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
471                        lpc.zero_indices()
472                            .iter()
473                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
474                    })
475                })
476                .collect::<Result<Vec<_>, _>>()?;
477            ensure!(
478                ps.len() == 1,
479                "Unexpected number of unweighted codensity matrices ({}) for one zero overlap.",
480                ps.len()
481            );
482            let p_mbar = ps.first().ok_or_else(|| {
483                format_err!("Unable to retrieve the computed unweighted codensity matrix.")
484            })?;
485
486            // i = μ, j = μ'
487            einsum("ij,ji->", &[o1, &p_mbar.view()])
488                .map_err(|err| format_err!(err))?
489                .into_dimensionality::<Ix0>()?
490                .into_iter()
491                .next()
492                .ok_or_else(|| {
493                    format_err!("Unable to extract the result of the einsum contraction.")
494                })
495                .map(|v| v * reduced_ov)
496        }
497    }
498}
499
500/// Calculates the matrix element of a two-particle operator between two Löwdin-paired
501/// determinants.
502///
503/// # Arguments
504///
505/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
506/// subspace determined by the specified structure constraint.
507/// `o2_opt` - The two-particle operator in the atomic-orbital basis.
508/// `structure_constraint` - The structure constraint governing the coefficients.
509///
510/// # Returns
511///
512/// The two-particle matrix element.
513pub fn calc_o2_matrix_element<'a, 'b, T, SC, F>(
514    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
515    o2_opt: Option<&'b ArrayView4<'a, T>>,
516    get_jk_opt: Option<&F>,
517    structure_constraint: &SC,
518) -> Result<T, anyhow::Error>
519where
520    'a: 'b,
521    T: ComplexFloat + ScalarOperand + Product + std::fmt::Display,
522    SC: StructureConstraint,
523    F: Fn(&Array2<T>) -> Result<(Array2<T>, Array2<T>), anyhow::Error>,
524{
525    let nzeros_explicit: usize = lowdin_paired_coefficientss
526        .iter()
527        .map(|lpc| lpc.n_lowdin_zeros())
528        .sum();
529    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
530    if nzeros > 2 {
531        Ok(T::zero())
532    } else {
533        let reduced_ov_explicit: T = lowdin_paired_coefficientss
534            .iter()
535            .map(|lpc| lpc.reduced_overlap())
536            .product();
537        let reduced_ov = (0..structure_constraint.implicit_factor()?)
538            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
539
540        if nzeros == 0 {
541            let nbasis = lowdin_paired_coefficientss[0].nbasis();
542            let w_sigmas = (0..structure_constraint.implicit_factor()?)
543                .cartesian_product(lowdin_paired_coefficientss.iter())
544                .map(|(_, lpc)| calc_weighted_codensity_matrix(lpc))
545                .collect::<Result<Vec<_>, _>>()?;
546            let w = w_sigmas
547                .iter()
548                .fold(Array2::<T>::zeros((nbasis, nbasis)), |acc, w_sigma| {
549                    acc + w_sigma
550                });
551
552            match (o2_opt, get_jk_opt) {
553                (Some(o2), None) => {
554                    // i = μ, j = μ', k = ν, l = ν'
555                    let j_term = einsum("ikjl,ji,lk->", &[o2, &w.view(), &w.view()])
556                        .map_err(|err| format_err!(err))?
557                        .into_dimensionality::<Ix0>()?
558                        .into_iter()
559                        .next()
560                        .ok_or_else(|| {
561                            format_err!("Unable to extract the result of the einsum contraction.")
562                        })
563                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
564                    let k_term = w_sigmas
565                        .iter()
566                        .try_fold(T::zero(), |acc, w_sigma| {
567                            einsum("ikjl,li,jk->", &[o2, &w_sigma.view(), &w_sigma.view()])
568                                .map_err(|err| format_err!(err))?
569                                .into_dimensionality::<Ix0>()?
570                                .into_iter()
571                                .next()
572                                .ok_or_else(|| {
573                                    format_err!(
574                                        "Unable to extract the result of the einsum contraction."
575                                    )
576                                })
577                                .map(|v| acc + v)
578                        })
579                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
580                    Ok(j_term - k_term)
581                }
582                (None, Some(get_jk)) => {
583                    // i = μ, j = μ', k = ν, l = ν'
584                    let (j_w, _) = get_jk(&w)?;
585                    let j_term = einsum("ij,ji->", &[&j_w.view(), &w.view()])
586                        .map_err(|err| format_err!(err))?
587                        .into_dimensionality::<Ix0>()?
588                        .into_iter()
589                        .next()
590                        .ok_or_else(|| {
591                            format_err!("Unable to extract the result of the einsum contraction.")
592                        })
593                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
594                    let k_term = w_sigmas
595                        .iter()
596                        .try_fold(T::zero(), |acc, w_sigma| {
597                            let (_, k_w_sigma) = get_jk(w_sigma)?;
598                            einsum("il,li->", &[&k_w_sigma.view(), &w_sigma.view()])
599                                .map_err(|err| format_err!(err))?
600                                .into_dimensionality::<Ix0>()?
601                                .into_iter()
602                                .next()
603                                .ok_or_else(|| {
604                                    format_err!(
605                                        "Unable to extract the result of the einsum contraction."
606                                    )
607                                })
608                                .map(|v| acc + v)
609                        })
610                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
611                    Ok(j_term - k_term)
612                }
613                _ => Err(format_err!(
614                    "One and only one of `o2` or `get_jk` should be provided."
615                )),
616            }
617        } else if nzeros == 1 {
618            ensure!(
619                nzeros_explicit == 1,
620                "Unexpected number of explicit zero Löwdin overlaps: {nzeros_explicit} != 1."
621            );
622
623            let nbasis = lowdin_paired_coefficientss[0].nbasis();
624            let w = (0..structure_constraint.implicit_factor()?)
625                .cartesian_product(lowdin_paired_coefficientss.iter())
626                .try_fold(Array2::<T>::zeros((nbasis, nbasis)), |acc, (_, lpc)| {
627                    calc_weighted_codensity_matrix(lpc).map(|w_sigma| acc + w_sigma)
628                })?;
629
630            lowdin_paired_coefficientss
631                .iter()
632                .filter_map(|lpc| {
633                    if lpc.n_lowdin_zeros() == 1 {
634                        let w_sigma_res = calc_weighted_codensity_matrix(lpc);
635                        let mbar = lpc.zero_indices()[0];
636                        let p_mbar_sigma_res = calc_unweighted_codensity_matrix(lpc, mbar);
637                        Some((w_sigma_res, p_mbar_sigma_res))
638                    } else {
639                        None
640                    }
641                })
642                .try_fold(T::zero(), |acc, (w_sigma_res, p_mbar_sigma_res)| {
643                    w_sigma_res.and_then(|w_sigma| {
644                        p_mbar_sigma_res.and_then(|p_mbar_sigma| {
645                            match (o2_opt, get_jk_opt) {
646                                (Some(o2), None) => {
647                                    // i = μ, j = μ', k = ν, l = ν'
648                                    let j_term_1 = einsum(
649                                        "ikjl,ji,lk->",
650                                        &[o2, &w.view(), &p_mbar_sigma.view()],
651                                    )
652                                    .map_err(|err| format_err!(err))?
653                                    .into_dimensionality::<Ix0>()?
654                                    .into_iter()
655                                    .next()
656                                    .ok_or_else(|| {
657                                        format_err!(
658                                            "Unable to extract the result of the einsum contraction."
659                                        )
660                                    })?;
661                                    let j_term_2 = einsum(
662                                        "ikjl,ji,lk->",
663                                        &[o2, &p_mbar_sigma.view(), &w.view()],
664                                    )
665                                    .map_err(|err| format_err!(err))?
666                                    .into_dimensionality::<Ix0>()?
667                                    .into_iter()
668                                    .next()
669                                    .ok_or_else(|| {
670                                        format_err!(
671                                            "Unable to extract the result of the einsum contraction."
672                                        )
673                                    })?;
674                                    let k_term_1 = einsum(
675                                        "ikjl,li,jk->",
676                                        &[o2, &w_sigma.view(), &p_mbar_sigma.view()],
677                                    )
678                                    .map_err(|err| format_err!(err))?
679                                    .into_dimensionality::<Ix0>()?
680                                    .into_iter()
681                                    .next()
682                                    .ok_or_else(|| {
683                                        format_err!(
684                                            "Unable to extract the result of the einsum contraction."
685                                        )
686                                    })?;
687                                    let k_term_2 = einsum(
688                                        "ikjl,li,jk->",
689                                        &[o2, &p_mbar_sigma.view(), &w_sigma.view()],
690                                    )
691                                    .map_err(|err| format_err!(err))?
692                                    .into_dimensionality::<Ix0>()?
693                                    .into_iter()
694                                    .next()
695                                    .ok_or_else(|| {
696                                        format_err!(
697                                            "Unable to extract the result of the einsum contraction."
698                                        )
699                                    })?;
700                                    Ok(acc + j_term_1 + j_term_2 - k_term_1 - k_term_2)
701                                },
702                                (None, Some(get_jk)) => {
703                                    // i = μ, j = μ', k = ν, l = ν'
704                                    let (j_p_mbar_sigma, k_p_mbar_sigma) = get_jk(&p_mbar_sigma)?;
705                                    let j_term_1 = einsum(
706                                        "ij,ji->",
707                                        &[&j_p_mbar_sigma.view(), &w.view()],
708                                    )
709                                    .map_err(|err| format_err!(err))?
710                                    .into_dimensionality::<Ix0>()?
711                                    .into_iter()
712                                    .next()
713                                    .ok_or_else(|| {
714                                        format_err!(
715                                            "Unable to extract the result of the einsum contraction."
716                                        )
717                                    })?;
718                                    let (j_w, _) = get_jk(&w)?;
719                                    let j_term_2 = einsum(
720                                        "ij,ji->",
721                                        &[&j_w.view(), &p_mbar_sigma.view()],
722                                    )
723                                    .map_err(|err| format_err!(err))?
724                                    .into_dimensionality::<Ix0>()?
725                                    .into_iter()
726                                    .next()
727                                    .ok_or_else(|| {
728                                        format_err!(
729                                            "Unable to extract the result of the einsum contraction."
730                                        )
731                                    })?;
732                                    let k_term_1 = einsum(
733                                        "il,li->",
734                                        &[&k_p_mbar_sigma.view(), &w_sigma.view()],
735                                    )
736                                    .map_err(|err| format_err!(err))?
737                                    .into_dimensionality::<Ix0>()?
738                                    .into_iter()
739                                    .next()
740                                    .ok_or_else(|| {
741                                        format_err!(
742                                            "Unable to extract the result of the einsum contraction."
743                                        )
744                                    })?;
745                                    let (_, k_w_sigma) = get_jk(&w_sigma)?;
746                                    let k_term_2 = einsum(
747                                        "il,li->",
748                                        &[&k_w_sigma.view(), &p_mbar_sigma.view()],
749                                    )
750                                    .map_err(|err| format_err!(err))?
751                                    .into_dimensionality::<Ix0>()?
752                                    .into_iter()
753                                    .next()
754                                    .ok_or_else(|| {
755                                        format_err!(
756                                            "Unable to extract the result of the einsum contraction."
757                                        )
758                                    })?;
759                                    Ok(acc + j_term_1 + j_term_2 - k_term_1 - k_term_2)
760                                }
761                                _ => Err(format_err!(
762                                    "One and only one of `o2` or `get_jk` should be provided."
763                                )),
764                            }
765                        })
766                    })
767                })
768                .map(|v| v * reduced_ov / (T::one() + T::one()))
769        } else {
770            ensure!(
771                nzeros == 2,
772                "Unexpected number of zero Löwdin overlaps: {nzeros} != 2."
773            );
774
775            let ps = (0..structure_constraint.implicit_factor()?)
776                .flat_map(|_| {
777                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
778                        lpc.zero_indices()
779                            .iter()
780                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
781                    })
782                })
783                .collect::<Result<Vec<_>, _>>()?;
784            ensure!(
785                ps.len() == 2,
786                "Unexpected number of unweighted codensity matrices ({}) for two zero overlaps.",
787                ps.len()
788            );
789            let p_mbar = ps.first().ok_or_else(|| {
790                format_err!("Unable to retrieve the first computed unweighted codensity matrix.")
791            })?;
792            let p_nbar = ps.last().ok_or_else(|| {
793                format_err!("Unable to retrieve the second computed unweighted codensity matrix.")
794            })?;
795
796            match (o2_opt, get_jk_opt) {
797                (Some(o2), None) => {
798                    // i = μ, j = μ', k = ν, l = ν'
799                    let j_term_1 = einsum("ikjl,ji,lk->", &[o2, &p_mbar.view(), &p_nbar.view()])
800                        .map_err(|err| format_err!(err))?
801                        .into_dimensionality::<Ix0>()?
802                        .into_iter()
803                        .next()
804                        .ok_or_else(|| {
805                            format_err!("Unable to extract the result of the einsum contraction.")
806                        })?;
807                    let j_term_2 = einsum("ikjl,ji,lk->", &[o2, &p_nbar.view(), &p_mbar.view()])
808                        .map_err(|err| format_err!(err))?
809                        .into_dimensionality::<Ix0>()?
810                        .into_iter()
811                        .next()
812                        .ok_or_else(|| {
813                            format_err!("Unable to extract the result of the einsum contraction.")
814                        })?;
815
816                    let (k_term_1, k_term_2) = if lowdin_paired_coefficientss
817                        .iter()
818                        .any(|lpc| lpc.n_lowdin_zeros() == 2)
819                    {
820                        let k_term_1 =
821                            einsum("ikjl,li,jk->", &[o2, &p_mbar.view(), &p_nbar.view()])
822                                .map_err(|err| format_err!(err))?
823                                .into_dimensionality::<Ix0>()?
824                                .into_iter()
825                                .next()
826                                .ok_or_else(|| {
827                                    format_err!(
828                                        "Unable to extract the result of the einsum contraction."
829                                    )
830                                })?;
831                        let k_term_2 =
832                            einsum("ikjl,li,jk->", &[o2, &p_nbar.view(), &p_mbar.view()])
833                                .map_err(|err| format_err!(err))?
834                                .into_dimensionality::<Ix0>()?
835                                .into_iter()
836                                .next()
837                                .ok_or_else(|| {
838                                    format_err!(
839                                        "Unable to extract the result of the einsum contraction."
840                                    )
841                                })?;
842                        (k_term_1, k_term_2)
843                    } else {
844                        (T::zero(), T::zero())
845                    };
846                    Ok(reduced_ov * (j_term_1 - k_term_1 + j_term_2 - k_term_2)
847                        / (T::one() + T::one()))
848                }
849                (None, Some(get_jk)) => {
850                    // i = μ, j = μ', k = ν, l = ν'
851                    let (j_p_nbar, k_p_nbar) = get_jk(p_nbar)?;
852                    let j_term_1 = einsum("ij,ji->", &[&j_p_nbar.view(), &p_mbar.view()])
853                        .map_err(|err| format_err!(err))?
854                        .into_dimensionality::<Ix0>()?
855                        .into_iter()
856                        .next()
857                        .ok_or_else(|| {
858                            format_err!("Unable to extract the result of the einsum contraction.")
859                        })?;
860                    let (j_p_mbar, k_p_mbar) = get_jk(p_mbar)?;
861                    let j_term_2 = einsum("ij,ji->", &[&j_p_mbar.view(), &p_nbar.view()])
862                        .map_err(|err| format_err!(err))?
863                        .into_dimensionality::<Ix0>()?
864                        .into_iter()
865                        .next()
866                        .ok_or_else(|| {
867                            format_err!("Unable to extract the result of the einsum contraction.")
868                        })?;
869
870                    let (k_term_1, k_term_2) = if lowdin_paired_coefficientss
871                        .iter()
872                        .any(|lpc| lpc.n_lowdin_zeros() == 2)
873                    {
874                        let k_term_1 = einsum("il,li->", &[&k_p_nbar.view(), &p_mbar.view()])
875                            .map_err(|err| format_err!(err))?
876                            .into_dimensionality::<Ix0>()?
877                            .into_iter()
878                            .next()
879                            .ok_or_else(|| {
880                                format_err!(
881                                    "Unable to extract the result of the einsum contraction."
882                                )
883                            })?;
884                        let k_term_2 = einsum("il,li->", &[&k_p_mbar.view(), &p_nbar.view()])
885                            .map_err(|err| format_err!(err))?
886                            .into_dimensionality::<Ix0>()?
887                            .into_iter()
888                            .next()
889                            .ok_or_else(|| {
890                                format_err!(
891                                    "Unable to extract the result of the einsum contraction."
892                                )
893                            })?;
894                        (k_term_1, k_term_2)
895                    } else {
896                        (T::zero(), T::zero())
897                    };
898                    Ok(reduced_ov * (j_term_1 - k_term_1 + j_term_2 - k_term_2)
899                        / (T::one() + T::one()))
900                }
901                _ => Err(format_err!(
902                    "One and only one of `o2` or `get_jk` should be provided."
903                )),
904            }
905        }
906    }
907}
908
909/// Calculates the transition density matrix between two Löwdin-paired determinants.
910///
911/// # Arguments
912///
913/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
914/// subspace determined by the specified structure constraint.
915/// `structure_constraint` - The structure constraint governing the coefficients.
916///
917/// # Returns
918///
919/// The one-particle matrix element.
920pub fn calc_transition_density_matrix<T, SC>(
921    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
922    structure_constraint: &SC,
923) -> Result<Array2<T>, anyhow::Error>
924where
925    T: ComplexFloat + ScalarOperand + Product,
926    SC: StructureConstraint,
927{
928    let nzeros_explicit: usize = lowdin_paired_coefficientss
929        .iter()
930        .map(|lpc| lpc.n_lowdin_zeros())
931        .sum();
932    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
933    let nbasis = lowdin_paired_coefficientss[0].nbasis();
934    if nzeros > 1 {
935        Ok(Array2::<T>::zeros((nbasis, nbasis)))
936    } else {
937        let reduced_ov_explicit: T = lowdin_paired_coefficientss
938            .iter()
939            .map(|lpc| lpc.reduced_overlap())
940            .product();
941        let reduced_ov = (0..structure_constraint.implicit_factor()?)
942            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
943
944        if nzeros == 0 {
945            let nbasis = lowdin_paired_coefficientss[0].nbasis();
946            let w = (0..structure_constraint.implicit_factor()?)
947                .cartesian_product(lowdin_paired_coefficientss.iter())
948                .try_fold(Array2::<T>::zeros((nbasis, nbasis)), |acc, (_, lpc)| {
949                    calc_weighted_codensity_matrix(lpc).map(|w| acc + w)
950                })?;
951            Ok(w.mapv(|v| v * reduced_ov))
952        } else {
953            ensure!(
954                nzeros == 1,
955                "Unexpected number of zero Löwdin overlaps: {nzeros} != 1."
956            );
957            let ps = (0..structure_constraint.implicit_factor()?)
958                .flat_map(|_| {
959                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
960                        lpc.zero_indices()
961                            .iter()
962                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
963                    })
964                })
965                .collect::<Result<Vec<_>, _>>()?;
966            ensure!(
967                ps.len() == 1,
968                "Unexpected number of unweighted codensity matrices ({}) for one zero overlap.",
969                ps.len()
970            );
971            let p_mbar = ps.first().ok_or_else(|| {
972                format_err!("Unable to retrieve the computed unweighted codensity matrix.")
973            })?;
974            Ok(p_mbar.mapv(|v| v * reduced_ov))
975        }
976    }
977}
978
979/// Performs modified Gram--Schmidt orthonormalisation on a set of column vectors in a matrix with
980/// respect to the complex-symmetric or Hermitian dot product.
981///
982/// # Arguments
983///
984/// * `vmat` - Matrix containing column vectors forming a basis for a subspace.
985/// * `complex_symmetric` - A boolean indicating if the vector dot product is complex-symmetric. If
986///   `false`, the conventional Hermitian dot product is used.
987/// * `thresh` - A threshold for determining self-orthogonal vectors.
988///
989/// # Returns
990///
991/// The orthonormal vectors forming a basis for the same subspace collected as column vectors in a
992/// matrix.
993///
994/// # Errors
995///
996/// Errors when the orthonormalisation procedure fails, which occurs when there is linear dependency
997/// between the basis vectors and/or when self-orthogonal vectors are encountered.
998pub fn complex_modified_gram_schmidt<T>(
999    vmat: &ArrayView2<T>,
1000    complex_symmetric: bool,
1001    thresh: <T as ComplexFloat>::Real,
1002) -> Result<Array2<T>, anyhow::Error>
1003where
1004    T: ComplexFloat + std::fmt::Display + 'static,
1005{
1006    let mut us: Vec<Array1<T>> = Vec::with_capacity(vmat.shape()[1]);
1007    let mut us_sq_norm: Vec<T> = Vec::with_capacity(vmat.shape()[1]);
1008    for (i, vi) in vmat.columns().into_iter().enumerate() {
1009        // u[i] now initialised with v[i]
1010        us.push(vi.to_owned());
1011
1012        // Project ui onto all uj (0 <= j < i)
1013        // This is the 'modified' part of Gram--Schmidt. We project the current (and being updated)
1014        // ui onto uj, rather than projecting vi onto uj. This enhances numerical stability.
1015        for j in 0..i {
1016            let p_uj_ui = if complex_symmetric {
1017                us[j].t().dot(&us[i]) / us_sq_norm[j]
1018            } else {
1019                us[j].t().map(|x| x.conj()).dot(&us[i]) / us_sq_norm[j]
1020            };
1021            us[i] = &us[i] - us[j].map(|&x| x * p_uj_ui);
1022        }
1023
1024        // Evaluate the squared norm of ui which will no longer be changed after this iteration.
1025        // us_sq_norm[i] now available.
1026        let us_sq_norm_i = if complex_symmetric {
1027            us[i].t().dot(&us[i])
1028        } else {
1029            us[i].t().map(|x| x.conj()).dot(&us[i])
1030        };
1031        if us_sq_norm_i.abs() < thresh {
1032            return Err(format_err!("A zero-norm vector found: {}", us[i]));
1033        }
1034        us_sq_norm.push(us_sq_norm_i);
1035    }
1036
1037    // Normalise ui
1038    for i in 0..us.len() {
1039        us[i].mapv_inplace(|x| x / us_sq_norm[i].sqrt());
1040    }
1041
1042    let ortho_check = us.iter().enumerate().all(|(i, ui)| {
1043        us.iter().enumerate().all(|(j, uj)| {
1044            let ov_ij = if complex_symmetric {
1045                ui.dot(uj)
1046            } else {
1047                ui.map(|x| x.conj()).dot(uj)
1048            };
1049            i == j || ov_ij.abs() < thresh
1050        })
1051    });
1052
1053    if ortho_check {
1054        stack(Axis(1), &us.iter().map(|u| u.view()).collect_vec()).map_err(|err| format_err!(err))
1055    } else {
1056        Err(format_err!(
1057            "Post-Gram--Schmidt orthogonality check failed."
1058        ))
1059    }
1060}
1061
1062/// Trait for Löwdin canonical orthogonalisation of a square matrix.
1063pub trait CanonicalOrthogonalisable {
1064    /// Numerical type of the matrix elements.
1065    type NumType;
1066
1067    /// Type of real threshold values.
1068    type RealType;
1069
1070    /// Calculates the Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$ for a square
1071    /// matrix.
1072    ///
1073    /// # Arguments
1074    ///
1075    /// * `complex_symmetric` - Boolean indicating if the orthogonalisation is with respect to the
1076    ///   complex-symmetric inner product.
1077    /// * `preserves_full_rank` - Boolean indicating if a full-rank square matrix should be left
1078    ///   unchanged, thus forcing $`\mathbf{X} = \mathbf{I}`$.
1079    /// * `thresh_offdiag` - Threshold for verifying that the orthogonalised matrix is indeed
1080    ///   orthogonal.
1081    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of the input square matrix.
1082    ///
1083    /// # Returns
1084    ///
1085    /// The canonical orthogonalisation result.
1086    fn calc_canonical_orthogonal_matrix(
1087        &self,
1088        complex_symmetric: bool,
1089        preserves_full_rank: bool,
1090        thresh_offdiag: Self::RealType,
1091        thresh_zeroov: Self::RealType,
1092    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error>;
1093}
1094
1095/// Structure containing the results of the Löwdin canonical orthogonalisation.
1096pub struct CanonicalOrthogonalisationResult<T> {
1097    /// The eigenvalues of the input matrix.
1098    eigenvalues: Array1<T>,
1099
1100    /// The Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$.
1101    xmat: Array2<T>,
1102
1103    /// The conjugate of the Löwdin canonical orthogonalisation matrix,
1104    /// $`\mathbf{X}^{\dagger\lozenge}`$, where $`\lozenge = \star`$ for complex-symmetric matrices
1105    /// and $`\lozenge = \hat{e}`$ otherwise.
1106    xmat_d: Array2<T>,
1107}
1108
1109impl<T> CanonicalOrthogonalisationResult<T> {
1110    /// Returns the eigenvalues of the input matrix.
1111    pub fn eigenvalues(&'_ self) -> ArrayView1<'_, T> {
1112        self.eigenvalues.view()
1113    }
1114
1115    /// Returns the Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$.
1116    pub fn xmat(&'_ self) -> ArrayView2<'_, T> {
1117        self.xmat.view()
1118    }
1119
1120    /// Returns the conjugate of the Löwdin canonical orthogonalisation matrix,
1121    /// $`\mathbf{X}^{\dagger\lozenge}`$, where $`\lozenge = \star`$ for complex-symmetric matrices
1122    /// and $`\lozenge = \hat{e}`$ otherwise.
1123    pub fn xmat_d(&'_ self) -> ArrayView2<'_, T> {
1124        self.xmat_d.view()
1125    }
1126}
1127
1128#[duplicate_item(
1129    [
1130        dtype_ [ f64 ]
1131    ]
1132    [
1133        dtype_ [ f32 ]
1134    ]
1135)]
1136impl CanonicalOrthogonalisable for ArrayView2<'_, dtype_> {
1137    type NumType = dtype_;
1138
1139    type RealType = dtype_;
1140
1141    fn calc_canonical_orthogonal_matrix(
1142        &self,
1143        _: bool,
1144        preserves_full_rank: bool,
1145        thresh_offdiag: dtype_,
1146        thresh_zeroov: dtype_,
1147    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error> {
1148        let smat = self;
1149
1150        // Real, symmetric S
1151        let max_offdiag_s = *(smat.to_owned() - smat.t()).map(|v| v.abs()).iter()
1152                .max_by(|a, b| a.total_cmp(b))
1153                .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap symmetric deviation matrix."))?;
1154        ensure!(
1155            max_offdiag_s <= thresh_offdiag,
1156            "Overlap matrix is not real-symmetric: ||S - S^T||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
1157        );
1158
1159        // S is real-symmetric, so U is orthogonal, i.e. U^T = U^(-1).
1160        let (s_eig, umat) = smat.eigh(UPLO::Lower).map_err(|err| format_err!(err))?;
1161        // Real eigenvalues, so both comparison modes are the same.
1162        let nonzero_s_indices = s_eig
1163            .iter()
1164            .positions(|x| x.abs() > thresh_zeroov)
1165            .collect_vec();
1166        let nonzero_s_eig = s_eig.select(Axis(0), &nonzero_s_indices);
1167        if nonzero_s_eig.iter().any(|v| *v < 0.0) {
1168            return Err(format_err!(
1169                "The matrix has negative eigenvalues and therefore cannot be orthogonalised over the reals."
1170            ));
1171        }
1172        let nonzero_umat = umat.select(Axis(1), &nonzero_s_indices);
1173        let nullity = smat.shape()[0] - nonzero_s_indices.len();
1174        let (xmat, xmat_d) = if nullity == 0 && preserves_full_rank {
1175            (Array2::eye(smat.shape()[0]), Array2::eye(smat.shape()[0]))
1176        } else {
1177            let s_s = Array2::<dtype_>::from_diag(&nonzero_s_eig.mapv(|x| 1.0 / x.sqrt()));
1178            (nonzero_umat.dot(&s_s), s_s.dot(&nonzero_umat.t()))
1179        };
1180        let res = CanonicalOrthogonalisationResult {
1181            eigenvalues: s_eig,
1182            xmat,
1183            xmat_d,
1184        };
1185        Ok(res)
1186    }
1187}
1188
1189impl<T> CanonicalOrthogonalisable for ArrayView2<'_, Complex<T>>
1190where
1191    T: Float + Scalar<Complex = Complex<T>>,
1192    Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
1193{
1194    type NumType = Complex<T>;
1195
1196    type RealType = T;
1197
1198    fn calc_canonical_orthogonal_matrix(
1199        &self,
1200        complex_symmetric: bool,
1201        preserves_full_rank: bool,
1202        thresh_offdiag: T,
1203        thresh_zeroov: T,
1204    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error> {
1205        let smat = self;
1206
1207        if complex_symmetric {
1208            // Complex-symmetric S
1209            let max_offdiag = *(smat.to_owned() - smat.t())
1210                    .mapv(ComplexFloat::abs)
1211                    .iter()
1212                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
1213                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-symmetric deviation matrix."))?;
1214            ensure!(
1215                max_offdiag <= thresh_offdiag,
1216                "Overlap matrix is not complex-symmetric."
1217            );
1218        } else {
1219            // Complex-Hermitian S
1220            let max_offdiag = *(smat.to_owned() - smat.map(|v| v.conj()).t())
1221                    .mapv(ComplexFloat::abs)
1222                    .iter()
1223                    .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
1224                    .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap complex-symmetric deviation matrix."))?;
1225            ensure!(
1226                max_offdiag <= thresh_offdiag,
1227                "Overlap matrix is not complex-Hermitian."
1228            );
1229        }
1230
1231        let (s_eig, umat_nonortho) = smat.eig().map_err(|err| format_err!(err))?;
1232        log::debug!("Overlap eigenvalues for canonical orthogonalisation:");
1233        for (i, eig) in s_eig.iter().enumerate() {
1234            log::debug!("  {i}: {eig:+.8e}");
1235        }
1236        log::debug!("");
1237
1238        let nonzero_s_indices = s_eig
1239            .iter()
1240            .positions(|x| ComplexFloat::abs(*x) > thresh_zeroov)
1241            .collect_vec();
1242        log::debug!("Non-zero overlap indices w.r.t. threshold {thresh_zeroov:.8e}:");
1243        log::debug!(
1244            "  {}",
1245            nonzero_s_indices.iter().map(|i| i.to_string()).join(", ")
1246        );
1247        log::debug!("");
1248        let nonzero_s_eig = s_eig.select(Axis(0), &nonzero_s_indices);
1249        let nonzero_umat_nonortho = umat_nonortho.select(Axis(1), &nonzero_s_indices);
1250
1251        // `eig` does not guarantee orthogonality of `nonzero_umat_nonortho`.
1252        // Gram--Schmidt is therefore required.
1253        let nonzero_umat = complex_modified_gram_schmidt(
1254            &nonzero_umat_nonortho.view(),
1255            complex_symmetric,
1256            thresh_zeroov,
1257        )
1258        .map_err(
1259            |_| format_err!("Unable to orthonormalise the linearly-independent eigenvectors of the overlap matrix.")
1260        )?;
1261
1262        let nonzero_s_eig_from_u = if complex_symmetric {
1263            nonzero_umat.t().dot(smat).dot(&nonzero_umat)
1264        } else {
1265            nonzero_umat
1266                .map(|v| v.conj())
1267                .t()
1268                .dot(smat)
1269                .dot(&nonzero_umat)
1270        };
1271
1272        let max_offdiag_s = *(nonzero_s_eig_from_u - Array2::from_diag(&nonzero_s_eig)).mapv(ComplexFloat::abs).iter()
1273                .max_by(|a, b| a.partial_cmp(b).unwrap_or_else(|| panic!("Unable to compare {a} and {b}.")))
1274                .ok_or_else(|| format_err!("Unable to find the maximum absolute value of the overlap symmetric deviation matrix."))?;
1275        ensure!(
1276            max_offdiag_s <= thresh_offdiag,
1277            if complex_symmetric {
1278                "Canonical orthogonalisation has failed: ||U^T.S.U - s||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
1279            } else {
1280                "Canonical orthogonalisation has failed: ||U^†.S.U - s||_∞ = {max_offdiag_s:.3e} > {thresh_offdiag:.3e}."
1281            }
1282        );
1283
1284        let nullity = smat.shape()[0] - nonzero_s_indices.len();
1285        let (xmat, xmat_d) = if nullity == 0 && preserves_full_rank {
1286            (
1287                Array2::<Complex<T>>::eye(smat.shape()[0]),
1288                Array2::<Complex<T>>::eye(smat.shape()[0]),
1289            )
1290        } else {
1291            let s_s = Array2::<Complex<T>>::from_diag(
1292                &nonzero_s_eig.mapv(|x| Complex::<T>::from(T::one()) / x.sqrt()),
1293            );
1294            let xmat = nonzero_umat.dot(&s_s);
1295            let xmat_d = if complex_symmetric {
1296                // (nonzero_umat.dot(&s_s), s_s.dot(&nonzero_umat.t()))
1297                xmat.t().to_owned()
1298            } else {
1299                xmat.map(|v| v.conj()).t().to_owned()
1300            };
1301            (xmat, xmat_d)
1302        };
1303        let res = CanonicalOrthogonalisationResult {
1304            eigenvalues: s_eig,
1305            xmat,
1306            xmat_d,
1307        };
1308        Ok(res)
1309    }
1310}