1use std::fmt::LowerExp;
2use std::iter::Product;
3
4use anyhow::{self, ensure, format_err};
5use derive_builder::Builder;
6use duplicate::duplicate_item;
7use indexmap::IndexSet;
8use itertools::Itertools;
9use ndarray::{
10 stack, Array1, Array2, ArrayView1, ArrayView2, ArrayView4, Axis, Ix0, Ix2, ScalarOperand,
11};
12use ndarray_einsum::einsum;
13use ndarray_linalg::types::Lapack;
14use ndarray_linalg::{Determinant, Eig, Eigh, Norm, Scalar, SVD, UPLO};
15use num::{Complex, Float};
16use num_complex::ComplexFloat;
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19
20use super::denmat::{calc_unweighted_codensity_matrix, calc_weighted_codensity_matrix};
21
22#[cfg(test)]
23#[path = "nonortho_tests.rs"]
24mod nonortho_tests;
25
26#[derive(Builder)]
38#[builder(build_fn(validate = "Self::validate"))]
39pub struct LowdinPairedCoefficients<T: ComplexFloat> {
40 paired_cw: Array2<T>,
42
43 paired_cx: Array2<T>,
45
46 lowdin_overlaps: Vec<T>,
48
49 #[builder(default = "self.default_zero_indices()?")]
53 zero_indices: IndexSet<usize>,
54
55 thresh_zeroov: T::Real,
57
58 complex_symmetric: bool,
61}
62
63impl<T: ComplexFloat> LowdinPairedCoefficientsBuilder<T> {
64 fn validate(&self) -> Result<(), String> {
65 let paired_cw = self
66 .paired_cw
67 .as_ref()
68 .ok_or("Löwdin-paired coefficients `paired_cw` not set.".to_string())?;
69 let paired_cx = self
70 .paired_cx
71 .as_ref()
72 .ok_or("Löwdin-paired coefficients `paired_cx` not set.".to_string())?;
73 let lowdin_overlaps = self
74 .lowdin_overlaps
75 .as_ref()
76 .ok_or("Löwdin overlaps not set.".to_string())?;
77 let zero_indices = self
78 .zero_indices
79 .as_ref()
80 .ok_or("Indices of zero Löwdin overlaps not set.".to_string())?;
81
82 if paired_cw.shape() == paired_cx.shape() {
83 let lowdin_dim = paired_cw.shape()[1];
84 if lowdin_dim == lowdin_overlaps.len() {
85 if zero_indices.iter().all(|i| *i < lowdin_dim) {
86 Ok(())
87 } else {
88 Err("Some indices of zero Löwdin overlaps are out-of-bound!".to_string())
89 }
90 } else {
91 Err(
92 "Inconsistent number of Löwdin-paired orbitals and Löwdin overlaps."
93 .to_string(),
94 )
95 }
96 } else {
97 Err(format!(
98 "Inconsistent shapes between `paired_cw` ({:?}) and `paired_cx` ({:?}).",
99 paired_cw.shape(),
100 paired_cx.shape()
101 ))
102 }
103 }
104
105 fn default_zero_indices(&self) -> Result<IndexSet<usize>, String> {
106 let lowdin_overlaps = self
107 .lowdin_overlaps
108 .as_ref()
109 .ok_or("Löwdin overlaps not set.".to_string())?;
110 let thresh_zeroov = self
111 .thresh_zeroov
112 .as_ref()
113 .ok_or("threshold for zero Löwdin overlaps not set.".to_string())?;
114 let zero_indices = lowdin_overlaps
115 .iter()
116 .enumerate()
117 .filter(|(_, ov)| ComplexFloat::abs(**ov) < *thresh_zeroov)
118 .map(|(i, _)| i)
119 .collect::<IndexSet<_>>();
120 Ok(zero_indices)
121 }
122}
123
124impl<T: ComplexFloat> LowdinPairedCoefficients<T> {
125 pub fn builder() -> LowdinPairedCoefficientsBuilder<T> {
126 LowdinPairedCoefficientsBuilder::<T>::default()
127 }
128
129 pub fn paired_coefficients(&self) -> (&Array2<T>, &Array2<T>) {
131 (&self.paired_cw, &self.paired_cx)
132 }
133
134 pub fn nbasis(&self) -> usize {
137 self.paired_cw.nrows()
138 }
139
140 pub fn lowdin_dim(&self) -> usize {
142 self.lowdin_overlaps.len()
143 }
144
145 pub fn n_lowdin_zeros(&self) -> usize {
147 self.zero_indices.len()
148 }
149
150 pub fn lowdin_overlaps(&self) -> &Vec<T> {
152 &self.lowdin_overlaps
153 }
154
155 pub fn zero_indices(&self) -> &IndexSet<usize> {
157 &self.zero_indices
158 }
159
160 pub fn nonzero_indices(&self) -> IndexSet<usize> {
162 (0..self.lowdin_dim())
163 .filter(|i| !self.zero_indices.contains(i))
164 .collect::<IndexSet<_>>()
165 }
166
167 pub fn complex_symmetric(&self) -> bool {
170 self.complex_symmetric
171 }
172}
173
174impl<T: ComplexFloat + Product> LowdinPairedCoefficients<T> {
175 pub fn reduced_overlap(&self) -> T {
177 self.nonzero_indices()
178 .iter()
179 .map(|i| self.lowdin_overlaps[*i])
180 .product()
181 }
182}
183
184pub fn calc_lowdin_pairing<T>(
235 cw: &ArrayView2<T>,
236 cx: &ArrayView2<T>,
237 sao: &ArrayView2<T>,
238 complex_symmetric: bool,
239 thresh_offdiag: <T as ComplexFloat>::Real,
240 thresh_zeroov: <T as ComplexFloat>::Real,
241) -> Result<LowdinPairedCoefficients<T>, anyhow::Error>
242where
243 T: ComplexFloat + Lapack,
244 <T as ComplexFloat>::Real: PartialOrd + LowerExp,
245{
246 if cw.shape() != cx.shape() {
247 Err(format_err!(
248 "Coefficient dimensions mismatched: cw ({:?}) !~ cx ({:?}).",
249 cw.shape(),
250 cx.shape()
251 ))
252 } else {
253 let init_orb_ovmat = if complex_symmetric {
254 einsum("ji,jk,kl->il", &[cw, sao, cx])
255 } else {
256 einsum("ji,jk,kl->il", &[&cw.map(|x| x.conj()), sao, cx])
257 }
258 .map_err(|err| format_err!(err))?
259 .into_dimensionality::<Ix2>()?;
260
261 let max_offdiag = (&init_orb_ovmat - &Array2::from_diag(&init_orb_ovmat.diag().to_owned()))
262 .iter()
263 .map(|x| ComplexFloat::abs(*x))
264 .max_by(|x, y| {
265 x.partial_cmp(y)
266 .expect("Unable to compare two `abs` values.")
267 })
268 .ok_or_else(|| format_err!("Unable to determine the maximum off-diagonal element."))?;
269
270 if max_offdiag <= thresh_offdiag {
271 let lowdin_overlaps = init_orb_ovmat.into_diag().to_vec();
272 let zero_indices = lowdin_overlaps
273 .iter()
274 .enumerate()
275 .filter(|(_, ov)| ComplexFloat::abs(**ov) < thresh_zeroov)
276 .map(|(i, _)| i)
277 .collect::<IndexSet<_>>();
278 LowdinPairedCoefficients::builder()
279 .paired_cw(cw.to_owned())
280 .paired_cx(cx.to_owned())
281 .lowdin_overlaps(lowdin_overlaps)
282 .zero_indices(zero_indices)
283 .thresh_zeroov(thresh_zeroov)
284 .complex_symmetric(complex_symmetric)
285 .build()
286 .map_err(|err| format_err!(err))
287 } else {
288 let (u_opt, _, vh_opt) = init_orb_ovmat.svd(true, true)?;
289 let u = u_opt.ok_or_else(|| format_err!("Unable to compute the U matrix from SVD."))?;
290 let vh =
291 vh_opt.ok_or_else(|| format_err!("Unable to compute the V matrix from SVD."))?;
292 let v = vh.t().map(|x| x.conj());
293 let det_v_c = v.det()?.conj();
294
295 let paired_cw = if complex_symmetric {
296 let uc = u.map(|x| x.conj());
297 let mut cwt = cw.dot(&uc);
298 let det_uc_c = uc.det()?.conj();
299 cwt.column_mut(0)
300 .iter_mut()
301 .for_each(|x| *x = *x * det_uc_c);
302 cwt
303 } else {
304 let mut cwt = cw.dot(&u);
305 let det_u_c = u.det()?.conj();
306 cwt.column_mut(0).iter_mut().for_each(|x| *x = *x * det_u_c);
307 cwt
308 };
309
310 let paired_cx = {
311 let mut cxt = cx.dot(&v);
312 cxt.column_mut(0).iter_mut().for_each(|x| *x = *x * det_v_c);
313 cxt
314 };
315
316 let lowdin_orb_ovmat = if complex_symmetric {
317 einsum("ji,jk,kl->il", &[&paired_cw, sao, &paired_cx])
318 } else {
319 einsum(
320 "ji,jk,kl->il",
321 &[&paired_cw.map(|x| x.conj()), sao, &paired_cx],
322 )
323 }
324 .map_err(|err| format_err!(err))?
325 .into_dimensionality::<Ix2>()?;
326
327 let max_offdiag_lowdin = (&lowdin_orb_ovmat - &Array2::from_diag(&lowdin_orb_ovmat.diag().to_owned()))
328 .iter()
329 .map(|x| ComplexFloat::abs(*x))
330 .max_by(|x, y| {
331 x.partial_cmp(y)
332 .expect("Unable to compare two `abs` values.")
333 })
334 .ok_or_else(|| format_err!("Unable to determine the maximum off-diagonal element of the Lowdin-paired overlap matrix."))?;
335 if max_offdiag_lowdin <= thresh_offdiag {
336 let lowdin_overlaps = lowdin_orb_ovmat.into_diag().to_vec();
337 let zero_indices = lowdin_overlaps
338 .iter()
339 .enumerate()
340 .filter(|(_, ov)| ComplexFloat::abs(**ov) < thresh_zeroov)
341 .map(|(i, _)| i)
342 .collect::<IndexSet<_>>();
343 LowdinPairedCoefficients::builder()
344 .paired_cw(paired_cw.clone())
345 .paired_cx(paired_cx.clone())
346 .lowdin_overlaps(lowdin_overlaps)
347 .zero_indices(zero_indices)
348 .thresh_zeroov(thresh_zeroov)
349 .complex_symmetric(complex_symmetric)
350 .build()
351 .map_err(|err| format_err!(err))
352 } else {
353 Err(format_err!(
354 "Löwdin overlap matrix deviates from diagonality. Maximum off-diagonal overlap has magnitude {max_offdiag_lowdin:.3e} > threshold of {thresh_offdiag:.3e}. Löwdin pairing has failed."
355 ))
356 }
357 }
358 }
359}
360
361pub fn calc_o0_matrix_element<T, SC>(
375 lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
376 o0: T,
377 structure_constraint: &SC,
378) -> Result<T, anyhow::Error>
379where
380 T: ComplexFloat + ScalarOperand + Product,
381 SC: StructureConstraint,
382{
383 let nzeros_explicit: usize = lowdin_paired_coefficientss
384 .iter()
385 .map(|lpc| lpc.n_lowdin_zeros())
386 .sum();
387 let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
388 if nzeros > 0 {
389 Ok(T::zero())
390 } else {
391 let reduced_ov_explicit: T = lowdin_paired_coefficientss
392 .iter()
393 .map(|lpc| lpc.reduced_overlap())
394 .product();
395 let reduced_ov = (0..structure_constraint.implicit_factor()?)
396 .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
397 Ok(reduced_ov * o0)
398 }
399}
400
401pub fn calc_o1_matrix_element<T, SC>(
415 lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
416 o1: &ArrayView2<T>,
417 structure_constraint: &SC,
418) -> Result<T, anyhow::Error>
419where
420 T: ComplexFloat + ScalarOperand + Product,
421 SC: StructureConstraint,
422{
423 let nzeros_explicit: usize = lowdin_paired_coefficientss
424 .iter()
425 .map(|lpc| lpc.n_lowdin_zeros())
426 .sum();
427 let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
428 if nzeros > 1 {
429 Ok(T::zero())
430 } else {
431 let reduced_ov_explicit: T = lowdin_paired_coefficientss
432 .iter()
433 .map(|lpc| lpc.reduced_overlap())
434 .product();
435 let reduced_ov = (0..structure_constraint.implicit_factor()?)
436 .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
437
438 if nzeros == 0 {
439 let nbasis = lowdin_paired_coefficientss[0].nbasis();
440 let w = (0..structure_constraint.implicit_factor()?)
441 .cartesian_product(lowdin_paired_coefficientss.iter())
442 .fold(
443 Ok(Array2::<T>::zeros((nbasis, nbasis))),
444 |acc_res, (_, lpc)| {
445 calc_weighted_codensity_matrix(lpc).and_then(|w| acc_res.map(|acc| acc + w))
446 },
447 )?;
448 einsum("ij,ji->", &[o1, &w.view()])
450 .map_err(|err| format_err!(err))?
451 .into_dimensionality::<Ix0>()?
452 .into_iter()
453 .next()
454 .ok_or_else(|| {
455 format_err!("Unable to extract the result of the einsum contraction.")
456 })
457 .map(|v| v * reduced_ov)
458 } else {
459 ensure!(
460 nzeros == 1,
461 "Unexpected number of zero Löwdin overlaps: {nzeros} != 1."
462 );
463 let ps = (0..structure_constraint.implicit_factor()?)
464 .flat_map(|_| {
465 lowdin_paired_coefficientss.iter().flat_map(|lpc| {
466 lpc.zero_indices()
467 .iter()
468 .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
469 })
470 })
471 .collect::<Result<Vec<_>, _>>()?;
472 ensure!(
473 ps.len() == 1,
474 "Unexpected number of unweighted codensity matrices ({}) for one zero overlap.",
475 ps.len()
476 );
477 let p_mbar = ps.first().ok_or_else(|| {
478 format_err!("Unable to retrieve the computed unweighted codensity matrix.")
479 })?;
480
481 einsum("ij,ji->", &[o1, &p_mbar.view()])
483 .map_err(|err| format_err!(err))?
484 .into_dimensionality::<Ix0>()?
485 .into_iter()
486 .next()
487 .ok_or_else(|| {
488 format_err!("Unable to extract the result of the einsum contraction.")
489 })
490 .map(|v| v * reduced_ov)
491 }
492 }
493}
494
495pub fn calc_o2_matrix_element<T, SC>(
509 lowdin_paired_coefficientss: &[LowdinPairedCoefficients<T>],
510 o2: &ArrayView4<T>,
511 structure_constraint: &SC,
512) -> Result<T, anyhow::Error>
513where
514 T: ComplexFloat + ScalarOperand + Product + std::fmt::Display,
515 SC: StructureConstraint,
516{
517 let nzeros_explicit: usize = lowdin_paired_coefficientss
518 .iter()
519 .map(|lpc| lpc.n_lowdin_zeros())
520 .sum();
521 let nzeros = nzeros_explicit * structure_constraint.implicit_factor()?;
522 if nzeros > 2 {
523 Ok(T::zero())
524 } else {
525 let reduced_ov_explicit: T = lowdin_paired_coefficientss
526 .iter()
527 .map(|lpc| lpc.reduced_overlap())
528 .product();
529 let reduced_ov = (0..structure_constraint.implicit_factor()?)
530 .fold(T::one(), |acc, _| acc * reduced_ov_explicit);
531
532 if nzeros == 0 {
533 let nbasis = lowdin_paired_coefficientss[0].nbasis();
534 let w_sigmas = (0..structure_constraint.implicit_factor()?)
535 .cartesian_product(lowdin_paired_coefficientss.iter())
536 .map(|(_, lpc)| calc_weighted_codensity_matrix(lpc))
537 .collect::<Result<Vec<_>, _>>()?;
538 let w = w_sigmas.iter().fold(
539 Ok::<_, anyhow::Error>(Array2::<T>::zeros((nbasis, nbasis))),
540 |acc_res, w_sigma| acc_res.map(|acc| acc + w_sigma),
541 )?;
542
543 let j_term = einsum("ikjl,ji,lk->", &[o2, &w.view(), &w.view()])
545 .map_err(|err| format_err!(err))?
546 .into_dimensionality::<Ix0>()?
547 .into_iter()
548 .next()
549 .ok_or_else(|| {
550 format_err!("Unable to extract the result of the einsum contraction.")
551 })
552 .map(|v| v * reduced_ov / (T::one() + T::one()))?;
553 let k_term = w_sigmas
554 .iter()
555 .fold(Ok(T::zero()), |acc_res, w_sigma| {
556 einsum("ikjl,li,jk->", &[o2, &w_sigma.view(), &w_sigma.view()])
557 .map_err(|err| format_err!(err))?
558 .into_dimensionality::<Ix0>()?
559 .into_iter()
560 .next()
561 .ok_or_else(|| {
562 format_err!("Unable to extract the result of the einsum contraction.")
563 })
564 .and_then(|v| acc_res.map(|acc| acc + v))
565 })
566 .map(|v| v * reduced_ov / (T::one() + T::one()))?;
567 Ok(j_term - k_term)
568 } else if nzeros == 1 {
569 ensure!(
570 nzeros_explicit == 1,
571 "Unexpected number of explicit zero Löwdin overlaps: {nzeros_explicit} != 1."
572 );
573
574 let nbasis = lowdin_paired_coefficientss[0].nbasis();
575 let w = (0..structure_constraint.implicit_factor()?)
576 .cartesian_product(lowdin_paired_coefficientss.iter())
577 .fold(
578 Ok::<_, anyhow::Error>(Array2::<T>::zeros((nbasis, nbasis))),
579 |acc_res, (_, lpc)| {
580 calc_weighted_codensity_matrix(lpc)
581 .and_then(|w_sigma| acc_res.map(|acc| acc + w_sigma))
582 },
583 )?;
584
585 lowdin_paired_coefficientss
586 .iter()
587 .filter_map(|lpc| {
588 if lpc.n_lowdin_zeros() == 1 {
589 let w_sigma_res = calc_weighted_codensity_matrix(lpc);
590 let mbar = lpc.zero_indices()[0];
591 let p_mbar_sigma_res = calc_unweighted_codensity_matrix(lpc, mbar);
592 Some((w_sigma_res, p_mbar_sigma_res))
593 } else {
594 None
595 }
596 })
597 .fold(Ok(T::zero()), |acc_res, (w_sigma_res, p_mbar_sigma_res)| {
598 w_sigma_res.and_then(|w_sigma| {
599 p_mbar_sigma_res.and_then(|p_mbar_sigma| {
600 let j_term_1 =
602 einsum("ikjl,ji,lk->", &[o2, &w.view(), &p_mbar_sigma.view()])
603 .map_err(|err| format_err!(err))?
604 .into_dimensionality::<Ix0>()?
605 .into_iter()
606 .next()
607 .ok_or_else(|| {
608 format_err!(
609 "Unable to extract the result of the einsum contraction."
610 )
611 })?;
612 let j_term_2 =
613 einsum("ikjl,ji,lk->", &[o2, &p_mbar_sigma.view(), &w.view()])
614 .map_err(|err| format_err!(err))?
615 .into_dimensionality::<Ix0>()?
616 .into_iter()
617 .next()
618 .ok_or_else(|| {
619 format_err!(
620 "Unable to extract the result of the einsum contraction."
621 )
622 })?;
623 let k_term_1 = einsum(
624 "ikjl,li,jk->",
625 &[o2, &w_sigma.view(), &p_mbar_sigma.view()],
626 )
627 .map_err(|err| format_err!(err))?
628 .into_dimensionality::<Ix0>()?
629 .into_iter()
630 .next()
631 .ok_or_else(|| {
632 format_err!(
633 "Unable to extract the result of the einsum contraction."
634 )
635 })?;
636 let k_term_2 = einsum(
637 "ikjl,li,jk->",
638 &[o2, &p_mbar_sigma.view(), &w_sigma.view()],
639 )
640 .map_err(|err| format_err!(err))?
641 .into_dimensionality::<Ix0>()?
642 .into_iter()
643 .next()
644 .ok_or_else(|| {
645 format_err!(
646 "Unable to extract the result of the einsum contraction."
647 )
648 })?;
649 acc_res.map(|acc| acc + j_term_1 + j_term_2 - k_term_1 - k_term_2)
650 })
651 })
652 })
653 .map(|v| v * reduced_ov / (T::one() + T::one()))
654 } else {
655 ensure!(
656 nzeros == 2,
657 "Unexpected number of zero Löwdin overlaps: {nzeros} != 2."
658 );
659
660 let ps = (0..structure_constraint.implicit_factor()?)
661 .flat_map(|_| {
662 lowdin_paired_coefficientss.iter().flat_map(|lpc| {
663 lpc.zero_indices()
664 .iter()
665 .map(|mbar| calc_unweighted_codensity_matrix(lpc, *mbar))
666 })
667 })
668 .collect::<Result<Vec<_>, _>>()?;
669 ensure!(
670 ps.len() == 2,
671 "Unexpected number of unweighted codensity matrices ({}) for two zero overlaps.",
672 ps.len()
673 );
674 let p_mbar = ps.first().ok_or_else(|| {
675 format_err!("Unable to retrieve the first computed unweighted codensity matrix.")
676 })?;
677 let p_nbar = ps.last().ok_or_else(|| {
678 format_err!("Unable to retrieve the second computed unweighted codensity matrix.")
679 })?;
680
681 let j_term_1 = einsum("ikjl,ji,lk->", &[o2, &p_mbar.view(), &p_nbar.view()])
683 .map_err(|err| format_err!(err))?
684 .into_dimensionality::<Ix0>()?
685 .into_iter()
686 .next()
687 .ok_or_else(|| {
688 format_err!("Unable to extract the result of the einsum contraction.")
689 })?;
690 let j_term_2 = einsum("ikjl,ji,lk->", &[o2, &p_nbar.view(), &p_mbar.view()])
691 .map_err(|err| format_err!(err))?
692 .into_dimensionality::<Ix0>()?
693 .into_iter()
694 .next()
695 .ok_or_else(|| {
696 format_err!("Unable to extract the result of the einsum contraction.")
697 })?;
698
699 let (k_term_1, k_term_2) = if lowdin_paired_coefficientss
700 .iter()
701 .any(|lpc| lpc.n_lowdin_zeros() == 2)
702 {
703 let k_term_1 = einsum("ikjl,li,jk->", &[o2, &p_mbar.view(), &p_nbar.view()])
704 .map_err(|err| format_err!(err))?
705 .into_dimensionality::<Ix0>()?
706 .into_iter()
707 .next()
708 .ok_or_else(|| {
709 format_err!("Unable to extract the result of the einsum contraction.")
710 })?;
711 let k_term_2 = einsum("ikjl,li,jk->", &[o2, &p_nbar.view(), &p_mbar.view()])
712 .map_err(|err| format_err!(err))?
713 .into_dimensionality::<Ix0>()?
714 .into_iter()
715 .next()
716 .ok_or_else(|| {
717 format_err!("Unable to extract the result of the einsum contraction.")
718 })?;
719 (k_term_1, k_term_2)
720 } else {
721 (T::zero(), T::zero())
722 };
723 Ok(reduced_ov * (j_term_1 - k_term_1 + j_term_2 - k_term_2) / (T::one() + T::one()))
724 }
725 }
726}
727
728pub fn complex_modified_gram_schmidt<T>(
748 vmat: &ArrayView2<T>,
749 complex_symmetric: bool,
750 thresh: <T as ComplexFloat>::Real,
751) -> Result<Array2<T>, anyhow::Error>
752where
753 T: ComplexFloat + std::fmt::Display + 'static,
754{
755 let mut us: Vec<Array1<T>> = Vec::with_capacity(vmat.shape()[1]);
756 let mut us_sq_norm: Vec<T> = Vec::with_capacity(vmat.shape()[1]);
757 for (i, vi) in vmat.columns().into_iter().enumerate() {
758 us.push(vi.to_owned());
760
761 for j in 0..i {
765 let p_uj_ui = if complex_symmetric {
766 us[j].t().dot(&us[i]) / us_sq_norm[j]
767 } else {
768 us[j].t().map(|x| x.conj()).dot(&us[i]) / us_sq_norm[j]
769 };
770 us[i] = &us[i] - us[j].map(|&x| x * p_uj_ui);
771 }
772
773 let us_sq_norm_i = if complex_symmetric {
776 us[i].t().dot(&us[i])
777 } else {
778 us[i].t().map(|x| x.conj()).dot(&us[i])
779 };
780 if us_sq_norm_i.abs() < thresh {
781 return Err(format_err!("A zero-norm vector found: {}", us[i]));
782 }
783 us_sq_norm.push(us_sq_norm_i);
784 }
785
786 for i in 0..us.len() {
788 us[i].mapv_inplace(|x| x / us_sq_norm[i].sqrt());
789 }
790
791 let ortho_check = us.iter().enumerate().all(|(i, ui)| {
792 us.iter().enumerate().all(|(j, uj)| {
793 let ov_ij = if complex_symmetric {
794 ui.dot(uj)
795 } else {
796 ui.map(|x| x.conj()).dot(uj)
797 };
798 i == j || ov_ij.abs() < thresh
799 })
800 });
801
802 if ortho_check {
803 stack(Axis(1), &us.iter().map(|u| u.view()).collect_vec()).map_err(|err| format_err!(err))
804 } else {
805 Err(format_err!(
806 "Post-Gram--Schmidt orthogonality check failed."
807 ))
808 }
809}
810
811pub trait CanonicalOrthogonalisable {
813 type NumType;
815
816 type RealType;
818
819 fn calc_canonical_orthogonal_matrix(
836 &self,
837 complex_symmetric: bool,
838 preserves_full_rank: bool,
839 thresh_offdiag: Self::RealType,
840 thresh_zeroov: Self::RealType,
841 ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error>;
842}
843
844pub struct CanonicalOrthogonalisationResult<T> {
846 eigenvalues: Array1<T>,
848
849 xmat: Array2<T>,
851
852 xmat_d: Array2<T>,
856}
857
858impl<T> CanonicalOrthogonalisationResult<T> {
859 pub fn eigenvalues(&self) -> ArrayView1<T> {
861 self.eigenvalues.view()
862 }
863
864 pub fn xmat(&self) -> ArrayView2<T> {
866 self.xmat.view()
867 }
868
869 pub fn xmat_d(&self) -> ArrayView2<T> {
873 self.xmat_d.view()
874 }
875}
876
877#[duplicate_item(
878 [
879 dtype_ [ f64 ]
880 ]
881 [
882 dtype_ [ f32 ]
883 ]
884)]
885impl CanonicalOrthogonalisable for ArrayView2<'_, dtype_> {
886 type NumType = dtype_;
887
888 type RealType = dtype_;
889
890 fn calc_canonical_orthogonal_matrix(
891 &self,
892 _: bool,
893 preserves_full_rank: bool,
894 thresh_offdiag: dtype_,
895 thresh_zeroov: dtype_,
896 ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error> {
897 let smat = self;
898
899 ensure!(
901 (smat.to_owned() - smat.t()).norm_l2() <= thresh_offdiag,
902 "Overlap matrix is not real-symmetric."
903 );
904
905 let (s_eig, umat) = smat.eigh(UPLO::Lower).map_err(|err| format_err!(err))?;
907 let nonzero_s_indices = s_eig
909 .iter()
910 .positions(|x| x.abs() > thresh_zeroov)
911 .collect_vec();
912 let nonzero_s_eig = s_eig.select(Axis(0), &nonzero_s_indices);
913 if nonzero_s_eig.iter().any(|v| *v < 0.0) {
914 return Err(format_err!(
915 "The matrix has negative eigenvalues and therefore cannot be orthogonalised over the reals."
916 ));
917 }
918 let nonzero_umat = umat.select(Axis(1), &nonzero_s_indices);
919 let nullity = smat.shape()[0] - nonzero_s_indices.len();
920 let (xmat, xmat_d) = if nullity == 0 && preserves_full_rank {
921 (Array2::eye(smat.shape()[0]), Array2::eye(smat.shape()[0]))
922 } else {
923 let s_s = Array2::<dtype_>::from_diag(&nonzero_s_eig.mapv(|x| 1.0 / x.sqrt()));
924 (nonzero_umat.dot(&s_s), s_s.dot(&nonzero_umat.t()))
925 };
926 let res = CanonicalOrthogonalisationResult {
927 eigenvalues: s_eig,
928 xmat,
929 xmat_d,
930 };
931 Ok(res)
932 }
933}
934
935impl<T> CanonicalOrthogonalisable for ArrayView2<'_, Complex<T>>
936where
937 T: Float + Scalar<Complex = Complex<T>>,
938 Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
939{
940 type NumType = Complex<T>;
941
942 type RealType = T;
943
944 fn calc_canonical_orthogonal_matrix(
945 &self,
946 complex_symmetric: bool,
947 preserves_full_rank: bool,
948 thresh_offdiag: T,
949 thresh_zeroov: T,
950 ) -> Result<CanonicalOrthogonalisationResult<Self::NumType>, anyhow::Error> {
951 let smat = self;
952
953 if complex_symmetric {
954 ensure!(
956 (smat.to_owned() - smat.t()).norm_l2() <= thresh_offdiag,
957 "Overlap matrix is not complex-symmetric."
958 );
959 } else {
960 ensure!(
962 (smat.to_owned() - smat.map(|v| v.conj()).t()).norm_l2() <= thresh_offdiag,
963 "Overlap matrix is not complex-Hermitian."
964 );
965 }
966
967 let (s_eig, umat_nonortho) = smat.eig().map_err(|err| format_err!(err))?;
968
969 let nonzero_s_indices = s_eig
970 .iter()
971 .positions(|x| ComplexFloat::abs(*x) > thresh_zeroov)
972 .collect_vec();
973 let nonzero_s_eig = s_eig.select(Axis(0), &nonzero_s_indices);
974 let nonzero_umat_nonortho = umat_nonortho.select(Axis(1), &nonzero_s_indices);
975
976 let nonzero_umat = complex_modified_gram_schmidt(
979 &nonzero_umat_nonortho.view(),
980 complex_symmetric,
981 thresh_zeroov,
982 )
983 .map_err(
984 |_| format_err!("Unable to orthonormalise the linearly-independent eigenvectors of the overlap matrix.")
985 )?;
986
987 let nullity = smat.shape()[0] - nonzero_s_indices.len();
988 let (xmat, xmat_d) = if nullity == 0 && preserves_full_rank {
989 (
990 Array2::<Complex<T>>::eye(smat.shape()[0]),
991 Array2::<Complex<T>>::eye(smat.shape()[0]),
992 )
993 } else {
994 let s_s = Array2::<Complex<T>>::from_diag(
995 &nonzero_s_eig.mapv(|x| Complex::<T>::from(T::one()) / x.sqrt()),
996 );
997 if complex_symmetric {
998 (nonzero_umat.dot(&s_s), s_s.dot(&nonzero_umat.t()))
999 } else {
1000 let xmat = nonzero_umat.dot(&s_s);
1001 let xmat_d = xmat.map(|v| v.conj()).t().to_owned();
1002 (xmat, xmat_d)
1003 }
1004 };
1005 let res = CanonicalOrthogonalisationResult {
1006 eigenvalues: s_eig,
1007 xmat,
1008 xmat_d,
1009 };
1010 Ok(res)
1011 }
1012}