qsym2/target/noci/
multideterminants.rs

1//! Collections of multi-determinant wavefunctions for non-orthogonal configuration interaction.
2
3use std::collections::HashSet;
4use std::fmt::{self, LowerExp};
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use anyhow::{ensure, format_err};
9use derive_builder::Builder;
10use itertools::Itertools;
11use log;
12use ndarray::{Array1, Array2, Array3, ArrayView2, Axis, Ix1, Ix3, ScalarOperand, ShapeBuilder};
13use ndarray_einsum::einsum;
14use ndarray_linalg::types::Lapack;
15use num_complex::ComplexFloat;
16use rayon::prelude::*;
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19use crate::group::GroupProperties;
20use crate::target::determinant::SlaterDeterminant;
21use crate::target::noci::backend::nonortho::{
22    calc_lowdin_pairing, calc_o0_matrix_element, calc_transition_density_matrix,
23};
24use crate::target::noci::basis::{EagerBasis, OrbitBasis};
25use crate::target::noci::multideterminant::MultiDeterminant;
26
27use super::basis::Basis;
28
29#[cfg(test)]
30#[path = "multideterminants_tests.rs"]
31mod multideterminants_tests;
32
33// ------------------
34// Struct definitions
35// ------------------
36
37/// Structure to manage collections of multi-determinantal wavefunctions that share the same basis
38/// but have different linear combination coefficients.
39#[derive(Builder, Clone)]
40#[builder(build_fn(validate = "Self::validate"))]
41pub struct MultiDeterminants<'a, T, B, SC>
42where
43    T: ComplexFloat + Lapack,
44    SC: StructureConstraint + Hash + Eq + fmt::Display,
45    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
46{
47    #[builder(setter(skip), default = "PhantomData")]
48    _lifetime: PhantomData<&'a ()>,
49
50    #[builder(setter(skip), default = "PhantomData")]
51    _structure_constraint: PhantomData<SC>,
52
53    /// A boolean indicating if inner products involving the wavefunctions in this collection should
54    /// be the complex-symmetric bilinear form, rather than the conventional Hermitian sesquilinear
55    /// form.
56    #[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
57    complex_symmetric: bool,
58
59    /// A boolean indicating if the wavefunctions in this collection have been acted on by an
60    /// antiunitary operation. This is so that the correct metric can be used during overlap
61    /// evaluation.
62    #[builder(default = "false")]
63    complex_conjugated: bool,
64
65    /// The basis of Slater determinants in which the multi-determinantal wavefunctions in this
66    /// collection are defined.
67    basis: B,
68
69    /// The linear combination coefficients of the elements in the multi-orbit to give the
70    /// multi-determinantal wavefunctions in this collection. Each column corresponds to one
71    /// multi-determinantal wavefunction.
72    coefficients: Array2<T>,
73
74    /// The energies of the multi-determinantal wavefunctions in this collection.
75    #[builder(
76        default = "Err(\"Multi-determinantal wavefunction energies not yet set.\".to_string())"
77    )]
78    energies: Result<Array1<T>, String>,
79
80    /// The threshold for comparing wavefunctions.
81    threshold: <T as ComplexFloat>::Real,
82}
83
84// ----------------------
85// Struct implementations
86// ----------------------
87
88impl<'a, T, B, SC> MultiDeterminantsBuilder<'a, T, B, SC>
89where
90    T: ComplexFloat + Lapack,
91    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
92    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
93{
94    fn validate(&self) -> Result<(), String> {
95        let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
96        let coefficients = self
97            .coefficients
98            .as_ref()
99            .ok_or("No coefficients found.".to_string())?;
100        let nbasis = basis.n_items() == coefficients.nrows();
101        if !nbasis {
102            log::error!(
103                "The number of coefficient rows does not match the number of basis determinants."
104            );
105        }
106
107        let complex_symmetric = basis
108            .iter()
109            .map(|det_res| det_res.map(|det| det.complex_symmetric()))
110            .collect::<Result<HashSet<_>, _>>()
111            .map_err(|err| err.to_string())?
112            .len()
113            == 1;
114        if !complex_symmetric {
115            log::error!("Inconsistent complex-symmetric flag across basis determinants.");
116        }
117
118        let structcons_check = basis
119            .iter()
120            .map(|det_res| det_res.map(|det| det.structure_constraint().clone()))
121            .collect::<Result<HashSet<_>, _>>()
122            .map_err(|err| err.to_string())?
123            .len()
124            == 1;
125        if !structcons_check {
126            log::error!("Inconsistent spin constraints across basis determinants.");
127        }
128
129        if nbasis && structcons_check && complex_symmetric {
130            Ok(())
131        } else {
132            Err("Multi-determinantal wavefunction collection validation failed.".to_string())
133        }
134    }
135
136    /// Retrieves the consistent complex-symmetric flag from the basis determinants.
137    fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
138        let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
139        let complex_symmetric_set = basis
140            .iter()
141            .map(|det_res| det_res.map(|det| det.complex_symmetric()))
142            .collect::<Result<HashSet<_>, _>>()
143            .map_err(|err| err.to_string())?;
144        if complex_symmetric_set.len() == 1 {
145            complex_symmetric_set
146                .into_iter()
147                .next()
148                .ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
149        } else {
150            Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
151        }
152    }
153}
154
155impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
156where
157    T: ComplexFloat + Lapack,
158    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
159    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
160{
161    /// Returns a builder to construct a new [`MultiDeterminants`].
162    pub fn builder() -> MultiDeterminantsBuilder<'a, T, B, SC> {
163        MultiDeterminantsBuilder::default()
164    }
165
166    /// Constructs a collection of multi-determinantal wavefunctions from a sequence of individual
167    /// multi-determinantal wavefunctions.
168    ///
169    /// No checks are performed to ensure that the single-determinantal bases are consistent across
170    /// all supplied multi-determinantal wavefunctions. Only the basis of the first
171    /// multi-determinantal wavefunction will be used.
172    ///
173    /// # Arguments
174    ///
175    /// * `mtds` - A sequence of individual multi-determinantal wavefunctions.
176    pub fn from_multideterminant_vec(
177        mtds: &[&MultiDeterminant<'a, T, B, SC>],
178    ) -> Result<MultiDeterminants<'a, T, B, SC>, anyhow::Error> {
179        log::warn!(
180            "Using basis from the first multi-determinantal wavefunction as the common basis for the collection of multi-determinantal wavefunctions..."
181        );
182        let nmultidets = mtds.len();
183        let dims_set = mtds
184            .iter()
185            .map(|mtd| mtd.basis().n_items())
186            .collect::<HashSet<_>>();
187        let dim = if dims_set.len() == 1 {
188            dims_set
189                .into_iter()
190                .next()
191                .ok_or_else(|| format_err!("Unable to obtain the unique basis size."))
192        } else {
193            Err(format_err!(
194                "Inconsistent basis sizes across the supplied multi-determinantal wavefunctions."
195            ))
196        }?;
197        let coefficients = Array2::from_shape_vec(
198            (dim, nmultidets).f(),
199            mtds.iter()
200                .flat_map(|mtd| mtd.coefficients())
201                .cloned()
202                .collect::<Vec<_>>(),
203        )
204        .map_err(|err| format_err!(err))?;
205
206        let (basis, threshold) = mtds
207            .first()
208            .map(|mtd| (mtd.basis().clone(), mtd.threshold()))
209            .ok_or_else(|| {
210                format_err!("Unable to access the first multi-determinantal wavefunction.")
211            })?;
212
213        MultiDeterminants::builder()
214            .basis(basis)
215            .coefficients(coefficients)
216            .threshold(threshold)
217            .build()
218            .map_err(|err| format_err!(err))
219    }
220
221    /// Returns the structure constraint of the multi-determinantal wavefunctions in the collection.
222    pub fn structure_constraint(&self) -> SC {
223        self.basis
224            .iter()
225            .next()
226            .expect("No basis determinant found.")
227            .expect("No basis determinant found.")
228            .structure_constraint()
229            .clone()
230    }
231
232    /// Returns an iterator over the multi-determinantal wavefunctions in this collection.
233    pub fn iter(&self) -> impl Iterator {
234        let energies = self
235            .energies
236            .as_ref()
237            .map(|energies| energies.mapv(|v| Ok(v)))
238            .unwrap_or(Array1::from_elem(
239                self.coefficients.ncols(),
240                Err("Multi-determinantal energy not available.".to_string()),
241            ));
242        self.coefficients
243            .columns()
244            .into_iter()
245            .zip(energies)
246            .map(|(c, e)| {
247                MultiDeterminant::builder()
248                    .complex_conjugated(self.complex_conjugated)
249                    .basis(self.basis().clone())
250                    .coefficients(c.to_owned())
251                    .energy(e)
252                    .threshold(self.threshold)
253                    .build()
254                    .map_err(|err| format_err!(err))
255            })
256    }
257}
258
259impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
260where
261    T: ComplexFloat + Lapack,
262    SC: StructureConstraint + Hash + Eq + fmt::Display,
263    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
264{
265    /// Returns the complex-conjugated flag of the multi-determinantal wavefunctions in the
266    /// collection.
267    pub fn complex_conjugated(&self) -> bool {
268        self.complex_conjugated
269    }
270
271    /// Returns the complex-symmetric flag of the multi-determinantal wavefunctions in the
272    /// collection.
273    pub fn complex_symmetric(&self) -> bool {
274        self.complex_symmetric
275    }
276
277    /// Returns the basis of determinants in which the multi-determinantal wavefunctions in this
278    /// collection are defined.
279    pub fn basis(&self) -> &B {
280        &self.basis
281    }
282
283    /// Returns the coefficients of the basis determinants constituting the multi-determinantal
284    /// wavefunctions in this collection.
285    pub fn coefficients(&self) -> &Array2<T> {
286        &self.coefficients
287    }
288
289    /// Returns the energies of the multi-determinantal wavefunctions in this collection.
290    pub fn energies(&self) -> Result<&Array1<T>, &String> {
291        self.energies.as_ref()
292    }
293
294    /// Returns the threshold with which multi-determinantal wavefunctions are compared.
295    pub fn threshold(&self) -> <T as ComplexFloat>::Real {
296        self.threshold
297    }
298}
299
300// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
301// Specific implementations for OrbitBasis
302// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
303
304impl<'a, T, G, SC> MultiDeterminants<'a, T, OrbitBasis<'a, G, SlaterDeterminant<'a, T, SC>>, SC>
305where
306    T: ComplexFloat + Lapack,
307    G: GroupProperties + Clone,
308    SC: StructureConstraint + Hash + Eq + fmt::Display + Clone,
309{
310    /// Converts this multi-determinantal wavefunction collection with an orbit basis into one with
311    /// the equivalent eager basis.
312    #[allow(clippy::type_complexity)]
313    pub fn to_eager_basis(
314        &self,
315    ) -> Result<MultiDeterminants<'a, T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>, anyhow::Error>
316    {
317        MultiDeterminants::<T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>::builder()
318            .complex_conjugated(self.complex_conjugated)
319            .basis(self.basis.to_eager()?)
320            .coefficients(self.coefficients().clone())
321            .energies(self.energies.clone())
322            .threshold(self.threshold)
323            .build()
324            .map_err(|err| format_err!(err))
325    }
326}
327
328// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
329// Generic implementation for all Basis
330// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
331
332impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
333where
334    T: ComplexFloat + Lapack + ScalarOperand + Send + Sync,
335    <T as ComplexFloat>::Real: LowerExp + fmt::Display + Sync,
336    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display + Sync,
337    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone + Sync,
338    SlaterDeterminant<'a, T, SC>: Send + Sync,
339{
340    /// Calculates the (contravariant) density matrices $`\mathbf{P}_m(\hat{\iota})`$ of all
341    /// multi-determinantal wavefunctions in this collection in the AO basis.
342    ///
343    /// Note that each contravariant density matrix $`\mathbf{P}_m(\hat{\iota})`$ needs to be
344    /// converted to the mixed form $`\tilde{\mathbf{P}}(\hat{\iota})`$ given by
345    /// ```math
346    ///     \tilde{\mathbf{P}}(\hat{\iota}) = \mathbf{P}(\hat{\iota}) \mathbf{S}_{\mathrm{AO}}
347    /// ```
348    /// before being diagonalised to obtain natural orbitals and their occupation numbers.
349    ///
350    /// # Arguments
351    ///
352    /// * `sao` - The atomic-orbital overlap matrix.
353    /// * `thresh_offdiag` - Threshold for determining non-zero off-diagonal elements in the
354    ///   orbital overlap matrix two Slater determinants during Löwdin pairing.
355    /// * `thresh_zeroov` - Threshold for identifying zero Löwdin overlaps.
356    ///
357    /// # Returns
358    ///
359    /// Returns a three-dimensional array $`P`$ containing the density matrices of the
360    /// multi-determinantal wavefunctions in this collection. The array is indexed $`P_{mij}`$
361    /// where $`m`$ is the index for the multi-determinantal wavefunctions in this collection.
362    pub fn density_matrices(
363        &self,
364        sao: &ArrayView2<T>,
365        thresh_offdiag: <T as ComplexFloat>::Real,
366        thresh_zeroov: <T as ComplexFloat>::Real,
367        normalised_wavefunctions: bool,
368    ) -> Result<Array3<T>, anyhow::Error> {
369        let nao = sao.nrows();
370        let dets = self.basis().iter().collect::<Result<Vec<_>, _>>()?;
371        let nmultidets = self.coefficients.ncols();
372        let sqnorms_denmats_res = dets.iter()
373            .zip(self.coefficients().rows())
374            .cartesian_product(dets.iter().zip(self.coefficients().rows()))
375            .par_bridge()
376            .fold(
377                || Ok((Array1::<T>::zeros(nmultidets), Array3::<T>::zeros((nmultidets, nao, nao)))),
378                |acc_res, ((det_w, c_wm), (det_x, c_xm))| {
379                    ensure!(
380                        det_w.structure_constraint() == det_x.structure_constraint(),
381                        "Inconsistent spin constraints: {} != {}.",
382                        det_w.structure_constraint(),
383                        det_x.structure_constraint(),
384                    );
385
386                    if det_w.complex_symmetric() != det_x.complex_symmetric() {
387                        return Err(format_err!(
388                            "The `complex_symmetric` booleans of the specified determinants do not match: `det_w` (`{}`) != `det_x` (`{}`).",
389                            det_w.complex_symmetric(),
390                            det_x.complex_symmetric(),
391                        ));
392                    }
393                    let complex_symmetric = det_w.complex_symmetric();
394                    let lowdin_paired_coefficientss = det_w
395                        .coefficients()
396                        .iter()
397                        .zip(det_w.occupations().iter())
398                        .zip(det_x.coefficients().iter().zip(det_x.occupations().iter()))
399                        .map(|((cw, occw), (cx, occx))| {
400                            let occw_indices = occw
401                                .iter()
402                                .enumerate()
403                                .filter_map(|(i, occ_i)| {
404                                    if occ_i.abs() >= det_w.threshold() {
405                                        Some(i)
406                                    } else {
407                                        None
408                                    }
409                                })
410                                .collect::<Vec<_>>();
411                            let ne_w = occw_indices.len();
412                            let cw_occ = cw.select(Axis(1), &occw_indices);
413                            let occx_indices = occx
414                                .iter()
415                                .enumerate()
416                                .filter_map(|(i, occ_i)| {
417                                    if occ_i.abs() >= det_x.threshold() {
418                                        Some(i)
419                                    } else {
420                                        None
421                                    }
422                                })
423                                .collect::<Vec<_>>();
424                            let ne_x = occx_indices.len();
425                            ensure!(ne_w == ne_x, "Inconsistent number of electrons: {ne_w} != {ne_x}.");
426                            let cx_occ = cx.select(Axis(1), &occx_indices);
427                            calc_lowdin_pairing(
428                                &cw_occ.view(),
429                                &cx_occ.view(),
430                                sao,
431                                complex_symmetric,
432                                thresh_offdiag,
433                                thresh_zeroov,
434                            )
435                        })
436                        .collect::<Result<Vec<_>, _>>()?;
437                    let den_wx = calc_transition_density_matrix(&lowdin_paired_coefficientss, &self.structure_constraint())?;
438                    let ov_wx = calc_o0_matrix_element(&lowdin_paired_coefficientss, T::one(), &self.structure_constraint())?;
439
440                    let c_wm = if self.complex_conjugated() {
441                        if complex_symmetric { c_wm.map(|v| v.conj()) } else { c_wm.to_owned() }
442                    } else if complex_symmetric { c_wm.to_owned() } else { c_wm.map(|v| v.conj()) };
443                    let c_xm = if self.complex_conjugated() {
444                        c_xm.map(|v| v.conj())
445                    } else {
446                        c_xm.to_owned()
447                    };
448                    let den_wx = if self.complex_conjugated() {
449                        den_wx.mapv(|v| v.conj())
450                    } else {
451                        den_wx
452                    };
453                    acc_res.and_then(|(sqnorm_acc, denmat_acc)| {
454                        let ov_wx_m = einsum("m,m->m", &[&c_wm.view(), &c_xm.view()])
455                            .map_err(|err| format_err!(err))?
456                            .into_dimensionality::<Ix1>()
457                            .map_err(|err| format_err!(err))?
458                            .mapv(|v| v * ov_wx);
459                        let denmat_wx_mij = einsum("ij,m,m->mij", &[&den_wx.view(), &c_wm.view(), &c_xm.view()])
460                            .map_err(|err| format_err!(err))?
461                            .into_dimensionality::<Ix3>()
462                            .map_err(|err| format_err!(err))?;
463                        Ok((sqnorm_acc + ov_wx_m, denmat_acc + denmat_wx_mij))
464                    })
465                },
466            )
467            .reduce(
468                || Ok((Array1::<T>::zeros(nmultidets), Array3::<T>::zeros((nmultidets, nao, nao)))),
469                |sqnorms_denmats_res_a: Result<(Array1<T>, Array3<T>), anyhow::Error>, sqnorms_denmats_res_b: Result<(Array1<T>, Array3<T>), anyhow::Error>| {
470                    sqnorms_denmats_res_a.and_then(|(sqnorm_acc, denmat_acc)| sqnorms_denmats_res_b.map(|(sqnorm, denmat)| {
471                        (
472                            sqnorm_acc + sqnorm,
473                            denmat_acc + denmat
474                        )
475                    }))
476                }
477            );
478        sqnorms_denmats_res.and_then(|(sqnorms, denmats)| {
479            if normalised_wavefunctions {
480                let sqnorms_inv = sqnorms.mapv(|v| T::one() / v);
481                einsum("m,mij->mij", &[&sqnorms_inv.view(), &denmats.view()])
482                    .map_err(|err| format_err!(err))?
483                    .into_dimensionality::<Ix3>()
484                    .map_err(|err| format_err!(err))
485            } else {
486                Ok(denmats)
487            }
488        })
489    }
490}
491
492// -----
493// Debug
494// -----
495impl<'a, T, B, SC> fmt::Debug for MultiDeterminants<'a, T, B, SC>
496where
497    T: ComplexFloat + Lapack,
498    SC: StructureConstraint + Hash + Eq + fmt::Display,
499    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
500{
501    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
502        write!(
503            f,
504            "MultiDeterminant collection over {} basis Slater determinants",
505            self.coefficients.len(),
506        )?;
507        Ok(())
508    }
509}
510
511// -------
512// Display
513// -------
514impl<'a, T, B, SC> fmt::Display for MultiDeterminants<'a, T, B, SC>
515where
516    T: ComplexFloat + Lapack,
517    SC: StructureConstraint + Hash + Eq + fmt::Display,
518    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
519{
520    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
521        write!(
522            f,
523            "MultiDeterminant collection over {} basis Slater determinants",
524            self.coefficients.len(),
525        )?;
526        Ok(())
527    }
528}