qsym2/target/noci/backend/matelem/
mod.rs

1use std::collections::HashSet;
2use std::fmt::{self, LowerExp};
3
4use anyhow::{self, ensure, format_err};
5use itertools::Itertools;
6use log;
7use ndarray::{Array2, Array3, ArrayView2, s};
8use ndarray_linalg::types::Lapack;
9use num_complex::ComplexFloat;
10
11use crate::angmom::spinor_rotation_3d::StructureConstraint;
12use crate::symmetry::symmetry_element::SpecialSymmetryTransformation;
13use crate::symmetry::symmetry_group::SymmetryGroupProperties;
14use crate::symmetry::symmetry_transformation::SymmetryTransformable;
15use crate::target::determinant::SlaterDeterminant;
16use crate::target::noci::basis::{Basis, OrbitBasis};
17
18pub mod hamiltonian;
19pub mod overlap;
20
21pub trait OrbitMatrix<'a, T, SC>
22where
23    T: Lapack + ComplexFloat,
24    SC: StructureConstraint + Clone + fmt::Display,
25    SlaterDeterminant<'a, T, SC>: SymmetryTransformable,
26{
27    /// The type of the matrix elements.
28    type MatrixElement;
29
30    // ----------------
31    // Required methods
32    // ----------------
33    /// Calculates the matrix element between two Slater determinants.
34    ///
35    /// # Arguments
36    ///
37    /// * `det_w` - The determinant $`^{w}\Psi`$.
38    /// * `det_x` - The determinant $`^{x}\Psi`$.
39    /// * `sao` - The atomic-orbital overlap matrix.
40    /// * `thresh_offdiag` - Threshold for determining non-zero off-diagonal elements in the
41    ///   orbital overlap matrix between $`^{w}\Psi`$ and $`^{x}\Psi`$ during Löwdin pairing.
42    /// * `thresh_zeroov` - Threshold for identifying zero Löwdin overlaps.
43    ///
44    /// # Returns
45    ///
46    /// The resulting matrix element.
47    fn calc_matrix_element(
48        &self,
49        det_w: &SlaterDeterminant<T, SC>,
50        det_x: &SlaterDeterminant<T, SC>,
51        sao: &ArrayView2<T>,
52        thresh_offdiag: <T as ComplexFloat>::Real,
53        thresh_zeroov: <T as ComplexFloat>::Real,
54    ) -> Result<Self::MatrixElement, anyhow::Error>;
55
56    /// Returns a string representing the operator for the matrix element.
57    fn op() -> &'a str;
58
59    /// Computes the transpose of a matrix element.
60    fn t(x: &Self::MatrixElement) -> Self::MatrixElement;
61
62    /// Computes the complex conjugation of a matrix element.
63    fn conj(x: &Self::MatrixElement) -> Self::MatrixElement;
64
65    /// Returns the zero matrix element.
66    fn zero(&self) -> Self::MatrixElement;
67
68    // ----------------
69    // Provided methods
70    // ----------------
71
72    /// Returns the norm-presearving scalar map connecting diagonally-symmetric elements in the
73    /// matrix.
74    #[allow(clippy::type_complexity)]
75    fn norm_preserving_scalar_map<'b, G>(
76        &self,
77        i: usize,
78        orbit_basis: &'b OrbitBasis<'b, G, SlaterDeterminant<'a, T, SC>>,
79    ) -> Result<fn(&Self::MatrixElement) -> Self::MatrixElement, anyhow::Error>
80    where
81        G: SymmetryGroupProperties + Clone,
82        'a: 'b,
83    {
84        let group = orbit_basis.group();
85        let complex_symmetric_set = orbit_basis
86            .origins()
87            .iter()
88            .map(|det| det.complex_symmetric())
89            .collect::<HashSet<_>>();
90        ensure!(
91            complex_symmetric_set.len() == 1,
92            "Inconsistent complex-symmetric flags across origin determinants."
93        );
94        let complex_symmetric = *complex_symmetric_set
95            .iter()
96            .next()
97            .ok_or(format_err!("Unable to obtain the complex-symmetric flag."))?;
98        if complex_symmetric {
99            Err(format_err!(
100                "`norm_preserving_scalar_map` is currently not implemented for complex-symmetric inner products. This thus precludes the use of the Cayley table to speed up the computation of orbit matrices."
101            ))
102        } else if group
103            .get_index(i)
104            .unwrap_or_else(|| panic!("Group operation index `{i}` not found."))
105            .contains_time_reversal()
106        {
107            Ok(Self::conj)
108        } else {
109            Ok(Self::t)
110        }
111    }
112
113    /// Computes the entire matrix of matrix elements in an orbit basis, making use of group
114    /// closure for optimisation.
115    ///
116    /// # Arguments
117    ///
118    /// * `orbit_basis` - The orbit basis in which the matrix elements are to be computed.
119    /// * `use_cayley_table` - Boolean indicating whether group closure should be used to speed up
120    ///   the computation.
121    /// * `sao` - The atomic-orbital overlap matrix.
122    /// * `thresh_offdiag` - Threshold for determining non-zero off-diagonal elements in the
123    ///   orbital overlap matrix between two Slater determinants during Löwdin pairing.
124    /// * `thresh_zeroov` - Threshold for identifying zero Löwdin overlaps.
125    fn calc_orbit_matrix<'g, G>(
126        &self,
127        orbit_basis: &'g OrbitBasis<'g, G, SlaterDeterminant<'a, T, SC>>,
128        use_cayley_table: bool,
129        sao: &ArrayView2<T>,
130        thresh_offdiag: <T as ComplexFloat>::Real,
131        thresh_zeroov: <T as ComplexFloat>::Real,
132    ) -> Result<Array2<Self::MatrixElement>, anyhow::Error>
133    where
134        G: SymmetryGroupProperties + Clone,
135        T: Sync + Send,
136        <T as ComplexFloat>::Real: Sync,
137        SlaterDeterminant<'a, T, SC>: Sync,
138        Self: Sync,
139        Self::MatrixElement: Send + LowerExp,
140        'a: 'g,
141        Self::MatrixElement: Clone,
142    {
143        let group = orbit_basis.group();
144        let order = group.order();
145        let det_origins = orbit_basis.origins();
146        let n_det_origins = det_origins.len();
147        let mut mat = Array2::<Self::MatrixElement>::from_elem(
148            (n_det_origins * order, n_det_origins * order),
149            self.zero(),
150        );
151
152        if let (Some(ctb), true) = (group.cayley_table(), use_cayley_table) {
153            log::debug!(
154                "Cayley table available and its use requested. Group closure will be used to speed up orbit matrix computation."
155            );
156            // Compute unique matrix elements
157            let mut ov_elems = orbit_basis
158                .iter()
159                .collect::<Result<Vec<_>, _>>()?
160                .iter()
161                .enumerate()
162                .cartesian_product(orbit_basis.origins().iter().enumerate())
163                // .par_bridge()
164                .map(|((k_ii, k_ii_det), (jj, jj_det))| {
165                    let k = k_ii.div_euclid(n_det_origins);
166                    let ii = k_ii.rem_euclid(n_det_origins);
167                    (
168                        ii,
169                        jj,
170                        k,
171                        self.calc_matrix_element(
172                            k_ii_det,
173                            jj_det,
174                            sao,
175                            thresh_offdiag,
176                            thresh_zeroov,
177                        ),
178                    )
179                })
180                .collect::<Vec<_>>();
181            ov_elems.sort_by_key(|v| (v.0, v.1, v.2));
182            let mut ov_ii_jj_k =
183                Array3::from_elem((n_det_origins, n_det_origins, order), self.zero());
184            for (ii, jj, k, elem_res) in ov_elems {
185                log::debug!(
186                    "⟨g_{k} Ψ_{ii} {} Ψ_{jj}⟩ = ⟨{} Ψ_{ii} {} Ψ_{jj}⟩ = {}",
187                    Self::op(),
188                    group
189                        .get_index(k)
190                        .map(|g| g.to_string())
191                        .unwrap_or_else(|| format!("g_{k}")),
192                    Self::op(),
193                    elem_res
194                        .as_ref()
195                        .map(|v| format!("{v:+.8e}"))
196                        .unwrap_or_else(|err| err.to_string())
197                );
198                ov_ii_jj_k[(ii, jj, k)] = elem_res?;
199            }
200
201            // Populate all matrix elements
202            for v in [
203                (0..order),
204                (0..n_det_origins),
205                (0..order),
206                (0..n_det_origins),
207            ]
208            .into_iter()
209            .multi_cartesian_product()
210            {
211                let i = v[0];
212                let ii = v[1];
213                let j = v[2];
214                let jj = v[3];
215
216                let jinv = ctb
217                    .slice(s![.., j])
218                    .iter()
219                    .position(|&x| x == 0)
220                    .ok_or(format_err!(
221                        "Unable to find the inverse of group element `{j}`."
222                    ))?;
223                let k = ctb[(jinv, i)];
224                log::debug!(
225                    "{}^(-1) = {} ⇒ ⟨g_{i} Ψ_{ii} {} g_{j} Ψ_{jj}⟩ = ⟨{} Ψ_{ii} {} {} Ψ_{jj}⟩ = ⟨{} Ψ_{ii} {} Ψ_{jj}⟩ = {:+8e}",
226                    group
227                        .get_index(j)
228                        .map(|g| g.to_string())
229                        .unwrap_or_else(|| format!("g_{j}")),
230                    group
231                        .get_index(jinv)
232                        .map(|g| g.to_string())
233                        .unwrap_or_else(|| format!("g_{jinv}")),
234                    Self::op(),
235                    group
236                        .get_index(i)
237                        .map(|g| g.to_string())
238                        .unwrap_or_else(|| format!("g_{i}")),
239                    Self::op(),
240                    group
241                        .get_index(j)
242                        .map(|g| g.to_string())
243                        .unwrap_or_else(|| format!("g_{j}")),
244                    group
245                        .get_index(k)
246                        .map(|g| g.to_string())
247                        .unwrap_or_else(|| format!("g_{k}")),
248                    Self::op(),
249                    ov_ii_jj_k[(ii, jj, k)],
250                );
251                mat[(i * n_det_origins + ii, j * n_det_origins + jj)] =
252                    self.norm_preserving_scalar_map(jinv, orbit_basis)?(&ov_ii_jj_k[(ii, jj, k)]);
253            }
254        } else {
255            log::debug!(
256                "Cayley table not available or its use not requested. Group closure will not be used for orbit matrix computation."
257            );
258            let orbit_basis_vec = orbit_basis.iter().collect::<Result<Vec<_>, _>>()?;
259            let mut elems = orbit_basis_vec
260                .iter()
261                .enumerate()
262                .cartesian_product(orbit_basis_vec.iter().enumerate())
263                .map(|((i_ii, i_ii_det), (j_jj, j_jj_det))| {
264                    let i = i_ii.div_euclid(n_det_origins);
265                    let ii = i_ii.rem_euclid(n_det_origins);
266                    let j = j_jj.div_euclid(n_det_origins);
267                    let jj = j_jj.rem_euclid(n_det_origins);
268                    let elem_res = self.calc_matrix_element(
269                        i_ii_det,
270                        j_jj_det,
271                        sao,
272                        thresh_offdiag,
273                        thresh_zeroov,
274                    );
275                    (i, ii, j, jj, elem_res)
276                })
277                .collect::<Vec<_>>();
278            elems.sort_by_key(|v| (v.1, v.0, v.3, v.2));
279            for (i, ii, j, jj, elem_res) in elems {
280                log::debug!(
281                    "⟨g_{i} Ψ_{ii} {} g_{j} Ψ_{jj}⟩ = ⟨{} Ψ_{ii} {} {} Ψ_{jj}⟩ = {}",
282                    Self::op(),
283                    group
284                        .get_index(i)
285                        .map(|g| g.to_string())
286                        .unwrap_or_else(|| format!("g_{i}")),
287                    Self::op(),
288                    group
289                        .get_index(j)
290                        .map(|g| g.to_string())
291                        .unwrap_or_else(|| format!("g_{j}")),
292                    elem_res
293                        .as_ref()
294                        .map(|v| format!("{v:+.8e}"))
295                        .unwrap_or_else(|err| err.to_string())
296                );
297                mat[(i * n_det_origins + ii, j * n_det_origins + jj)] = elem_res?;
298            }
299        }
300        Ok(mat)
301    }
302}