qsym2/target/noci/backend/nonortho/
mod.rs

1use std::fmt::LowerExp;
2use std::iter::Product;
3
4use anyhow::{self, ensure, format_err};
5use derive_builder::Builder;
6use duplicate::duplicate_item;
7use indexmap::IndexSet;
8use itertools::Itertools;
9use ndarray::{
10    Array1, Array2, ArrayView1, ArrayView2, ArrayView4, Axis, Ix0, Ix2, ScalarOperand, stack,
11};
12use ndarray_einsum::einsum;
13use ndarray_linalg::types::Lapack;
14use ndarray_linalg::{Determinant, Eig, Eigh, Norm, SVD, Scalar, UPLO};
15use num::{Complex, Float};
16use num_complex::ComplexFloat;
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19
20use super::denmat::{calc_unweighted_codensity_matrix, calc_weighted_codensity_matrix};
21
22#[cfg(test)]
23#[path = "nonortho_tests.rs"]
24mod nonortho_tests;
25
26/// Structure containing Löwdin-paired coefficients, the corresponding Löwdin overlaps, and the
27/// indices of the zero overlaps.
28///
29/// The Löwdin-paired coefficients satisfy
30/// ```math
31///     ^{wx}\mathbf{\Lambda}
32///         = \mathrm{diag}(^{wx}\lambda_i)
33///         = ^{w}\!\tilde{\mathbf{C}}^{\dagger\lozenge}
34///           \ \mathbf{S}_{\mathrm{AO}}
35///           \ ^{x}\tilde{\mathbf{C}}.
36/// ```
37#[derive(Builder)]
38#[builder(build_fn(validate = "Self::validate"))]
39pub struct LowdinPairedCoefficients<T: ComplexFloat> {
40    /// The $`^{w}\!\tilde{\mathbf{C}}`$ coefficient matrix.
41    paired_cw: Array2<T>,
42
43    /// The $`^{x}\!\tilde{\mathbf{C}}`$ coefficient matrix.
44    paired_cx: Array2<T>,
45
46    /// The Löwdin overlaps $`\{^{wx}\lambda_i\}`$.
47    lowdin_overlaps: Vec<T>,
48
49    /// The indices of the zero overlaps with respect to [`Self::thresh_zeroov`]. If these are not
50    /// provided, the specified [`Self::thresh_zeroov`] will be used to deduce them from the
51    /// specified `[Self::lowdin_overlaps]`.
52    #[builder(default = "self.default_zero_indices()?")]
53    zero_indices: IndexSet<usize>,
54
55    /// Threshold for determining zero overlaps.
56    thresh_zeroov: T::Real,
57
58    /// Boolean indicating whether the coefficients have been Löwdin-paired with respect to the
59    /// complex-symmetric inner product.
60    complex_symmetric: bool,
61}
62
63impl<T: ComplexFloat> LowdinPairedCoefficientsBuilder<T> {
64    fn validate(&self) -> Result<(), String> {
65        let paired_cw = self
66            .paired_cw
67            .as_ref()
68            .ok_or("Löwdin-paired coefficients `paired_cw` not set.".to_string())?;
69        let paired_cx = self
70            .paired_cx
71            .as_ref()
72            .ok_or("Löwdin-paired coefficients `paired_cx` not set.".to_string())?;
73        let lowdin_overlaps = self
74            .lowdin_overlaps
75            .as_ref()
76            .ok_or("Löwdin overlaps not set.".to_string())?;
77        let zero_indices = self
78            .zero_indices
79            .as_ref()
80            .ok_or("Indices of zero Löwdin overlaps not set.".to_string())?;
81
82        if paired_cw.shape() == paired_cx.shape() {
83            let lowdin_dim = paired_cw.shape()[1];
84            if lowdin_dim == lowdin_overlaps.len() {
85                if zero_indices.iter().all(|i| *i < lowdin_dim) {
86                    Ok(())
87                } else {
88                    Err("Some indices of zero Löwdin overlaps are out-of-bound!".to_string())
89                }
90            } else {
91                Err(
92                    "Inconsistent number of Löwdin-paired orbitals and Löwdin overlaps."
93                        .to_string(),
94                )
95            }
96        } else {
97            Err(format!(
98                "Inconsistent shapes between `paired_cw` ({:?}) and `paired_cx` ({:?}).",
99                paired_cw.shape(),
100                paired_cx.shape()
101            ))
102        }
103    }
104
105    fn default_zero_indices(&self) -> Result<IndexSet<usize>, String> {
106        let lowdin_overlaps = self
107            .lowdin_overlaps
108            .as_ref()
109            .ok_or("Löwdin overlaps not set.".to_string())?;
110        let thresh_zeroov = self
111            .thresh_zeroov
112            .as_ref()
113            .ok_or("threshold for zero Löwdin overlaps not set.".to_string())?;
114        let zero_indices = lowdin_overlaps
115            .iter()
116            .enumerate()
117            .filter(|(_, ov)| ComplexFloat::abs(**ov) < *thresh_zeroov)
118            .map(|(i, _)| i)
119            .collect::<IndexSet<_>>();
120        Ok(zero_indices)
121    }
122}
123
124impl<T: ComplexFloat> LowdinPairedCoefficients<T> {
125    pub fn builder() -> LowdinPairedCoefficientsBuilder<T> {
126        LowdinPairedCoefficientsBuilder::<T>::default()
127    }
128
129    /// Returns the Löwdin-paired coefficients.
130    pub fn paired_coefficients(&self) -> (&Array2<T>, &Array2<T>) {
131        (&self.paired_cw, &self.paired_cx)
132    }
133
134    /// Returns the number of atomic-orbital basis functions in which the coefficient matrices are
135    /// expressed.
136    pub fn nbasis(&self) -> usize {
137        self.paired_cw.nrows()
138    }
139
140    /// Returns the number of molecular orbitals being Löwdin-paired.
141    pub fn lowdin_dim(&self) -> usize {
142        self.lowdin_overlaps.len()
143    }
144
145    /// Returns the number of zero Löwdin overlaps.
146    pub fn n_lowdin_zeros(&self) -> usize {
147        self.zero_indices.len()
148    }
149
150    /// Returns the Löwdin overlaps.
151    pub fn lowdin_overlaps(&self) -> &Vec<T> {
152        &self.lowdin_overlaps
153    }
154
155    /// Returns the indices of the zero Löwdin overlaps.
156    pub fn zero_indices(&self) -> &IndexSet<usize> {
157        &self.zero_indices
158    }
159
160    /// Returns the indices of the non-zero Löwdin overlaps.
161    pub fn nonzero_indices(&self) -> IndexSet<usize> {
162        (0..self.lowdin_dim())
163            .filter(|i| !self.zero_indices.contains(i))
164            .collect::<IndexSet<_>>()
165    }
166
167    /// Returns the boolean indicating whether the Löwdin pairing is with respect to the
168    /// complex-symmetric inner product.
169    pub fn complex_symmetric(&self) -> bool {
170        self.complex_symmetric
171    }
172}
173
174impl<T: ComplexFloat + Product> LowdinPairedCoefficients<T> {
175    /// The reduced overlap between the two determinants.
176    pub fn reduced_overlap(&self) -> T {
177        self.nonzero_indices()
178            .iter()
179            .map(|i| self.lowdin_overlaps[*i])
180            .product()
181    }
182}
183
184/// Performs Löwdin pairing on two coefficient matrices $`^{w}\mathbf{C}`$ and
185/// $`^{x}\mathbf{C}`$.
186///
187/// Löwdin pairing ensures that
188/// ```math
189///     ^{wx}\mathbf{\Lambda}
190///         = \mathrm{diag}(^{wx}\lambda_i)
191///         = ^{w}\!\tilde{\mathbf{C}}^{\dagger\lozenge}
192///           \ \mathbf{S}_{\mathrm{AO}}
193///           \ ^{x}\tilde{\mathbf{C}},
194/// ```
195/// where the Löwdin-paired coefficient matrices are given by
196/// ```math
197///     \begin{align*}
198///         ^{w}\!\tilde{\mathbf{C}}
199///             &=\ ^{w}\mathbf{C}\ ^{wx}\mathbf{U}^{\lozenge} \\
200///         ^{x}\!\tilde{\mathbf{C}}
201///             &=\ ^{x}\mathbf{C}\ ^{wx}\mathbf{V}
202///     \end{align*}
203/// ```
204/// with $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$ being SVD factorisation matrices:
205/// ```math
206///     ^{w}\mathbf{C}^{\dagger\lozenge}
207///     \ \mathbf{S}_{\mathrm{AO}}
208///     \ ^{x}\mathbf{C}
209///     =
210///     \ ^{wx}\mathbf{U}
211///     \ ^{wx}\mathbf{\Lambda}
212///     \ ^{wx}\mathbf{V}^{\dagger}.
213/// ```
214///
215/// We note that the first columns of $`^{w}\!\tilde{\mathbf{C}}`$ and $`^{x}\!\tilde{\mathbf{C}}`$
216/// are also adjusted by the determinants of $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$ as
217/// appropriate to ensure that the Slater determinants corresponding to them remain invariant with
218/// respect to the unitary transformations brought about by $`^{wx}\mathbf{U}`$ and $`^{wx}\mathbf{V}`$.
219///
220/// # Arguments
221///
222/// * `cw` - Coefficient matrix $`^{w}\mathbf{C}`$.
223/// * `cx` - Coefficient matrix $`^{x}\mathbf{C}`$.
224/// * `sao` - The overlap matrix $`\mathbf{S}_{\mathrm{AO}}`$ of the underlying atomic basis
225/// functions.
226/// * `complex_symmetric` - If `true`, $`\lozenge = \star`$. If `false`, $`\lozenge = \hat{e}`$.
227/// * `thresh_offdiag` - Threshold to check if the off-diagonal elements in the original orbital
228/// overlap matrix and in the Löwdin-paired orbital overlap matrix $`^{wx}\mathbf{\Lambda}`$ are zero.
229/// * `thresh_zeroov` - Threshold to identify which Löwdin overlaps $`^{wx}\lambda_i`$ are zero.
230///
231/// # Returns
232///
233/// A [`LowdinPairedCoefficients`] structure containing the result of the Löwdin pairing.
234pub fn calc_lowdin_pairing<T>(
235    cw: &ArrayView2<T>,
236    cx: &ArrayView2<T>,
237    sao: &ArrayView2<T>,
238    complex_symmetric: bool,
239    thresh_offdiag: <T as ComplexFloat>::Real,
240    thresh_zeroov: <T as ComplexFloat>::Real,
241) -> Result<LowdinPairedCoefficients<T>, anyhow::Error>
242where
243    T: ComplexFloat + Lapack,
244    <T as ComplexFloat>::Real: PartialOrd + LowerExp,
245{
246    if cw.shape() != cx.shape() {
247        Err(format_err!(
248            "Coefficient dimensions mismatched: cw ({:?}) !~ cx ({:?}).",
249            cw.shape(),
250            cx.shape()
251        ))
252    } else {
253        let init_orb_ovmat = if complex_symmetric {
254            einsum("ji,jk,kl->il", &[cw, sao, cx])
255        } else {
256            einsum("ji,jk,kl->il", &[&cw.map(|x| x.conj()), sao, cx])
257        }
258        .map_err(|err| format_err!(err))?
259        .into_dimensionality::<Ix2>()?;
260
261        let max_offdiag = (&init_orb_ovmat - &Array2::from_diag(&init_orb_ovmat.diag().to_owned()))
262            .iter()
263            .map(|x| ComplexFloat::abs(*x))
264            .max_by(|x, y| {
265                x.partial_cmp(y)
266                    .expect("Unable to compare two `abs` values.")
267            })
268            .ok_or_else(|| {
269                format_err!(
270                    "Unable to determine the maximum off-diagonal element for\n{}.",
271                    &init_orb_ovmat
272                )
273            })?;
274
275        if max_offdiag <= thresh_offdiag {
276            let lowdin_overlaps = init_orb_ovmat.into_diag().to_vec();
277            let zero_indices = lowdin_overlaps
278                .iter()
279                .enumerate()
280                .filter(|(_, ov)| ComplexFloat::abs(**ov) < thresh_zeroov)
281                .map(|(i, _)| i)
282                .collect::<IndexSet<_>>();
283            LowdinPairedCoefficients::builder()
284                .paired_cw(cw.to_owned())
285                .paired_cx(cx.to_owned())
286                .lowdin_overlaps(lowdin_overlaps)
287                .zero_indices(zero_indices)
288                .thresh_zeroov(thresh_zeroov)
289                .complex_symmetric(complex_symmetric)
290                .build()
291                .map_err(|err| format_err!(err))
292        } else {
293            let (u_opt, _, vh_opt) = init_orb_ovmat.svd(true, true)?;
294            let u = u_opt.ok_or_else(|| format_err!("Unable to compute the U matrix from SVD."))?;
295            let vh =
296                vh_opt.ok_or_else(|| format_err!("Unable to compute the V matrix from SVD."))?;
297            let v = vh.t().map(|x| x.conj());
298            let det_v_c = v.det()?.conj();
299
300            let paired_cw = if complex_symmetric {
301                let uc = u.map(|x| x.conj());
302                let mut cwt = cw.dot(&uc);
303                let det_uc_c = uc.det()?.conj();
304                cwt.column_mut(0)
305                    .iter_mut()
306                    .for_each(|x| *x = *x * det_uc_c);
307                cwt
308            } else {
309                let mut cwt = cw.dot(&u);
310                let det_u_c = u.det()?.conj();
311                cwt.column_mut(0).iter_mut().for_each(|x| *x = *x * det_u_c);
312                cwt
313            };
314
315            let paired_cx = {
316                let mut cxt = cx.dot(&v);
317                cxt.column_mut(0).iter_mut().for_each(|x| *x = *x * det_v_c);
318                cxt
319            };
320
321            let lowdin_orb_ovmat = if complex_symmetric {
322                einsum("ji,jk,kl->il", &[&paired_cw, sao, &paired_cx])
323            } else {
324                einsum(
325                    "ji,jk,kl->il",
326                    &[&paired_cw.map(|x| x.conj()), sao, &paired_cx],
327                )
328            }
329            .map_err(|err| format_err!(err))?
330            .into_dimensionality::<Ix2>()?;
331
332            let max_offdiag_lowdin = (&lowdin_orb_ovmat - &Array2::from_diag(&lowdin_orb_ovmat.diag().to_owned()))
333                .iter()
334                .map(|x| ComplexFloat::abs(*x))
335                .max_by(|x, y| {
336                    x.partial_cmp(y)
337                        .expect("Unable to compare two `abs` values.")
338                })
339                .ok_or_else(|| format_err!("Unable to determine the maximum off-diagonal element of the Lowdin-paired overlap matrix."))?;
340            if max_offdiag_lowdin <= thresh_offdiag {
341                let lowdin_overlaps = lowdin_orb_ovmat.into_diag().to_vec();
342                let zero_indices = lowdin_overlaps
343                    .iter()
344                    .enumerate()
345                    .filter(|(_, ov)| ComplexFloat::abs(**ov) < thresh_zeroov)
346                    .map(|(i, _)| i)
347                    .collect::<IndexSet<_>>();
348                LowdinPairedCoefficients::builder()
349                    .paired_cw(paired_cw.clone())
350                    .paired_cx(paired_cx.clone())
351                    .lowdin_overlaps(lowdin_overlaps)
352                    .zero_indices(zero_indices)
353                    .thresh_zeroov(thresh_zeroov)
354                    .complex_symmetric(complex_symmetric)
355                    .build()
356                    .map_err(|err| format_err!(err))
357            } else {
358                Err(format_err!(
359                    "Löwdin overlap matrix deviates from diagonality. Maximum off-diagonal overlap has magnitude {max_offdiag_lowdin:.3e} > threshold of {thresh_offdiag:.3e}. Löwdin pairing has failed."
360                ))
361            }
362        }
363    }
364}
365
366/// Calculates the matrix element of a zero-particle operator between two Löwdin-paired
367/// determinants.
368///
369/// # Arguments
370///
371/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
372/// subspace determined by the specified structure constraint.
373/// `o0` - The zero-particle operator.
374/// `structure_constraint` - The structure constraint governing the coefficients.
375///
376/// # Returns
377///
378/// The zero-particle matrix element.
379pub fn calc_o0_matrix_element<T, SC>(
380    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
381    o0: T,
382    structure_constraint: &SC,
383) -> Result<T, anyhow::Error>
384where
385    T: ComplexFloat + ScalarOperand + Product,
386    SC: StructureConstraint,
387{
388    let nzeros_explicit: usize = lowdin_paired_coefficientss
389        .iter()
390        .map(|lpc| lpc.n_lowdin_zeros())
391        .sum();
392    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
393    if nzeros > 0 {
394        Ok(T::zero())
395    } else {
396        let reduced_ov_explicit: T = lowdin_paired_coefficientss
397            .iter()
398            .map(|lpc| lpc.reduced_overlap())
399            .product();
400        let reduced_ov = (0..structure_constraint.implicit_factor()?)
401            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
402        Ok(reduced_ov * o0)
403    }
404}
405
406/// Calculates the matrix element of a one-particle operator between two Löwdin-paired
407/// determinants.
408///
409/// # Arguments
410///
411/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
412/// subspace determined by the specified structure constraint.
413/// `o1` - The one-particle operator in the atomic-orbital basis.
414/// `structure_constraint` - The structure constraint governing the coefficients.
415///
416/// # Returns
417///
418/// The one-particle matrix element.
419pub fn calc_o1_matrix_element<T, SC>(
420    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
421    o1: &ArrayView2<T>,
422    structure_constraint: &SC,
423) -> Result<T, anyhow::Error>
424where
425    T: ComplexFloat + ScalarOperand + Product,
426    SC: StructureConstraint,
427{
428    let nzeros_explicit: usize = lowdin_paired_coefficientss
429        .iter()
430        .map(|lpc| lpc.n_lowdin_zeros())
431        .sum();
432    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
433    if nzeros > 1 {
434        Ok(T::zero())
435    } else {
436        let reduced_ov_explicit: T = lowdin_paired_coefficientss
437            .iter()
438            .map(|lpc| lpc.reduced_overlap())
439            .product();
440        let reduced_ov = (0..structure_constraint.implicit_factor()?)
441            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
442
443        if nzeros == 0 {
444            let nbasis = lowdin_paired_coefficientss[0].nbasis();
445            let w = (0..structure_constraint.implicit_factor()?)
446                .cartesian_product(lowdin_paired_coefficientss.iter())
447                .fold(
448                    Ok(Array2::<T>::zeros((nbasis, nbasis))),
449                    |acc_res, (_, lpc)| {
450                        calc_weighted_codensity_matrix(lpc).and_then(|w| acc_res.map(|acc| acc + w))
451                    },
452                )?;
453            // i = μ, j = μ'
454            einsum("ij,ji->", &[o1, &w.view()])
455                .map_err(|err| format_err!(err))?
456                .into_dimensionality::<Ix0>()?
457                .into_iter()
458                .next()
459                .ok_or_else(|| {
460                    format_err!("Unable to extract the result of the einsum contraction.")
461                })
462                .map(|v| v * reduced_ov)
463        } else {
464            ensure!(
465                nzeros == 1,
466                "Unexpected number of zero Löwdin overlaps: {nzeros} != 1."
467            );
468            let ps = (0..structure_constraint.implicit_factor()?)
469                .flat_map(|_| {
470                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
471                        lpc.zero_indices()
472                            .iter()
473                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
474                    })
475                })
476                .collect::<Result<Vec<_>, _>>()?;
477            ensure!(
478                ps.len() == 1,
479                "Unexpected number of unweighted codensity matrices ({}) for one zero overlap.",
480                ps.len()
481            );
482            let p_mbar = ps.first().ok_or_else(|| {
483                format_err!("Unable to retrieve the computed unweighted codensity matrix.")
484            })?;
485
486            // i = μ, j = μ'
487            einsum("ij,ji->", &[o1, &p_mbar.view()])
488                .map_err(|err| format_err!(err))?
489                .into_dimensionality::<Ix0>()?
490                .into_iter()
491                .next()
492                .ok_or_else(|| {
493                    format_err!("Unable to extract the result of the einsum contraction.")
494                })
495                .map(|v| v * reduced_ov)
496        }
497    }
498}
499
500/// Calculates the matrix element of a two-particle operator between two Löwdin-paired
501/// determinants.
502///
503/// # Arguments
504///
505/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
506/// subspace determined by the specified structure constraint.
507/// `o2_opt` - The two-particle operator in the atomic-orbital basis.
508/// `structure_constraint` - The structure constraint governing the coefficients.
509///
510/// # Returns
511///
512/// The two-particle matrix element.
513pub fn calc_o2_matrix_element<'a, 'b, T, SC, F>(
514    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
515    o2_opt: Option<&'b ArrayView4<'a, T>>,
516    get_jk_opt: Option<&F>,
517    structure_constraint: &SC,
518) -> Result<T, anyhow::Error>
519where
520    'a: 'b,
521    T: ComplexFloat + ScalarOperand + Product + std::fmt::Display,
522    SC: StructureConstraint,
523    F: Fn(&Array2<T>) -> Result<(Array2<T>, Array2<T>), anyhow::Error>,
524{
525    let nzeros_explicit: usize = lowdin_paired_coefficientss
526        .iter()
527        .map(|lpc| lpc.n_lowdin_zeros())
528        .sum();
529    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
530    if nzeros > 2 {
531        Ok(T::zero())
532    } else {
533        let reduced_ov_explicit: T = lowdin_paired_coefficientss
534            .iter()
535            .map(|lpc| lpc.reduced_overlap())
536            .product();
537        let reduced_ov = (0..structure_constraint.implicit_factor()?)
538            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
539
540        if nzeros == 0 {
541            let nbasis = lowdin_paired_coefficientss[0].nbasis();
542            let w_sigmas = (0..structure_constraint.implicit_factor()?)
543                .cartesian_product(lowdin_paired_coefficientss.iter())
544                .map(|(_, lpc)| calc_weighted_codensity_matrix(lpc))
545                .collect::<Result<Vec<_>, _>>()?;
546            let w = w_sigmas.iter().fold(
547                Ok::<_, anyhow::Error>(Array2::<T>::zeros((nbasis, nbasis))),
548                |acc_res, w_sigma| acc_res.map(|acc| acc + w_sigma),
549            )?;
550
551            match (o2_opt, get_jk_opt) {
552                (Some(o2), None) => {
553                    // i = μ, j = μ', k = ν, l = ν'
554                    let j_term = einsum("ikjl,ji,lk->", &[o2, &w.view(), &w.view()])
555                        .map_err(|err| format_err!(err))?
556                        .into_dimensionality::<Ix0>()?
557                        .into_iter()
558                        .next()
559                        .ok_or_else(|| {
560                            format_err!("Unable to extract the result of the einsum contraction.")
561                        })
562                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
563                    let k_term = w_sigmas
564                        .iter()
565                        .fold(Ok(T::zero()), |acc_res, w_sigma| {
566                            einsum("ikjl,li,jk->", &[o2, &w_sigma.view(), &w_sigma.view()])
567                                .map_err(|err| format_err!(err))?
568                                .into_dimensionality::<Ix0>()?
569                                .into_iter()
570                                .next()
571                                .ok_or_else(|| {
572                                    format_err!(
573                                        "Unable to extract the result of the einsum contraction."
574                                    )
575                                })
576                                .and_then(|v| acc_res.map(|acc| acc + v))
577                        })
578                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
579                    Ok(j_term - k_term)
580                }
581                (None, Some(get_jk)) => {
582                    // i = μ, j = μ', k = ν, l = ν'
583                    let (j_w, _) = get_jk(&w)?;
584                    let j_term = einsum("ij,ji->", &[&j_w.view(), &w.view()])
585                        .map_err(|err| format_err!(err))?
586                        .into_dimensionality::<Ix0>()?
587                        .into_iter()
588                        .next()
589                        .ok_or_else(|| {
590                            format_err!("Unable to extract the result of the einsum contraction.")
591                        })
592                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
593                    let k_term = w_sigmas
594                        .iter()
595                        .fold(Ok(T::zero()), |acc_res, w_sigma| {
596                            let (_, k_w_sigma) = get_jk(&w_sigma)?;
597                            einsum("il,li->", &[&k_w_sigma.view(), &w_sigma.view()])
598                                .map_err(|err| format_err!(err))?
599                                .into_dimensionality::<Ix0>()?
600                                .into_iter()
601                                .next()
602                                .ok_or_else(|| {
603                                    format_err!(
604                                        "Unable to extract the result of the einsum contraction."
605                                    )
606                                })
607                                .and_then(|v| acc_res.map(|acc| acc + v))
608                        })
609                        .map(|v| v * reduced_ov / (T::one() + T::one()))?;
610                    Ok(j_term - k_term)
611                }
612                _ => Err(format_err!(
613                    "One and only one of `o2` or `get_jk` should be provided."
614                )),
615            }
616        } else if nzeros == 1 {
617            ensure!(
618                nzeros_explicit == 1,
619                "Unexpected number of explicit zero Löwdin overlaps: {nzeros_explicit} != 1."
620            );
621
622            let nbasis = lowdin_paired_coefficientss[0].nbasis();
623            let w = (0..structure_constraint.implicit_factor()?)
624                .cartesian_product(lowdin_paired_coefficientss.iter())
625                .fold(
626                    Ok::<_, anyhow::Error>(Array2::<T>::zeros((nbasis, nbasis))),
627                    |acc_res, (_, lpc)| {
628                        calc_weighted_codensity_matrix(lpc)
629                            .and_then(|w_sigma| acc_res.map(|acc| acc + w_sigma))
630                    },
631                )?;
632
633            lowdin_paired_coefficientss
634                .iter()
635                .filter_map(|lpc| {
636                    if lpc.n_lowdin_zeros() == 1 {
637                        let w_sigma_res = calc_weighted_codensity_matrix(lpc);
638                        let mbar = lpc.zero_indices()[0];
639                        let p_mbar_sigma_res = calc_unweighted_codensity_matrix(lpc, mbar);
640                        Some((w_sigma_res, p_mbar_sigma_res))
641                    } else {
642                        None
643                    }
644                })
645                .fold(Ok(T::zero()), |acc_res, (w_sigma_res, p_mbar_sigma_res)| {
646                    w_sigma_res.and_then(|w_sigma| {
647                        p_mbar_sigma_res.and_then(|p_mbar_sigma| {
648                            match (o2_opt, get_jk_opt) {
649                                (Some(o2), None) => {
650                                    // i = μ, j = μ', k = ν, l = ν'
651                                    let j_term_1 = einsum(
652                                        "ikjl,ji,lk->",
653                                        &[o2, &w.view(), &p_mbar_sigma.view()],
654                                    )
655                                    .map_err(|err| format_err!(err))?
656                                    .into_dimensionality::<Ix0>()?
657                                    .into_iter()
658                                    .next()
659                                    .ok_or_else(|| {
660                                        format_err!(
661                                            "Unable to extract the result of the einsum contraction."
662                                        )
663                                    })?;
664                                    let j_term_2 = einsum(
665                                        "ikjl,ji,lk->",
666                                        &[o2, &p_mbar_sigma.view(), &w.view()],
667                                    )
668                                    .map_err(|err| format_err!(err))?
669                                    .into_dimensionality::<Ix0>()?
670                                    .into_iter()
671                                    .next()
672                                    .ok_or_else(|| {
673                                        format_err!(
674                                            "Unable to extract the result of the einsum contraction."
675                                        )
676                                    })?;
677                                    let k_term_1 = einsum(
678                                        "ikjl,li,jk->",
679                                        &[o2, &w_sigma.view(), &p_mbar_sigma.view()],
680                                    )
681                                    .map_err(|err| format_err!(err))?
682                                    .into_dimensionality::<Ix0>()?
683                                    .into_iter()
684                                    .next()
685                                    .ok_or_else(|| {
686                                        format_err!(
687                                            "Unable to extract the result of the einsum contraction."
688                                        )
689                                    })?;
690                                    let k_term_2 = einsum(
691                                        "ikjl,li,jk->",
692                                        &[o2, &p_mbar_sigma.view(), &w_sigma.view()],
693                                    )
694                                    .map_err(|err| format_err!(err))?
695                                    .into_dimensionality::<Ix0>()?
696                                    .into_iter()
697                                    .next()
698                                    .ok_or_else(|| {
699                                        format_err!(
700                                            "Unable to extract the result of the einsum contraction."
701                                        )
702                                    })?;
703                                    acc_res.map(|acc| acc + j_term_1 + j_term_2 - k_term_1 - k_term_2)
704                                },
705                                (None, Some(get_jk)) => {
706                                    // i = μ, j = μ', k = ν, l = ν'
707                                    let (j_p_mbar_sigma, k_p_mbar_sigma) = get_jk(&p_mbar_sigma)?;
708                                    let j_term_1 = einsum(
709                                        "ij,ji->",
710                                        &[&j_p_mbar_sigma.view(), &w.view()],
711                                    )
712                                    .map_err(|err| format_err!(err))?
713                                    .into_dimensionality::<Ix0>()?
714                                    .into_iter()
715                                    .next()
716                                    .ok_or_else(|| {
717                                        format_err!(
718                                            "Unable to extract the result of the einsum contraction."
719                                        )
720                                    })?;
721                                    let (j_w, _) = get_jk(&w)?;
722                                    let j_term_2 = einsum(
723                                        "ij,ji->",
724                                        &[&j_w.view(), &p_mbar_sigma.view()],
725                                    )
726                                    .map_err(|err| format_err!(err))?
727                                    .into_dimensionality::<Ix0>()?
728                                    .into_iter()
729                                    .next()
730                                    .ok_or_else(|| {
731                                        format_err!(
732                                            "Unable to extract the result of the einsum contraction."
733                                        )
734                                    })?;
735                                    let k_term_1 = einsum(
736                                        "il,li->",
737                                        &[&k_p_mbar_sigma.view(), &w_sigma.view()],
738                                    )
739                                    .map_err(|err| format_err!(err))?
740                                    .into_dimensionality::<Ix0>()?
741                                    .into_iter()
742                                    .next()
743                                    .ok_or_else(|| {
744                                        format_err!(
745                                            "Unable to extract the result of the einsum contraction."
746                                        )
747                                    })?;
748                                    let (_, k_w_sigma) = get_jk(&w_sigma)?;
749                                    let k_term_2 = einsum(
750                                        "il,li->",
751                                        &[&k_w_sigma.view(), &p_mbar_sigma.view()],
752                                    )
753                                    .map_err(|err| format_err!(err))?
754                                    .into_dimensionality::<Ix0>()?
755                                    .into_iter()
756                                    .next()
757                                    .ok_or_else(|| {
758                                        format_err!(
759                                            "Unable to extract the result of the einsum contraction."
760                                        )
761                                    })?;
762                                    acc_res.map(|acc| acc + j_term_1 + j_term_2 - k_term_1 - k_term_2)
763                                }
764                                _ => Err(format_err!(
765                                    "One and only one of `o2` or `get_jk` should be provided."
766                                )),
767                            }
768                        })
769                    })
770                })
771                .map(|v| v * reduced_ov / (T::one() + T::one()))
772        } else {
773            ensure!(
774                nzeros == 2,
775                "Unexpected number of zero Löwdin overlaps: {nzeros} != 2."
776            );
777
778            let ps = (0..structure_constraint.implicit_factor()?)
779                .flat_map(|_| {
780                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
781                        lpc.zero_indices()
782                            .iter()
783                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
784                    })
785                })
786                .collect::<Result<Vec<_>, _>>()?;
787            ensure!(
788                ps.len() == 2,
789                "Unexpected number of unweighted codensity matrices ({}) for two zero overlaps.",
790                ps.len()
791            );
792            let p_mbar = ps.first().ok_or_else(|| {
793                format_err!("Unable to retrieve the first computed unweighted codensity matrix.")
794            })?;
795            let p_nbar = ps.last().ok_or_else(|| {
796                format_err!("Unable to retrieve the second computed unweighted codensity matrix.")
797            })?;
798
799            match (o2_opt, get_jk_opt) {
800                (Some(o2), None) => {
801                    // i = μ, j = μ', k = ν, l = ν'
802                    let j_term_1 = einsum("ikjl,ji,lk->", &[o2, &p_mbar.view(), &p_nbar.view()])
803                        .map_err(|err| format_err!(err))?
804                        .into_dimensionality::<Ix0>()?
805                        .into_iter()
806                        .next()
807                        .ok_or_else(|| {
808                            format_err!("Unable to extract the result of the einsum contraction.")
809                        })?;
810                    let j_term_2 = einsum("ikjl,ji,lk->", &[o2, &p_nbar.view(), &p_mbar.view()])
811                        .map_err(|err| format_err!(err))?
812                        .into_dimensionality::<Ix0>()?
813                        .into_iter()
814                        .next()
815                        .ok_or_else(|| {
816                            format_err!("Unable to extract the result of the einsum contraction.")
817                        })?;
818
819                    let (k_term_1, k_term_2) = if lowdin_paired_coefficientss
820                        .iter()
821                        .any(|lpc| lpc.n_lowdin_zeros() == 2)
822                    {
823                        let k_term_1 =
824                            einsum("ikjl,li,jk->", &[o2, &p_mbar.view(), &p_nbar.view()])
825                                .map_err(|err| format_err!(err))?
826                                .into_dimensionality::<Ix0>()?
827                                .into_iter()
828                                .next()
829                                .ok_or_else(|| {
830                                    format_err!(
831                                        "Unable to extract the result of the einsum contraction."
832                                    )
833                                })?;
834                        let k_term_2 =
835                            einsum("ikjl,li,jk->", &[o2, &p_nbar.view(), &p_mbar.view()])
836                                .map_err(|err| format_err!(err))?
837                                .into_dimensionality::<Ix0>()?
838                                .into_iter()
839                                .next()
840                                .ok_or_else(|| {
841                                    format_err!(
842                                        "Unable to extract the result of the einsum contraction."
843                                    )
844                                })?;
845                        (k_term_1, k_term_2)
846                    } else {
847                        (T::zero(), T::zero())
848                    };
849                    Ok(reduced_ov * (j_term_1 - k_term_1 + j_term_2 - k_term_2)
850                        / (T::one() + T::one()))
851                }
852                (None, Some(get_jk)) => {
853                    // i = μ, j = μ', k = ν, l = ν'
854                    let (j_p_nbar, k_p_nbar) = get_jk(&p_nbar)?;
855                    let j_term_1 = einsum("ij,ji->", &[&j_p_nbar.view(), &p_mbar.view()])
856                        .map_err(|err| format_err!(err))?
857                        .into_dimensionality::<Ix0>()?
858                        .into_iter()
859                        .next()
860                        .ok_or_else(|| {
861                            format_err!("Unable to extract the result of the einsum contraction.")
862                        })?;
863                    let (j_p_mbar, k_p_mbar) = get_jk(&p_mbar)?;
864                    let j_term_2 = einsum("ij,ji->", &[&j_p_mbar.view(), &p_nbar.view()])
865                        .map_err(|err| format_err!(err))?
866                        .into_dimensionality::<Ix0>()?
867                        .into_iter()
868                        .next()
869                        .ok_or_else(|| {
870                            format_err!("Unable to extract the result of the einsum contraction.")
871                        })?;
872
873                    let (k_term_1, k_term_2) = if lowdin_paired_coefficientss
874                        .iter()
875                        .any(|lpc| lpc.n_lowdin_zeros() == 2)
876                    {
877                        let k_term_1 = einsum("il,li->", &[&k_p_nbar.view(), &p_mbar.view()])
878                            .map_err(|err| format_err!(err))?
879                            .into_dimensionality::<Ix0>()?
880                            .into_iter()
881                            .next()
882                            .ok_or_else(|| {
883                                format_err!(
884                                    "Unable to extract the result of the einsum contraction."
885                                )
886                            })?;
887                        let k_term_2 = einsum("il,li->", &[&k_p_mbar.view(), &p_nbar.view()])
888                            .map_err(|err| format_err!(err))?
889                            .into_dimensionality::<Ix0>()?
890                            .into_iter()
891                            .next()
892                            .ok_or_else(|| {
893                                format_err!(
894                                    "Unable to extract the result of the einsum contraction."
895                                )
896                            })?;
897                        (k_term_1, k_term_2)
898                    } else {
899                        (T::zero(), T::zero())
900                    };
901                    Ok(reduced_ov * (j_term_1 - k_term_1 + j_term_2 - k_term_2)
902                        / (T::one() + T::one()))
903                }
904                _ => Err(format_err!(
905                    "One and only one of `o2` or `get_jk` should be provided."
906                )),
907            }
908        }
909    }
910}
911
912/// Calculates the transition density matrix between two Löwdin-paired determinants.
913///
914/// # Arguments
915///
916/// `lowdin_paired_coefficientss` - A sequence of pairs of Löwdin-paired coefficients, one for each
917/// subspace determined by the specified structure constraint.
918/// `structure_constraint` - The structure constraint governing the coefficients.
919///
920/// # Returns
921///
922/// The one-particle matrix element.
923pub fn calc_transition_density_matrix<T, SC>(
924    lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
925    structure_constraint: &SC,
926) -> Result<Array2<T>, anyhow::Error>
927where
928    T: ComplexFloat + ScalarOperand + Product,
929    SC: StructureConstraint,
930{
931    let nzeros_explicit: usize = lowdin_paired_coefficientss
932        .iter()
933        .map(|lpc| lpc.n_lowdin_zeros())
934        .sum();
935    let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
936    let nbasis = lowdin_paired_coefficientss[0].nbasis();
937    if nzeros > 1 {
938        Ok(Array2::<T>::zeros((nbasis, nbasis)))
939    } else {
940        let reduced_ov_explicit: T = lowdin_paired_coefficientss
941            .iter()
942            .map(|lpc| lpc.reduced_overlap())
943            .product();
944        let reduced_ov = (0..structure_constraint.implicit_factor()?)
945            .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
946
947        if nzeros == 0 {
948            let nbasis = lowdin_paired_coefficientss[0].nbasis();
949            let w = (0..structure_constraint.implicit_factor()?)
950                .cartesian_product(lowdin_paired_coefficientss.iter())
951                .fold(
952                    Ok(Array2::<T>::zeros((nbasis, nbasis))),
953                    |acc_res, (_, lpc)| {
954                        calc_weighted_codensity_matrix(lpc).and_then(|w| acc_res.map(|acc| acc + w))
955                    },
956                )?;
957            Ok(w.mapv(|v| v * reduced_ov))
958        } else {
959            ensure!(
960                nzeros == 1,
961                "Unexpected number of zero Löwdin overlaps: {nzeros} != 1."
962            );
963            let ps = (0..structure_constraint.implicit_factor()?)
964                .flat_map(|_| {
965                    lowdin_paired_coefficientss.iter().flat_map(|lpc| {
966                        lpc.zero_indices()
967                            .iter()
968                            .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
969                    })
970                })
971                .collect::<Result<Vec<_>, _>>()?;
972            ensure!(
973                ps.len() == 1,
974                "Unexpected number of unweighted codensity matrices ({}) for one zero overlap.",
975                ps.len()
976            );
977            let p_mbar = ps.first().ok_or_else(|| {
978                format_err!("Unable to retrieve the computed unweighted codensity matrix.")
979            })?;
980            Ok(p_mbar.mapv(|v| v * reduced_ov))
981        }
982    }
983}
984
985/// Performs modified Gram--Schmidt orthonormalisation on a set of column vectors in a matrix with
986/// respect to the complex-symmetric or Hermitian dot product.
987///
988/// # Arguments
989///
990/// * `vmat` - Matrix containing column vectors forming a basis for a subspace.
991/// * `complex_symmetric` - A boolean indicating if the vector dot product is complex-symmetric. If
992/// `false`, the conventional Hermitian dot product is used.
993/// * `thresh` - A threshold for determining self-orthogonal vectors.
994///
995/// # Returns
996///
997/// The orthonormal vectors forming a basis for the same subspace collected as column vectors in a
998/// matrix.
999///
1000/// # Errors
1001///
1002/// Errors when the orthonormalisation procedure fails, which occurs when there is linear dependency
1003/// between the basis vectors and/or when self-orthogonal vectors are encountered.
1004pub fn complex_modified_gram_schmidt<T>(
1005    vmat: &ArrayView2<T>,
1006    complex_symmetric: bool,
1007    thresh: <T as ComplexFloat>::Real,
1008) -> Result<Array2<T>, anyhow::Error>
1009where
1010    T: ComplexFloat + std::fmt::Display + 'static,
1011{
1012    let mut us: Vec<Array1<T>> = Vec::with_capacity(vmat.shape()[1]);
1013    let mut us_sq_norm: Vec<T> = Vec::with_capacity(vmat.shape()[1]);
1014    for (i, vi) in vmat.columns().into_iter().enumerate() {
1015        // u[i] now initialised with v[i]
1016        us.push(vi.to_owned());
1017
1018        // Project ui onto all uj (0 <= j < i)
1019        // This is the 'modified' part of Gram--Schmidt. We project the current (and being updated)
1020        // ui onto uj, rather than projecting vi onto uj. This enhances numerical stability.
1021        for j in 0..i {
1022            let p_uj_ui = if complex_symmetric {
1023                us[j].t().dot(&us[i]) / us_sq_norm[j]
1024            } else {
1025                us[j].t().map(|x| x.conj()).dot(&us[i]) / us_sq_norm[j]
1026            };
1027            us[i] = &us[i] - us[j].map(|&x| x * p_uj_ui);
1028        }
1029
1030        // Evaluate the squared norm of ui which will no longer be changed after this iteration.
1031        // us_sq_norm[i] now available.
1032        let us_sq_norm_i = if complex_symmetric {
1033            us[i].t().dot(&us[i])
1034        } else {
1035            us[i].t().map(|x| x.conj()).dot(&us[i])
1036        };
1037        if us_sq_norm_i.abs() < thresh {
1038            return Err(format_err!("A zero-norm vector found: {}", us[i]));
1039        }
1040        us_sq_norm.push(us_sq_norm_i);
1041    }
1042
1043    // Normalise ui
1044    for i in 0..us.len() {
1045        us[i].mapv_inplace(|x| x / us_sq_norm[i].sqrt());
1046    }
1047
1048    let ortho_check = us.iter().enumerate().all(|(i, ui)| {
1049        us.iter().enumerate().all(|(j, uj)| {
1050            let ov_ij = if complex_symmetric {
1051                ui.dot(uj)
1052            } else {
1053                ui.map(|x| x.conj()).dot(uj)
1054            };
1055            i == j || ov_ij.abs() < thresh
1056        })
1057    });
1058
1059    if ortho_check {
1060        stack(Axis(1), &us.iter().map(|u| u.view()).collect_vec()).map_err(|err| format_err!(err))
1061    } else {
1062        Err(format_err!(
1063            "Post-Gram--Schmidt orthogonality check failed."
1064        ))
1065    }
1066}
1067
1068/// Trait for Löwdin canonical orthogonalisation of a square matrix.
1069pub trait CanonicalOrthogonalisable {
1070    /// Numerical type of the matrix elements.
1071    type NumType;
1072
1073    /// Type of real threshold values.
1074    type RealType;
1075
1076    /// Calculates the Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$ for a square
1077    /// matrix.
1078    ///
1079    /// # Arguments
1080    ///
1081    /// * `complex_symmetric` - Boolean indicating if the orthogonalisation is with respect to the
1082    /// complex-symmetric inner product.
1083    /// * `preserves_full_rank` - Boolean indicating if a full-rank square matrix should be left
1084    /// unchanged, thus forcing $`\mathbf{X} = \mathbf{I}`$.
1085    /// * `thresh_offdiag` - Threshold for verifying that the orthogonalised matrix is indeed
1086    /// orthogonal.
1087    /// * `thresh_zeroov` - Threshold for determining zero eigenvalues of the input square matrix.
1088    ///
1089    /// # Returns
1090    ///
1091    /// The canonical orthogonalisation result.
1092    fn calc_canonical_orthogonal_matrix(
1093        &self,
1094        complex_symmetric: bool,
1095        preserves_full_rank: bool,
1096        thresh_offdiag: Self::RealType,
1097        thresh_zeroov: Self::RealType,
1098    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error>;
1099}
1100
1101/// Structure containing the results of the Löwdin canonical orthogonalisation.
1102pub struct CanonicalOrthogonalisationResult<T> {
1103    /// The eigenvalues of the input matrix.
1104    eigenvalues: Array1<T>,
1105
1106    /// The Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$.
1107    xmat: Array2<T>,
1108
1109    /// The conjugate of the Löwdin canonical orthogonalisation matrix,
1110    /// $`\mathbf{X}^{\dagger\lozenge}`$, where $`\lozenge = \star`$ for complex-symmetric matrices
1111    /// and $`\lozenge = \hat{e}`$ otherwise.
1112    xmat_d: Array2<T>,
1113}
1114
1115impl<T> CanonicalOrthogonalisationResult<T> {
1116    /// Returns the eigenvalues of the input matrix.
1117    pub fn eigenvalues(&'_ self) -> ArrayView1<'_, T> {
1118        self.eigenvalues.view()
1119    }
1120
1121    /// Returns the Löwdin canonical orthogonalisation matrix $`\mathbf{X}`$.
1122    pub fn xmat(&'_ self) -> ArrayView2<'_, T> {
1123        self.xmat.view()
1124    }
1125
1126    /// Returns the conjugate of the Löwdin canonical orthogonalisation matrix,
1127    /// $`\mathbf{X}^{\dagger\lozenge}`$, where $`\lozenge = \star`$ for complex-symmetric matrices
1128    /// and $`\lozenge = \hat{e}`$ otherwise.
1129    pub fn xmat_d(&'_ self) -> ArrayView2<'_, T> {
1130        self.xmat_d.view()
1131    }
1132}
1133
1134#[duplicate_item(
1135    [
1136        dtype_ [ f64 ]
1137    ]
1138    [
1139        dtype_ [ f32 ]
1140    ]
1141)]
1142impl CanonicalOrthogonalisable for ArrayView2<'_, dtype_> {
1143    type NumType = dtype_;
1144
1145    type RealType = dtype_;
1146
1147    fn calc_canonical_orthogonal_matrix(
1148        &self,
1149        _: bool,
1150        preserves_full_rank: bool,
1151        thresh_offdiag: dtype_,
1152        thresh_zeroov: dtype_,
1153    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error> {
1154        let smat = self;
1155
1156        // Real, symmetric S
1157        let deviation_s = (smat.to_owned() - smat.t()).norm_l2();
1158        ensure!(
1159            deviation_s <= thresh_offdiag,
1160            "Overlap matrix is not real-symmetric: ||S - S^T|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
1161        );
1162
1163        // S is real-symmetric, so U is orthogonal, i.e. U^T = U^(-1).
1164        let (s_eig, umat) = smat.eigh(UPLO::Lower).map_err(|err| format_err!(err))?;
1165        // Real eigenvalues, so both comparison modes are the same.
1166        let nonzero_s_indices = s_eig
1167            .iter()
1168            .positions(|x| x.abs() > thresh_zeroov)
1169            .collect_vec();
1170        let nonzero_s_eig = s_eig.select(Axis(0), &nonzero_s_indices);
1171        if nonzero_s_eig.iter().any(|v| *v < 0.0) {
1172            return Err(format_err!(
1173                "The matrix has negative eigenvalues and therefore cannot be orthogonalised over the reals."
1174            ));
1175        }
1176        let nonzero_umat = umat.select(Axis(1), &nonzero_s_indices);
1177        let nullity = smat.shape()[0] - nonzero_s_indices.len();
1178        let (xmat, xmat_d) = if nullity == 0 && preserves_full_rank {
1179            (Array2::eye(smat.shape()[0]), Array2::eye(smat.shape()[0]))
1180        } else {
1181            let s_s = Array2::<dtype_>::from_diag(&nonzero_s_eig.mapv(|x| 1.0 / x.sqrt()));
1182            (nonzero_umat.dot(&s_s), s_s.dot(&nonzero_umat.t()))
1183        };
1184        let res = CanonicalOrthogonalisationResult {
1185            eigenvalues: s_eig,
1186            xmat,
1187            xmat_d,
1188        };
1189        Ok(res)
1190    }
1191}
1192
1193impl<T> CanonicalOrthogonalisable for ArrayView2<'_, Complex<T>>
1194where
1195    T: Float + Scalar<Complex = Complex<T>>,
1196    Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
1197{
1198    type NumType = Complex<T>;
1199
1200    type RealType = T;
1201
1202    fn calc_canonical_orthogonal_matrix(
1203        &self,
1204        complex_symmetric: bool,
1205        preserves_full_rank: bool,
1206        thresh_offdiag: T,
1207        thresh_zeroov: T,
1208    ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error> {
1209        let smat = self;
1210
1211        if complex_symmetric {
1212            // Complex-symmetric S
1213            ensure!(
1214                (smat.to_owned() - smat.t()).norm_l2() <= thresh_offdiag,
1215                "Overlap matrix is not complex-symmetric."
1216            );
1217        } else {
1218            // Complex-Hermitian S
1219            ensure!(
1220                (smat.to_owned() - smat.map(|v| v.conj()).t()).norm_l2() <= thresh_offdiag,
1221                "Overlap matrix is not complex-Hermitian."
1222            );
1223        }
1224
1225        let (s_eig, umat_nonortho) = smat.eig().map_err(|err| format_err!(err))?;
1226        log::debug!("Overlap eigenvalues for canonical orthogonalisation:");
1227        for (i, eig) in s_eig.iter().enumerate() {
1228            log::debug!("  {i}: {eig:+.8e}");
1229        }
1230        log::debug!("");
1231
1232        let nonzero_s_indices = s_eig
1233            .iter()
1234            .positions(|x| ComplexFloat::abs(*x) > thresh_zeroov)
1235            .collect_vec();
1236        log::debug!("Non-zero overlap indices w.r.t. threshold {thresh_zeroov:.8e}:");
1237        log::debug!(
1238            "  {}",
1239            nonzero_s_indices.iter().map(|i| i.to_string()).join(", ")
1240        );
1241        log::debug!("");
1242        let nonzero_s_eig = s_eig.select(Axis(0), &nonzero_s_indices);
1243        let nonzero_umat_nonortho = umat_nonortho.select(Axis(1), &nonzero_s_indices);
1244
1245        // `eig` does not guarantee orthogonality of `nonzero_umat_nonortho`.
1246        // Gram--Schmidt is therefore required.
1247        let nonzero_umat = complex_modified_gram_schmidt(
1248            &nonzero_umat_nonortho.view(),
1249            complex_symmetric,
1250            thresh_zeroov,
1251        )
1252        .map_err(
1253            |_| format_err!("Unable to orthonormalise the linearly-independent eigenvectors of the overlap matrix.")
1254        )?;
1255
1256        let nonzero_s_eig_from_u = if complex_symmetric {
1257            nonzero_umat.t().dot(smat).dot(&nonzero_umat)
1258        } else {
1259            nonzero_umat
1260                .map(|v| v.conj())
1261                .t()
1262                .dot(smat)
1263                .dot(&nonzero_umat)
1264        };
1265
1266        let deviation_s = (nonzero_s_eig_from_u - Array2::from_diag(&nonzero_s_eig)).norm_l2();
1267        ensure!(
1268            deviation_s <= thresh_offdiag,
1269            if complex_symmetric {
1270                "Canonical orthogonalisation has failed: ||U^T.S.U - s|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
1271            } else {
1272                "Canonical orthogonalisation has failed: ||U^†.S.U - s|| = {deviation_s:.3e} > {thresh_offdiag:.3e}."
1273            }
1274        );
1275
1276        let nullity = smat.shape()[0] - nonzero_s_indices.len();
1277        let (xmat, xmat_d) = if nullity == 0 && preserves_full_rank {
1278            (
1279                Array2::<Complex<T>>::eye(smat.shape()[0]),
1280                Array2::<Complex<T>>::eye(smat.shape()[0]),
1281            )
1282        } else {
1283            let s_s = Array2::<Complex<T>>::from_diag(
1284                &nonzero_s_eig.mapv(|x| Complex::<T>::from(T::one()) / x.sqrt()),
1285            );
1286            let xmat = nonzero_umat.dot(&s_s);
1287            let xmat_d = if complex_symmetric {
1288                // (nonzero_umat.dot(&s_s), s_s.dot(&nonzero_umat.t()))
1289                xmat.t().to_owned()
1290            } else {
1291                xmat.map(|v| v.conj()).t().to_owned()
1292            };
1293            (xmat, xmat_d)
1294        };
1295        let res = CanonicalOrthogonalisationResult {
1296            eigenvalues: s_eig,
1297            xmat,
1298            xmat_d,
1299        };
1300        Ok(res)
1301    }
1302}