qsym2/target/noci/
multideterminant.rs

1//! Multi-determinant wavefunctions for non-orthogonal configuration interaction.
2
3use std::collections::HashSet;
4use std::fmt::{self, LowerExp};
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use anyhow::{ensure, format_err};
9use derive_builder::Builder;
10use itertools::Itertools;
11use log;
12use ndarray::{Array1, Array2, ArrayView2, Axis, ScalarOperand};
13use ndarray_linalg::types::Lapack;
14use num_complex::ComplexFloat;
15use rayon::prelude::*;
16
17use crate::angmom::spinor_rotation_3d::StructureConstraint;
18use crate::group::GroupProperties;
19use crate::target::determinant::SlaterDeterminant;
20use crate::target::noci::backend::nonortho::{calc_lowdin_pairing, calc_transition_density_matrix};
21use crate::target::noci::basis::{EagerBasis, OrbitBasis};
22
23use super::basis::Basis;
24
25#[path = "multideterminant_transformation.rs"]
26pub(crate) mod multideterminant_transformation;
27
28#[path = "multideterminant_analysis.rs"]
29pub(crate) mod multideterminant_analysis;
30
31#[cfg(test)]
32#[path = "multideterminant_tests.rs"]
33mod multideterminant_tests;
34
35// ------------------
36// Struct definitions
37// ------------------
38
39/// Structure to manage multi-determinantal wavefunctions.
40#[derive(Builder, Clone)]
41#[builder(build_fn(validate = "Self::validate"))]
42pub struct MultiDeterminant<'a, T, B, SC>
43where
44    T: ComplexFloat + Lapack,
45    SC: StructureConstraint + Hash + Eq + fmt::Display,
46    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
47{
48    #[builder(setter(skip), default = "PhantomData")]
49    _lifetime: PhantomData<&'a ()>,
50
51    #[builder(setter(skip), default = "PhantomData")]
52    _structure_constraint: PhantomData<SC>,
53
54    /// A boolean indicating if inner products involving this wavefunction should be the
55    /// complex-symmetric bilinear form, rather than the conventional Hermitian sesquilinear form.
56    #[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
57    complex_symmetric: bool,
58
59    /// A boolean indicating if the wavefunction has been acted on by an antiunitary operation. This
60    /// is so that the correct metric can be used during overlap evaluation.
61    #[builder(default = "false")]
62    complex_conjugated: bool,
63
64    /// The basis of Slater determinants in which this multi-determinantal wavefunction is defined.
65    basis: B,
66
67    /// The linear combination coefficients of the elements in the multi-orbit to give this
68    /// multi-determinant wavefunction.
69    coefficients: Array1<T>,
70
71    /// The energy of this multi-determinantal wavefunction.
72    #[builder(
73        default = "Err(\"Multi-determinantal wavefunction energy not yet set.\".to_string())"
74    )]
75    energy: Result<T, String>,
76
77    /// The threshold for comparing wavefunctions.
78    threshold: <T as ComplexFloat>::Real,
79}
80
81// ----------------------
82// Struct implementations
83// ----------------------
84
85impl<'a, T, B, SC> MultiDeterminantBuilder<'a, T, B, SC>
86where
87    T: ComplexFloat + Lapack,
88    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
89    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
90{
91    fn validate(&self) -> Result<(), String> {
92        let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
93        let coefficients = self
94            .coefficients
95            .as_ref()
96            .ok_or("No coefficients found.".to_string())?;
97        let nbasis = basis.n_items() == coefficients.len();
98        if !nbasis {
99            log::error!(
100                "The number of coefficients does not match the number of basis determinants."
101            );
102        }
103
104        let complex_symmetric = basis
105            .iter()
106            .map(|det_res| det_res.map(|det| det.complex_symmetric()))
107            .collect::<Result<HashSet<_>, _>>()
108            .map_err(|err| err.to_string())?
109            .len()
110            == 1;
111        if !complex_symmetric {
112            log::error!("Inconsistent complex-symmetric flag across basis determinants.");
113        }
114
115        let structcons_check = basis
116            .iter()
117            .map(|det_res| det_res.map(|det| det.structure_constraint().clone()))
118            .collect::<Result<HashSet<_>, _>>()
119            .map_err(|err| err.to_string())?
120            .len()
121            == 1;
122        if !structcons_check {
123            log::error!("Inconsistent spin constraints across basis determinants.");
124        }
125
126        if nbasis && structcons_check && complex_symmetric {
127            Ok(())
128        } else {
129            Err("Multi-determinant wavefunction validation failed.".to_string())
130        }
131    }
132
133    /// Retrieves the consistent complex-symmetric flag from the basis determinants.
134    fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
135        let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
136        let complex_symmetric_set = basis
137            .iter()
138            .map(|det_res| det_res.map(|det| det.complex_symmetric()))
139            .collect::<Result<HashSet<_>, _>>()
140            .map_err(|err| err.to_string())?;
141        if complex_symmetric_set.len() == 1 {
142            complex_symmetric_set
143                .into_iter()
144                .next()
145                .ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
146        } else {
147            Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
148        }
149    }
150}
151
152impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
153where
154    T: ComplexFloat + Lapack,
155    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
156    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
157{
158    /// Returns a builder to construct a new [`MultiDeterminant`].
159    pub fn builder() -> MultiDeterminantBuilder<'a, T, B, SC> {
160        MultiDeterminantBuilder::default()
161    }
162
163    /// Returns the structure constraint of the multi-determinantal wavefunction.
164    pub fn structure_constraint(&self) -> SC {
165        self.basis
166            .iter()
167            .next()
168            .expect("No basis determinant found.")
169            .expect("No basis determinant found.")
170            .structure_constraint()
171            .clone()
172    }
173}
174
175impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
176where
177    T: ComplexFloat + Lapack,
178    SC: StructureConstraint + Hash + Eq + fmt::Display,
179    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
180{
181    /// Returns the complex-conjugated flag of the multi-determinantal wavefunction.
182    pub fn complex_conjugated(&self) -> bool {
183        self.complex_conjugated
184    }
185
186    /// Returns the complex-symmetric flag of the multi-determinantal wavefunction.
187    pub fn complex_symmetric(&self) -> bool {
188        self.complex_symmetric
189    }
190
191    /// Returns the basis of determinants in which this multi-determinantal wavefunction is
192    /// defined.
193    pub fn basis(&self) -> &B {
194        &self.basis
195    }
196
197    /// Returns the coefficients of the basis determinants constituting this multi-determinantal
198    /// wavefunction.
199    pub fn coefficients(&self) -> &Array1<T> {
200        &self.coefficients
201    }
202
203    /// Returns the energy of the multi-determinantal wavefunction.
204    pub fn energy(&self) -> Result<&T, &String> {
205        self.energy.as_ref()
206    }
207
208    /// Returns the threshold with which multi-determinantal wavefunctions are compared.
209    pub fn threshold(&self) -> <T as ComplexFloat>::Real {
210        self.threshold
211    }
212}
213
214// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
215// Specific implementations for OrbitBasis
216// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
217
218impl<'a, T, G, SC> MultiDeterminant<'a, T, OrbitBasis<'a, G, SlaterDeterminant<'a, T, SC>>, SC>
219where
220    T: ComplexFloat + Lapack,
221    G: GroupProperties + Clone,
222    SC: StructureConstraint + Hash + Eq + fmt::Display + Clone,
223{
224    /// Converts this multi-determinant with an orbit basis into a multi-determinant with the
225    /// equivalent eager basis.
226    #[allow(clippy::type_complexity)]
227    pub fn to_eager_basis(
228        &self,
229    ) -> Result<MultiDeterminant<'a, T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>, anyhow::Error>
230    {
231        MultiDeterminant::<T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>::builder()
232            .complex_conjugated(self.complex_conjugated)
233            .basis(self.basis.to_eager()?)
234            .coefficients(self.coefficients().clone())
235            .energy(self.energy.clone())
236            .threshold(self.threshold)
237            .build()
238            .map_err(|err| format_err!(err))
239    }
240}
241
242// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
243// Generic implementation for all Basis
244// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
245
246impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
247where
248    T: ComplexFloat + Lapack + ScalarOperand + Send + Sync,
249    <T as ComplexFloat>::Real: LowerExp + fmt::Display + Sync,
250    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display + Sync,
251    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone + Sync,
252    SlaterDeterminant<'a, T, SC>: Send + Sync,
253{
254    /// Calculates the (contravariant) density matrix $`\mathbf{P}(\hat{\iota})`$ of the
255    /// multi-determinantal wavefunction in the AO basis.
256    ///
257    /// Note that the contravariant density matrix $`\mathbf{P}(\hat{\iota})`$ needs to be converted
258    /// to the mixed form $`\tilde{\mathbf{P}}(\hat{\iota})`$ given by
259    /// ```math
260    ///     \tilde{\mathbf{P}}(\hat{\iota}) = \mathbf{P}(\hat{\iota}) \mathbf{S}_{\mathrm{AO}}
261    /// ```
262    /// before being diagonalised to obtain natural orbitals and their occupation numbers.
263    ///
264    /// # Arguments
265    ///
266    /// * `sao` - The atomic-orbital overlap matrix.
267    /// * `thresh_offdiag` - Threshold for determining non-zero off-diagonal elements in the
268    ///   orbital overlap matrix two Slater determinants during Löwdin pairing.
269    /// * `thresh_zeroov` - Threshold for identifying zero Löwdin overlaps.
270    pub fn density_matrix(
271        &self,
272        sao: &ArrayView2<T>,
273        thresh_offdiag: <T as ComplexFloat>::Real,
274        thresh_zeroov: <T as ComplexFloat>::Real,
275        normalised_wavefunction: bool,
276    ) -> Result<Array2<T>, anyhow::Error> {
277        let nao = sao.nrows();
278        let dets = self.basis().iter().collect::<Result<Vec<_>, _>>()?;
279        let sqnorm_denmat_res = dets.iter()
280            .zip(self.coefficients().iter())
281            .cartesian_product(dets.iter().zip(self.coefficients().iter()))
282            .par_bridge()
283            .fold(
284                || Ok((T::zero(), Array2::<T>::zeros((nao, nao)))),
285                |acc_res, ((det_w, c_w), (det_x, c_x))| {
286                    ensure!(
287                        det_w.structure_constraint() == det_x.structure_constraint(),
288                        "Inconsistent spin constraints: {} != {}.",
289                        det_w.structure_constraint(),
290                        det_x.structure_constraint(),
291                    );
292
293                    if det_w.complex_symmetric() != det_x.complex_symmetric() {
294                        return Err(format_err!(
295                            "The `complex_symmetric` booleans of the specified determinants do not match: `det_w` (`{}`) != `det_x` (`{}`).",
296                            det_w.complex_symmetric(),
297                            det_x.complex_symmetric(),
298                        ));
299                    }
300                    let complex_symmetric = det_w.complex_symmetric();
301                    let lowdin_paired_coefficientss = det_w
302                        .coefficients()
303                        .iter()
304                        .zip(det_w.occupations().iter())
305                        .zip(det_x.coefficients().iter().zip(det_x.occupations().iter()))
306                        .map(|((cw, occw), (cx, occx))| {
307                            let occw_indices = occw
308                                .iter()
309                                .enumerate()
310                                .filter_map(|(i, occ_i)| {
311                                    if occ_i.abs() >= det_w.threshold() {
312                                        Some(i)
313                                    } else {
314                                        None
315                                    }
316                                })
317                                .collect::<Vec<_>>();
318                            let ne_w = occw_indices.len();
319                            let cw_occ = cw.select(Axis(1), &occw_indices);
320                            let occx_indices = occx
321                                .iter()
322                                .enumerate()
323                                .filter_map(|(i, occ_i)| {
324                                    if occ_i.abs() >= det_x.threshold() {
325                                        Some(i)
326                                    } else {
327                                        None
328                                    }
329                                })
330                                .collect::<Vec<_>>();
331                            let ne_x = occx_indices.len();
332                            ensure!(ne_w == ne_x, "Inconsistent number of electrons: {ne_w} != {ne_x}.");
333                            let cx_occ = cx.select(Axis(1), &occx_indices);
334                            calc_lowdin_pairing(
335                                &cw_occ.view(),
336                                &cx_occ.view(),
337                                sao,
338                                complex_symmetric,
339                                thresh_offdiag,
340                                thresh_zeroov,
341                            )
342                        })
343                        .collect::<Result<Vec<_>, _>>()?;
344                    let den_wx = calc_transition_density_matrix(&lowdin_paired_coefficientss, &self.structure_constraint())?;
345                    let ov_wx = lowdin_paired_coefficientss
346                        .iter()
347                        .flat_map(|lpc| lpc.lowdin_overlaps().iter())
348                        .fold(T::one(), |acc, ov| acc * *ov);
349
350                    let c_w = if self.complex_conjugated() {
351                        if complex_symmetric { c_w.conj() } else { *c_w }
352                    } else if complex_symmetric { *c_w } else { c_w.conj() };
353                    let c_x = if self.complex_conjugated() {
354                        c_x.conj()
355                    } else {
356                        *c_x
357                    };
358                    let den_wx = if self.complex_conjugated() {
359                        den_wx.mapv(|v| v.conj())
360                    } else {
361                        den_wx
362                    };
363                    acc_res.map(|(sqnorm_acc, denmat_acc)| (sqnorm_acc + ov_wx * c_w * c_x, denmat_acc + den_wx * c_w * c_x))
364                },
365            )
366            .reduce(
367                || Ok((T::zero(), Array2::<T>::zeros((nao, nao)))),
368                |sqnorm_denmat_res_a: Result<(T, Array2<T>), anyhow::Error>, sqnorm_denmat_res_b: Result<(T, Array2<T>), anyhow::Error>| {
369                    sqnorm_denmat_res_a.and_then(|(sqnorm_acc, denmat_acc)| sqnorm_denmat_res_b.map(|(sqnorm, denmat)| {
370                        (
371                            sqnorm_acc + sqnorm,
372                            denmat_acc + denmat
373                        )
374                    }))
375                }
376            );
377        sqnorm_denmat_res.map(|(sqnorm, denmat)| {
378            if normalised_wavefunction {
379                denmat / sqnorm
380            } else {
381                denmat
382            }
383        })
384    }
385}
386
387// -----
388// Debug
389// -----
390impl<'a, T, B, SC> fmt::Debug for MultiDeterminant<'a, T, B, SC>
391where
392    T: ComplexFloat + Lapack,
393    SC: StructureConstraint + Hash + Eq + fmt::Display,
394    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
395{
396    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397        write!(
398            f,
399            "MultiDeterminant over {} basis Slater determinants",
400            self.coefficients.len(),
401        )?;
402        Ok(())
403    }
404}
405
406// -------
407// Display
408// -------
409impl<'a, T, B, SC> fmt::Display for MultiDeterminant<'a, T, B, SC>
410where
411    T: ComplexFloat + Lapack,
412    SC: StructureConstraint + Hash + Eq + fmt::Display,
413    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
414{
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416        write!(
417            f,
418            "MultiDeterminant over {} basis Slater determinants",
419            self.coefficients.len(),
420        )?;
421        Ok(())
422    }
423}