qsym2/target/noci/
multideterminants.rs

1//! Collections of multi-determinant wavefunctions for non-orthogonal configuration interaction.
2
3use std::collections::HashSet;
4use std::fmt::{self, LowerExp};
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use anyhow::{ensure, format_err};
9use derive_builder::Builder;
10use itertools::Itertools;
11use log;
12use ndarray::{Array1, Array2, Array3, ArrayView2, Axis, Ix1, Ix3, ScalarOperand, ShapeBuilder};
13use ndarray_einsum::einsum;
14use ndarray_linalg::types::Lapack;
15use num_complex::ComplexFloat;
16use rayon::prelude::*;
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19use crate::group::GroupProperties;
20use crate::target::determinant::SlaterDeterminant;
21use crate::target::noci::backend::nonortho::{calc_lowdin_pairing, calc_transition_density_matrix};
22use crate::target::noci::basis::{EagerBasis, OrbitBasis};
23use crate::target::noci::multideterminant::MultiDeterminant;
24
25use super::basis::Basis;
26
27#[cfg(test)]
28#[path = "multideterminants_tests.rs"]
29mod multideterminants_tests;
30
31// ------------------
32// Struct definitions
33// ------------------
34
35/// Structure to manage collections of multi-determinantal wavefunctions that share the same basis
36/// but have different linear combination coefficients.
37#[derive(Builder, Clone)]
38#[builder(build_fn(validate = "Self::validate"))]
39pub struct MultiDeterminants<'a, T, B, SC>
40where
41    T: ComplexFloat + Lapack,
42    SC: StructureConstraint + Hash + Eq + fmt::Display,
43    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
44{
45    #[builder(setter(skip), default = "PhantomData")]
46    _lifetime: PhantomData<&'a ()>,
47
48    #[builder(setter(skip), default = "PhantomData")]
49    _structure_constraint: PhantomData<SC>,
50
51    /// A boolean indicating if inner products involving the wavefunctions in this collection should
52    /// be the complex-symmetric bilinear form, rather than the conventional Hermitian sesquilinear
53    /// form.
54    #[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
55    complex_symmetric: bool,
56
57    /// A boolean indicating if the wavefunctions in this collection have been acted on by an
58    /// antiunitary operation. This is so that the correct metric can be used during overlap
59    /// evaluation.
60    #[builder(default = "false")]
61    complex_conjugated: bool,
62
63    /// The basis of Slater determinants in which the multi-determinantal wavefunctions in this
64    /// collection are defined.
65    basis: B,
66
67    /// The linear combination coefficients of the elements in the multi-orbit to give the
68    /// multi-determinantal wavefunctions in this collection. Each column corresponds to one
69    /// multi-determinantal wavefunction.
70    coefficients: Array2<T>,
71
72    /// The energies of the multi-determinantal wavefunctions in this collection.
73    #[builder(
74        default = "Err(\"Multi-determinantal wavefunction energies not yet set.\".to_string())"
75    )]
76    energies: Result<Array1<T>, String>,
77
78    /// The threshold for comparing wavefunctions.
79    threshold: <T as ComplexFloat>::Real,
80}
81
82// ----------------------
83// Struct implementations
84// ----------------------
85
86impl<'a, T, B, SC> MultiDeterminantsBuilder<'a, T, B, SC>
87where
88    T: ComplexFloat + Lapack,
89    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
90    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
91{
92    fn validate(&self) -> Result<(), String> {
93        let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
94        let coefficients = self
95            .coefficients
96            .as_ref()
97            .ok_or("No coefficients found.".to_string())?;
98        let nbasis = basis.n_items() == coefficients.nrows();
99        if !nbasis {
100            log::error!(
101                "The number of coefficient rows does not match the number of basis determinants."
102            );
103        }
104
105        let complex_symmetric = basis
106            .iter()
107            .map(|det_res| det_res.map(|det| det.complex_symmetric()))
108            .collect::<Result<HashSet<_>, _>>()
109            .map_err(|err| err.to_string())?
110            .len()
111            == 1;
112        if !complex_symmetric {
113            log::error!("Inconsistent complex-symmetric flag across basis determinants.");
114        }
115
116        let structcons_check = basis
117            .iter()
118            .map(|det_res| det_res.map(|det| det.structure_constraint().clone()))
119            .collect::<Result<HashSet<_>, _>>()
120            .map_err(|err| err.to_string())?
121            .len()
122            == 1;
123        if !structcons_check {
124            log::error!("Inconsistent spin constraints across basis determinants.");
125        }
126
127        if nbasis && structcons_check && complex_symmetric {
128            Ok(())
129        } else {
130            Err("Multi-determinantal wavefunction collection validation failed.".to_string())
131        }
132    }
133
134    /// Retrieves the consistent complex-symmetric flag from the basis determinants.
135    fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
136        let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
137        let complex_symmetric_set = basis
138            .iter()
139            .map(|det_res| det_res.map(|det| det.complex_symmetric()))
140            .collect::<Result<HashSet<_>, _>>()
141            .map_err(|err| err.to_string())?;
142        if complex_symmetric_set.len() == 1 {
143            complex_symmetric_set
144                .into_iter()
145                .next()
146                .ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
147        } else {
148            Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
149        }
150    }
151}
152
153impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
154where
155    T: ComplexFloat + Lapack,
156    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
157    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
158{
159    /// Returns a builder to construct a new [`MultiDeterminants`].
160    pub fn builder() -> MultiDeterminantsBuilder<'a, T, B, SC> {
161        MultiDeterminantsBuilder::default()
162    }
163
164    /// Constructs a collection of multi-determinantal wavefunctions from a sequence of individual
165    /// multi-determinantal wavefunctions.
166    ///
167    /// No checks are performed to ensure that the single-determinantal bases are consistent across
168    /// all supplied multi-determinantal wavefunctions. Only the basis of the first
169    /// multi-determinantal wavefunction will be used.
170    ///
171    /// # Arguments
172    ///
173    /// * `mtds` - A sequence of individual multi-determinantal wavefunctions.
174    pub fn from_multideterminant_vec(
175        mtds: &[&MultiDeterminant<'a, T, B, SC>],
176    ) -> Result<MultiDeterminants<'a, T, B, SC>, anyhow::Error> {
177        log::warn!(
178            "Using basis from the first multi-determinantal wavefunction as the common basis for the collection of multi-determinantal wavefunctions..."
179        );
180        let nmultidets = mtds.len();
181        let dims_set = mtds
182            .iter()
183            .map(|mtd| mtd.basis().n_items())
184            .collect::<HashSet<_>>();
185        let dim = if dims_set.len() == 1 {
186            dims_set
187                .into_iter()
188                .next()
189                .ok_or_else(|| format_err!("Unable to obtain the unique basis size."))
190        } else {
191            Err(format_err!(
192                "Inconsistent basis sizes across the supplied multi-determinantal wavefunctions."
193            ))
194        }?;
195        let coefficients = Array2::from_shape_vec(
196            (dim, nmultidets).f(),
197            mtds.iter()
198                .flat_map(|mtd| mtd.coefficients())
199                .cloned()
200                .collect::<Vec<_>>(),
201        )
202        .map_err(|err| format_err!(err))?;
203
204        let (basis, threshold) = mtds
205            .first()
206            .map(|mtd| (mtd.basis().clone(), mtd.threshold()))
207            .ok_or_else(|| {
208                format_err!("Unable to access the first multi-determinantal wavefunction.")
209            })?;
210
211        MultiDeterminants::builder()
212            .basis(basis)
213            .coefficients(coefficients)
214            .threshold(threshold)
215            .build()
216            .map_err(|err| format_err!(err))
217    }
218
219    /// Returns the structure constraint of the multi-determinantal wavefunctions in the collection.
220    pub fn structure_constraint(&self) -> SC {
221        self.basis
222            .iter()
223            .next()
224            .expect("No basis determinant found.")
225            .expect("No basis determinant found.")
226            .structure_constraint()
227            .clone()
228    }
229
230    /// Returns an iterator over the multi-determinantal wavefunctions in this collection.
231    pub fn iter(&self) -> impl Iterator {
232        let energies = self
233            .energies
234            .as_ref()
235            .map(|energies| energies.mapv(|v| Ok(v)))
236            .unwrap_or(Array1::from_elem(
237                self.coefficients.ncols(),
238                Err("Multi-determinantal energy not available.".to_string()),
239            ));
240        self.coefficients
241            .columns()
242            .into_iter()
243            .zip(energies)
244            .map(|(c, e)| {
245                MultiDeterminant::builder()
246                    .complex_conjugated(self.complex_conjugated)
247                    .basis(self.basis().clone())
248                    .coefficients(c.to_owned())
249                    .energy(e)
250                    .threshold(self.threshold)
251                    .build()
252                    .map_err(|err| format_err!(err))
253            })
254    }
255}
256
257impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
258where
259    T: ComplexFloat + Lapack,
260    SC: StructureConstraint + Hash + Eq + fmt::Display,
261    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
262{
263    /// Returns the complex-conjugated flag of the multi-determinantal wavefunctions in the
264    /// collection.
265    pub fn complex_conjugated(&self) -> bool {
266        self.complex_conjugated
267    }
268
269    /// Returns the complex-symmetric flag of the multi-determinantal wavefunctions in the
270    /// collection.
271    pub fn complex_symmetric(&self) -> bool {
272        self.complex_symmetric
273    }
274
275    /// Returns the basis of determinants in which the multi-determinantal wavefunctions in this
276    /// collection are defined.
277    pub fn basis(&self) -> &B {
278        &self.basis
279    }
280
281    /// Returns the coefficients of the basis determinants constituting the multi-determinantal
282    /// wavefunctions in this collection.
283    pub fn coefficients(&self) -> &Array2<T> {
284        &self.coefficients
285    }
286
287    /// Returns the energies of the multi-determinantal wavefunctions in this collection.
288    pub fn energies(&self) -> Result<&Array1<T>, &String> {
289        self.energies.as_ref()
290    }
291
292    /// Returns the threshold with which multi-determinantal wavefunctions are compared.
293    pub fn threshold(&self) -> <T as ComplexFloat>::Real {
294        self.threshold
295    }
296}
297
298// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
299// Specific implementations for OrbitBasis
300// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
301
302impl<'a, T, G, SC> MultiDeterminants<'a, T, OrbitBasis<'a, G, SlaterDeterminant<'a, T, SC>>, SC>
303where
304    T: ComplexFloat + Lapack,
305    G: GroupProperties + Clone,
306    SC: StructureConstraint + Hash + Eq + fmt::Display + Clone,
307{
308    /// Converts this multi-determinantal wavefunction collection with an orbit basis into one with
309    /// the equivalent eager basis.
310    #[allow(clippy::type_complexity)]
311    pub fn to_eager_basis(
312        &self,
313    ) -> Result<MultiDeterminants<'a, T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>, anyhow::Error>
314    {
315        MultiDeterminants::<T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>::builder()
316            .complex_conjugated(self.complex_conjugated)
317            .basis(self.basis.to_eager()?)
318            .coefficients(self.coefficients().clone())
319            .energies(self.energies.clone())
320            .threshold(self.threshold)
321            .build()
322            .map_err(|err| format_err!(err))
323    }
324}
325
326// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
327// Generic implementation for all Basis
328// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
329
330impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
331where
332    T: ComplexFloat + Lapack + ScalarOperand + Send + Sync,
333    <T as ComplexFloat>::Real: LowerExp + fmt::Display + Sync,
334    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display + Sync,
335    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone + Sync,
336    SlaterDeterminant<'a, T, SC>: Send + Sync,
337{
338    /// Calculates the (contravariant) density matrices $`\mathbf{P}_m(\hat{\iota})`$ of all
339    /// multi-determinantal wavefunctions in this collection in the AO basis.
340    ///
341    /// Note that each contravariant density matrix $`\mathbf{P}_m(\hat{\iota})`$ needs to be
342    /// converted to the mixed form $`\tilde{\mathbf{P}}(\hat{\iota})`$ given by
343    /// ```math
344    ///     \tilde{\mathbf{P}}(\hat{\iota}) = \mathbf{P}(\hat{\iota}) \mathbf{S}_{\mathrm{AO}}
345    /// ```
346    /// before being diagonalised to obtain natural orbitals and their occupation numbers.
347    ///
348    /// # Arguments
349    ///
350    /// * `sao` - The atomic-orbital overlap matrix.
351    /// * `thresh_offdiag` - Threshold for determining non-zero off-diagonal elements in the
352    ///   orbital overlap matrix two Slater determinants during Löwdin pairing.
353    /// * `thresh_zeroov` - Threshold for identifying zero Löwdin overlaps.
354    ///
355    /// # Returns
356    ///
357    /// Returns a three-dimensional array $`P`$ containing the density matrices of the
358    /// multi-determinantal wavefunctions in this collection. The array is indexed $`P_{mij}`$
359    /// where $`m`$ is the index for the multi-determinantal wavefunctions in this collection.
360    pub fn density_matrices(
361        &self,
362        sao: &ArrayView2<T>,
363        thresh_offdiag: <T as ComplexFloat>::Real,
364        thresh_zeroov: <T as ComplexFloat>::Real,
365        normalised_wavefunctions: bool,
366    ) -> Result<Array3<T>, anyhow::Error> {
367        let nao = sao.nrows();
368        let dets = self.basis().iter().collect::<Result<Vec<_>, _>>()?;
369        let nmultidets = self.coefficients.ncols();
370        let sqnorms_denmats_res = dets.iter()
371            .zip(self.coefficients().rows())
372            .cartesian_product(dets.iter().zip(self.coefficients().rows()))
373            .par_bridge()
374            .fold(
375                || Ok((Array1::<T>::zeros(nmultidets), Array3::<T>::zeros((nmultidets, nao, nao)))),
376                |acc_res, ((det_w, c_wm), (det_x, c_xm))| {
377                    ensure!(
378                        det_w.structure_constraint() == det_x.structure_constraint(),
379                        "Inconsistent spin constraints: {} != {}.",
380                        det_w.structure_constraint(),
381                        det_x.structure_constraint(),
382                    );
383
384                    if det_w.complex_symmetric() != det_x.complex_symmetric() {
385                        return Err(format_err!(
386                            "The `complex_symmetric` booleans of the specified determinants do not match: `det_w` (`{}`) != `det_x` (`{}`).",
387                            det_w.complex_symmetric(),
388                            det_x.complex_symmetric(),
389                        ));
390                    }
391                    let complex_symmetric = det_w.complex_symmetric();
392                    let lowdin_paired_coefficientss = det_w
393                        .coefficients()
394                        .iter()
395                        .zip(det_w.occupations().iter())
396                        .zip(det_x.coefficients().iter().zip(det_x.occupations().iter()))
397                        .map(|((cw, occw), (cx, occx))| {
398                            let occw_indices = occw
399                                .iter()
400                                .enumerate()
401                                .filter_map(|(i, occ_i)| {
402                                    if occ_i.abs() >= det_w.threshold() {
403                                        Some(i)
404                                    } else {
405                                        None
406                                    }
407                                })
408                                .collect::<Vec<_>>();
409                            let ne_w = occw_indices.len();
410                            let cw_occ = cw.select(Axis(1), &occw_indices);
411                            let occx_indices = occx
412                                .iter()
413                                .enumerate()
414                                .filter_map(|(i, occ_i)| {
415                                    if occ_i.abs() >= det_x.threshold() {
416                                        Some(i)
417                                    } else {
418                                        None
419                                    }
420                                })
421                                .collect::<Vec<_>>();
422                            let ne_x = occx_indices.len();
423                            ensure!(ne_w == ne_x, "Inconsistent number of electrons: {ne_w} != {ne_x}.");
424                            let cx_occ = cx.select(Axis(1), &occx_indices);
425                            calc_lowdin_pairing(
426                                &cw_occ.view(),
427                                &cx_occ.view(),
428                                sao,
429                                complex_symmetric,
430                                thresh_offdiag,
431                                thresh_zeroov,
432                            )
433                        })
434                        .collect::<Result<Vec<_>, _>>()?;
435                    let den_wx = calc_transition_density_matrix(&lowdin_paired_coefficientss, &self.structure_constraint())?;
436                    let ov_wx = lowdin_paired_coefficientss
437                        .iter()
438                        .flat_map(|lpc| lpc.lowdin_overlaps().iter())
439                        .fold(T::one(), |acc, ov| acc * *ov);
440
441                    let c_wm = if self.complex_conjugated() {
442                        if complex_symmetric { c_wm.map(|v| v.conj()) } else { c_wm.to_owned() }
443                    } else if complex_symmetric { c_wm.to_owned() } else { c_wm.map(|v| v.conj()) };
444                    let c_xm = if self.complex_conjugated() {
445                        c_xm.map(|v| v.conj())
446                    } else {
447                        c_xm.to_owned()
448                    };
449                    let den_wx = if self.complex_conjugated() {
450                        den_wx.mapv(|v| v.conj())
451                    } else {
452                        den_wx
453                    };
454                    acc_res.and_then(|(sqnorm_acc, denmat_acc)| {
455                        let ov_wx_m = einsum("m,m->m", &[&c_wm.view(), &c_xm.view()])
456                            .map_err(|err| format_err!(err))?
457                            .into_dimensionality::<Ix1>()
458                            .map_err(|err| format_err!(err))?
459                            .mapv(|v| v * ov_wx);
460                        let denmat_wx_mij = einsum("ij,m,m->mij", &[&den_wx.view(), &c_wm.view(), &c_xm.view()])
461                            .map_err(|err| format_err!(err))?
462                            .into_dimensionality::<Ix3>()
463                            .map_err(|err| format_err!(err))?;
464                        Ok((sqnorm_acc + ov_wx_m, denmat_acc + denmat_wx_mij))
465                    })
466                },
467            )
468            .reduce(
469                || Ok((Array1::<T>::zeros(nmultidets), Array3::<T>::zeros((nmultidets, nao, nao)))),
470                |sqnorms_denmats_res_a: Result<(Array1<T>, Array3<T>), anyhow::Error>, sqnorms_denmats_res_b: Result<(Array1<T>, Array3<T>), anyhow::Error>| {
471                    sqnorms_denmats_res_a.and_then(|(sqnorm_acc, denmat_acc)| sqnorms_denmats_res_b.map(|(sqnorm, denmat)| {
472                        (
473                            sqnorm_acc + sqnorm,
474                            denmat_acc + denmat
475                        )
476                    }))
477                }
478            );
479        sqnorms_denmats_res.and_then(|(sqnorms, denmats)| {
480            if normalised_wavefunctions {
481                let sqnorms_inv = sqnorms.mapv(|v| T::one() / v);
482                einsum("m,mij->mij", &[&sqnorms_inv.view(), &denmats.view()])
483                    .map_err(|err| format_err!(err))?
484                    .into_dimensionality::<Ix3>()
485                    .map_err(|err| format_err!(err))
486            } else {
487                Ok(denmats)
488            }
489        })
490    }
491}
492
493// -----
494// Debug
495// -----
496impl<'a, T, B, SC> fmt::Debug for MultiDeterminants<'a, T, B, SC>
497where
498    T: ComplexFloat + Lapack,
499    SC: StructureConstraint + Hash + Eq + fmt::Display,
500    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
501{
502    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
503        write!(
504            f,
505            "MultiDeterminant collection over {} basis Slater determinants",
506            self.coefficients.len(),
507        )?;
508        Ok(())
509    }
510}
511
512// -------
513// Display
514// -------
515impl<'a, T, B, SC> fmt::Display for MultiDeterminants<'a, T, B, SC>
516where
517    T: ComplexFloat + Lapack,
518    SC: StructureConstraint + Hash + Eq + fmt::Display,
519    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
520{
521    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522        write!(
523            f,
524            "MultiDeterminant collection over {} basis Slater determinants",
525            self.coefficients.len(),
526        )?;
527        Ok(())
528    }
529}