qsym2/target/noci/backend/matelem/
overlap.rs

1use std::marker::PhantomData;
2
3use anyhow::{self, format_err};
4use derive_builder::Builder;
5use itertools::Itertools;
6use ndarray::{Array2, ArrayView2, Ix2};
7use ndarray_linalg::Lapack;
8use num_complex::ComplexFloat;
9
10use crate::analysis::Overlap;
11use crate::angmom::spinor_rotation_3d::StructureConstraint;
12use crate::symmetry::symmetry_transformation::SymmetryTransformable;
13use crate::target::determinant::SlaterDeterminant;
14
15use super::OrbitMatrix;
16
17/// Structure for managing the overlap integrals in an atomic-orbital basis.
18#[derive(Builder)]
19pub struct OverlapAO<'a, T, SC>
20where
21    T: ComplexFloat + Lapack,
22    SC: StructureConstraint + Clone,
23{
24    /// The overlap integral in an atomic-orbital basis.
25    sao: ArrayView2<'a, T>,
26
27    /// The structure constraint for the wavefunctions on the Hilbert space with this overlap
28    /// metric.
29    #[builder(setter(skip), default = "PhantomData")]
30    structure_constraint: PhantomData<SC>,
31}
32
33impl<'a, T, SC> OverlapAO<'a, T, SC>
34where
35    T: ComplexFloat + Lapack,
36    SC: StructureConstraint + Clone,
37{
38    /// Returns a builder for [`OverlapAO`].
39    pub fn builder() -> OverlapAOBuilder<'a, T, SC> {
40        OverlapAOBuilder::<T, SC>::default()
41    }
42
43    /// Returns the overlap integrals in an atomic-orbital basis.
44    pub fn sao(&'a self) -> &'a ArrayView2<'a, T> {
45        &self.sao
46    }
47}
48
49impl<'a, T, SC> OverlapAO<'a, T, SC>
50where
51    T: ComplexFloat + Lapack,
52    SC: StructureConstraint + Clone + std::fmt::Display,
53    for<'b> SlaterDeterminant<'b, T, SC>: Overlap<T, Ix2>,
54{
55    /// Calculates the overlap matrix element between two determinants.
56    pub fn calc_overlap_matrix_element(
57        &self,
58        det_w: &SlaterDeterminant<T, SC>,
59        det_x: &SlaterDeterminant<T, SC>,
60    ) -> Result<T, anyhow::Error> {
61        if det_w.complex_symmetric() != det_x.complex_symmetric() {
62            return Err(format_err!(
63                "The `complex_symmetric` booleans of the specified determinants do not match: `det_w` (`{}`) != `det_x` (`{}`).",
64                det_w.complex_symmetric(),
65                det_x.complex_symmetric(),
66            ));
67        }
68        det_w.overlap(det_x, Some(&self.sao.to_owned()), None)
69    }
70
71    pub fn calc_overlap_matrix(
72        &self,
73        dets: &[&SlaterDeterminant<T, SC>],
74    ) -> Result<Array2<T>, anyhow::Error> {
75        let dim = dets.len();
76        let mut smat = Array2::<T>::zeros((dim, dim));
77        for pair in dets.iter().enumerate().combinations_with_replacement(2) {
78            let (w, det_w) = &pair[0];
79            let (x, det_x) = &pair[1];
80            let ov_wx = self.calc_overlap_matrix_element(det_w, det_x)?;
81            smat[(*w, *x)] = ov_wx;
82            if *w != *x {
83                let ov_xw = self.calc_overlap_matrix_element(det_x, det_w)?;
84                smat[(*x, *w)] = ov_xw;
85            }
86        }
87        Ok(smat)
88    }
89}
90
91impl<'a, T, SC> OrbitMatrix<'a, T, SC> for &OverlapAO<'a, T, SC>
92where
93    T: ComplexFloat + Lapack,
94    SC: StructureConstraint + Clone + std::fmt::Display,
95    for<'b> SlaterDeterminant<'b, T, SC>: Overlap<T, Ix2>,
96    SlaterDeterminant<'a, T, SC>: SymmetryTransformable,
97{
98    fn calc_matrix_element(
99        &self,
100        det_w: &SlaterDeterminant<T, SC>,
101        det_x: &SlaterDeterminant<T, SC>,
102        _sao: &ArrayView2<T>,
103        _thresh_offdiag: <T as ComplexFloat>::Real,
104        _thresh_zeroov: <T as ComplexFloat>::Real,
105    ) -> Result<T, anyhow::Error> {
106        self.calc_overlap_matrix_element(det_w, det_x)
107    }
108}
109
110impl<'a, T, SC> OrbitMatrix<'a, T, SC> for OverlapAO<'a, T, SC>
111where
112    T: ComplexFloat + Lapack,
113    SC: StructureConstraint + Clone + std::fmt::Display,
114    for<'b> SlaterDeterminant<'b, T, SC>: Overlap<T, Ix2>,
115    SlaterDeterminant<'a, T, SC>: SymmetryTransformable,
116{
117    fn calc_matrix_element(
118        &self,
119        det_w: &SlaterDeterminant<T, SC>,
120        det_x: &SlaterDeterminant<T, SC>,
121        _sao: &ArrayView2<T>,
122        _thresh_offdiag: <T as ComplexFloat>::Real,
123        _thresh_zeroov: <T as ComplexFloat>::Real,
124    ) -> Result<T, anyhow::Error> {
125        (&self).calc_matrix_element(det_w, det_x, _sao, _thresh_offdiag, _thresh_zeroov)
126    }
127}