qsym2/target/noci/
multideterminant.rs

1//! Multi-determinant wavefunctions for non-orthogonal configuration interaction.
2
3use std::collections::HashSet;
4use std::fmt;
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use derive_builder::Builder;
9use log;
10use ndarray::Array1;
11use ndarray_linalg::types::Lapack;
12use num_complex::ComplexFloat;
13
14use crate::angmom::spinor_rotation_3d::StructureConstraint;
15use crate::target::determinant::SlaterDeterminant;
16
17use super::basis::Basis;
18
19#[path = "multideterminant_transformation.rs"]
20pub(crate) mod multideterminant_transformation;
21
22#[path = "multideterminant_analysis.rs"]
23pub(crate) mod multideterminant_analysis;
24
25#[cfg(test)]
26#[path = "multideterminant_tests.rs"]
27mod multideterminant_tests;
28
29// ------------------
30// Struct definitions
31// ------------------
32
33/// Structure to manage multi-determinantal wavefunctions.
34#[derive(Builder, Clone)]
35#[builder(build_fn(validate = "Self::validate"))]
36pub struct MultiDeterminant<'a, T, B, SC>
37where
38    T: ComplexFloat + Lapack,
39    SC: StructureConstraint + Hash + Eq + fmt::Display,
40    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
41{
42    #[builder(setter(skip), default = "PhantomData")]
43    _lifetime: PhantomData<&'a ()>,
44
45    #[builder(setter(skip), default = "PhantomData")]
46    _structure_constraint: PhantomData<SC>,
47
48    /// A boolean indicating if inner products involving this wavefunction should be the
49    /// complex-symmetric bilinear form, rather than the conventional Hermitian sesquilinear form.
50    #[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
51    complex_symmetric: bool,
52
53    /// A boolean indicating if the wavefunction has been acted on by an antiunitary operation. This
54    /// is so that the correct metric can be used during overlap evaluation.
55    #[builder(default = "false")]
56    complex_conjugated: bool,
57
58    /// The basis of Slater determinants in which this multi-determinantal wavefunction is defined.
59    basis: B,
60
61    /// The linear combination coefficients of the elements in the multi-orbit to give this
62    /// multi-determinant wavefunction.
63    coefficients: Array1<T>,
64
65    /// The energy of this multi-determinantal wavefunction.
66    #[builder(
67        default = "Err(\"Multi-determinantal wavefunction energy not yet set.\".to_string())"
68    )]
69    energy: Result<T, String>,
70
71    /// The threshold for comparing wavefunctions.
72    threshold: <T as ComplexFloat>::Real,
73}
74
75// ----------------------
76// Struct implementations
77// ----------------------
78
79impl<'a, T, B, SC> MultiDeterminantBuilder<'a, T, B, SC>
80where
81    T: ComplexFloat + Lapack,
82    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
83    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
84{
85    fn validate(&self) -> Result<(), String> {
86        let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
87        let coefficients = self
88            .coefficients
89            .as_ref()
90            .ok_or("No coefficients found.".to_string())?;
91        let nbasis = basis.n_items() == coefficients.len();
92        if !nbasis {
93            log::error!(
94                "The number of coefficients does not match the number of basis determinants."
95            );
96        }
97
98        let complex_symmetric = basis
99            .iter()
100            .map(|det_res| det_res.map(|det| det.complex_symmetric()))
101            .collect::<Result<HashSet<_>, _>>()
102            .map_err(|err| err.to_string())?
103            .len()
104            == 1;
105        if !complex_symmetric {
106            log::error!("Inconsistent complex-symmetric flag across basis determinants.");
107        }
108
109        let structcons_check = basis
110            .iter()
111            .map(|det_res| det_res.map(|det| det.structure_constraint().clone()))
112            .collect::<Result<HashSet<_>, _>>()
113            .map_err(|err| err.to_string())?
114            .len()
115            == 1;
116        if !structcons_check {
117            log::error!("Inconsistent spin constraints across basis determinants.");
118        }
119
120        if nbasis && structcons_check && complex_symmetric {
121            Ok(())
122        } else {
123            Err("Multi-determinant wavefunction validation failed.".to_string())
124        }
125    }
126
127    /// Retrieves the consistent complex-symmetric flag from the basis determinants.
128    fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
129        let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
130        let complex_symmetric_set = basis
131            .iter()
132            .map(|det_res| det_res.map(|det| det.complex_symmetric()))
133            .collect::<Result<HashSet<_>, _>>()
134            .map_err(|err| err.to_string())?;
135        if complex_symmetric_set.len() == 1 {
136            complex_symmetric_set
137                .into_iter()
138                .next()
139                .ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
140        } else {
141            Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
142        }
143    }
144}
145
146impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
147where
148    T: ComplexFloat + Lapack,
149    SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
150    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
151{
152    /// Returns a builder to construct a new [`MultiDeterminant`].
153    pub(crate) fn builder() -> MultiDeterminantBuilder<'a, T, B, SC> {
154        MultiDeterminantBuilder::default()
155    }
156
157    /// Returns the structure constraint of the multi-determinantal wavefunction.
158    pub fn structure_constraint(&self) -> SC {
159        self.basis
160            .iter()
161            .next()
162            .expect("No basis determinant found.")
163            .expect("No basis determinant found.")
164            .structure_constraint()
165            .clone()
166    }
167}
168
169impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
170where
171    T: ComplexFloat + Lapack,
172    SC: StructureConstraint + Hash + Eq + fmt::Display,
173    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
174{
175    /// Returns the complex-conjugated flag of the multi-determinantal wavefunction.
176    pub fn complex_conjugated(&self) -> bool {
177        self.complex_conjugated
178    }
179
180    /// Returns the basis of determinants in which this multi-determinantal wavefunction is
181    /// defined.
182    pub fn basis(&self) -> &B {
183        &self.basis
184    }
185
186    /// Returns the coefficients of the basis determinants constituting this multi-determinantal
187    /// wavefunction.
188    pub fn coefficients(&self) -> &Array1<T> {
189        &self.coefficients
190    }
191
192    /// Returns the energy of the multi-determinantal wavefunction.
193    pub fn energy(&self) -> Result<&T, &String> {
194        self.energy.as_ref()
195    }
196
197    /// Returns the threshold with which multi-determinantal wavefunctions are compared.
198    pub fn threshold(&self) -> <T as ComplexFloat>::Real {
199        self.threshold
200    }
201}
202
203// -----
204// Debug
205// -----
206impl<'a, T, B, SC> fmt::Debug for MultiDeterminant<'a, T, B, SC>
207where
208    T: ComplexFloat + Lapack,
209    SC: StructureConstraint + Hash + Eq + fmt::Display,
210    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
211{
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        write!(
214            f,
215            "MultiDeterminant over {} basis Slater determinants",
216            self.coefficients.len(),
217        )?;
218        Ok(())
219    }
220}
221
222// -------
223// Display
224// -------
225impl<'a, T, B, SC> fmt::Display for MultiDeterminant<'a, T, B, SC>
226where
227    T: ComplexFloat + Lapack,
228    SC: StructureConstraint + Hash + Eq + fmt::Display,
229    B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
230{
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        write!(
233            f,
234            "MultiDeterminant over {} basis Slater determinants",
235            self.coefficients.len(),
236        )?;
237        Ok(())
238    }
239}