qsym2/target/noci/backend/matelem/
mod.rs

1use std::collections::HashSet;
2use std::fmt;
3
4use anyhow::{self, ensure, format_err};
5use itertools::Itertools;
6use log;
7use ndarray::{s, Array2, Array3, ArrayView2};
8use ndarray_linalg::types::Lapack;
9use num_complex::ComplexFloat;
10use rayon::iter::{ParallelBridge, ParallelIterator};
11
12use crate::angmom::spinor_rotation_3d::StructureConstraint;
13use crate::symmetry::symmetry_element::SpecialSymmetryTransformation;
14use crate::symmetry::symmetry_group::SymmetryGroupProperties;
15use crate::symmetry::symmetry_transformation::SymmetryTransformable;
16use crate::target::determinant::SlaterDeterminant;
17use crate::target::noci::basis::{Basis, OrbitBasis};
18
19pub mod hamiltonian;
20pub mod overlap;
21
22pub trait OrbitMatrix<'a, T, SC>
23where
24    T: Lapack + ComplexFloat,
25    SC: StructureConstraint + Clone + fmt::Display,
26    SlaterDeterminant<'a, T, SC>: SymmetryTransformable,
27{
28    // ----------------
29    // Required methods
30    // ----------------
31    fn calc_matrix_element(
32        &self,
33        det_w: &SlaterDeterminant<T, SC>,
34        det_x: &SlaterDeterminant<T, SC>,
35        sao: &ArrayView2<T>,
36        thresh_offdiag: <T as ComplexFloat>::Real,
37        thresh_zeroov: <T as ComplexFloat>::Real,
38    ) -> Result<T, anyhow::Error>;
39
40    // ----------------
41    // Provided methods
42    // ----------------
43    fn norm_preserving_scalar_map<'b, G>(
44        &self,
45        i: usize,
46        orbit_basis: &'b OrbitBasis<'b, G, SlaterDeterminant<'a, T, SC>>,
47    ) -> Result<fn(T) -> T, anyhow::Error>
48    where
49        G: SymmetryGroupProperties + Clone,
50        'a: 'b,
51    {
52        let group = orbit_basis.group();
53        let complex_symmetric_set = orbit_basis
54            .origins()
55            .iter()
56            .map(|det| det.complex_symmetric())
57            .collect::<HashSet<_>>();
58        ensure!(
59            complex_symmetric_set.len() == 1,
60            "Inconsistent complex-symmetric flags across origin determinants."
61        );
62        let complex_symmetric = *complex_symmetric_set
63            .iter()
64            .next()
65            .ok_or(format_err!("Unable to obtain the complex-symmetric flag."))?;
66        if complex_symmetric {
67            Err(format_err!(
68                "`norm_preserving_scalar_map` is currently not implemented for complex-symmetric inner products. This thus precludes the use of the Cayley table to speed up the computation of orbit matrices."
69            ))
70        } else {
71            if group
72                .get_index(i)
73                .unwrap_or_else(|| panic!("Group operation index `{i}` not found."))
74                .contains_time_reversal()
75            {
76                Ok(ComplexFloat::conj)
77            } else {
78                Ok(|x| x)
79            }
80        }
81    }
82
83    fn calc_orbit_matrix<'g, G>(
84        &self,
85        orbit_basis: &'g OrbitBasis<'g, G, SlaterDeterminant<'a, T, SC>>,
86        use_cayley_table: bool,
87        sao: &ArrayView2<T>,
88        thresh_offdiag: <T as ComplexFloat>::Real,
89        thresh_zeroov: <T as ComplexFloat>::Real,
90    ) -> Result<Array2<T>, anyhow::Error>
91    where
92        G: SymmetryGroupProperties + Clone,
93        T: Sync + Send,
94        <T as ComplexFloat>::Real: Sync,
95        SlaterDeterminant<'a, T, SC>: Sync,
96        Self: Sync,
97        'a: 'g,
98    {
99        let group = orbit_basis.group();
100        let order = group.order();
101        let det_origins = orbit_basis.origins();
102        let n_det_origins = det_origins.len();
103        let mut mat = Array2::<T>::zeros((n_det_origins * order, n_det_origins * order));
104
105        if let (Some(ctb), true) = (group.cayley_table(), use_cayley_table) {
106            log::debug!(
107                "Cayley table available. Group closure will be used to speed up orbit matrix computation."
108            );
109            // Compute unique matrix elements
110            let ov_elems = orbit_basis
111                .iter()
112                .collect::<Result<Vec<_>, _>>()?
113                .iter()
114                .enumerate()
115                .cartesian_product(orbit_basis.origins().iter().enumerate())
116                .par_bridge()
117                .map(|((k_ii, k_ii_det), (jj, jj_det))| {
118                    let k = k_ii.div_euclid(n_det_origins);
119                    let ii = k_ii.rem_euclid(n_det_origins);
120                    (
121                        ii,
122                        jj,
123                        k,
124                        self.calc_matrix_element(
125                            k_ii_det,
126                            jj_det,
127                            sao,
128                            thresh_offdiag,
129                            thresh_zeroov,
130                        ),
131                    )
132                })
133                .collect::<Vec<_>>();
134            let mut ov_ii_jj_k = Array3::zeros((n_det_origins, n_det_origins, order));
135            for (ii, jj, k, elem_res) in ov_elems {
136                ov_ii_jj_k[(ii, jj, k)] = elem_res?;
137            }
138
139            // Populate all matrix elements
140            for v in [
141                (0..order),
142                (0..n_det_origins),
143                (0..order),
144                (0..n_det_origins),
145            ]
146            .into_iter()
147            .multi_cartesian_product()
148            {
149                let i = v[0];
150                let ii = v[1];
151                let j = v[2];
152                let jj = v[3];
153
154                let jinv = ctb
155                    .slice(s![.., j])
156                    .iter()
157                    .position(|&x| x == 0)
158                    .ok_or(format_err!(
159                        "Unable to find the inverse of group element `{j}`."
160                    ))?;
161                let k = ctb[(jinv, i)];
162                mat[(i + ii * order, j + jj * order)] =
163                    self.norm_preserving_scalar_map(jinv, orbit_basis)?(ov_ii_jj_k[(ii, jj, k)]);
164            }
165        } else {
166            let orbit_basis_vec = orbit_basis.iter().collect::<Result<Vec<_>, _>>()?;
167            let elems = orbit_basis_vec
168                .iter()
169                .enumerate()
170                .cartesian_product(orbit_basis_vec.iter().enumerate())
171                .par_bridge()
172                .map(|((i_ii, i_ii_det), (j_jj, j_jj_det))| {
173                    let i = i_ii.div_euclid(n_det_origins);
174                    let ii = i_ii.rem_euclid(n_det_origins);
175                    let j = j_jj.div_euclid(n_det_origins);
176                    let jj = j_jj.rem_euclid(n_det_origins);
177                    let elem_res = self.calc_matrix_element(
178                        i_ii_det,
179                        j_jj_det,
180                        sao,
181                        thresh_offdiag,
182                        thresh_zeroov,
183                    );
184                    (i, ii, j, jj, elem_res)
185                })
186                .collect::<Vec<_>>();
187            for (i, ii, j, jj, elem_res) in elems {
188                mat[(i + ii * order, j + jj * order)] = elem_res?;
189            }
190        }
191        Ok(mat)
192    }
193}