qsym2/bindings/python/representation_analysis/multideterminant/
mod.rs

1//! Python bindings for QSym² symmetry analysis of multi-determinants.
2
3use std::collections::HashSet;
4use std::fmt;
5use std::hash::Hash;
6
7use anyhow::format_err;
8use itertools::Itertools;
9use ndarray::{Array1, Array2};
10use num_complex::Complex;
11use numpy::{PyArray1, PyArray2, PyArrayMethods, ToPyArray};
12use pyo3::exceptions::PyRuntimeError;
13use pyo3::prelude::*;
14
15use crate::angmom::spinor_rotation_3d::StructureConstraint;
16use crate::auxiliary::molecule::Molecule;
17use crate::basis::ao::BasisAngularOrder;
18use crate::bindings::python::integrals::PyStructureConstraint;
19use crate::bindings::python::representation_analysis::slater_determinant::{
20    PySlaterDeterminantComplex, PySlaterDeterminantReal,
21};
22use crate::target::determinant::SlaterDeterminant;
23use crate::target::noci::basis::EagerBasis;
24use crate::target::noci::multideterminant::MultiDeterminant;
25
26type C128 = Complex<f64>;
27
28// ==================
29// Struct definitions
30// ==================
31
32// -----------------
33// Multi-determinant
34// -----------------
35
36// ~~~~
37// Real
38// ~~~~
39//
40/// Python-exposed structure to marshall real multi-determinant information between Rust and Python.
41#[pyclass]
42#[derive(Clone)]
43pub struct PyMultiDeterminantsReal {
44    /// The basis of Slater determinants in which the multi-determinantal states are expressed.
45    #[pyo3(get)]
46    basis: Vec<PySlaterDeterminantReal>,
47
48    /// The coefficients for the multi-determinantal states in the specified basis. Each column of
49    /// the coefficient matrix contains the coefficients for one state.
50    coefficients: Array2<f64>,
51
52    /// The energies of the multi-determinantal states.
53    energies: Array1<f64>,
54
55    /// The density matrices for the multi-determinantal states in the specified basis.
56    density_matrices: Option<Vec<Array2<f64>>>,
57
58    /// The threshold for comparisons.
59    #[pyo3(get)]
60    threshold: f64,
61}
62
63#[pymethods]
64impl PyMultiDeterminantsReal {
65    /// Constructs a set of real Python-exposed multi-determinants.
66    ///
67    /// # Arguments
68    ///
69    /// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
70    /// expressed.
71    /// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
72    /// Each column of the coefficient matrix contains the coefficients for one state.
73    /// * `energies` - The energies of the multi-determinantal states.
74    /// * `density_matrices` - The optional density matrices of the multi-determinantal states.
75    /// * `threshold` - The threshold for comparisons.
76    #[new]
77    #[pyo3(signature = (basis, coefficients, energies, density_matrices, threshold))]
78    pub fn new(
79        basis: Vec<PySlaterDeterminantReal>,
80        coefficients: Bound<'_, PyArray2<f64>>,
81        energies: Bound<'_, PyArray1<f64>>,
82        density_matrices: Option<Vec<Bound<'_, PyArray2<f64>>>>,
83        threshold: f64,
84    ) -> Self {
85        let coefficients = coefficients.to_owned_array();
86        let energies = energies.to_owned_array();
87        let density_matrices = density_matrices.map(|denmats| {
88            denmats
89                .into_iter()
90                .map(|denmat| denmat.to_owned_array())
91                .collect_vec()
92        });
93        if let Some(ref denmats) = density_matrices {
94            if denmats.len() != coefficients.ncols()
95                || denmats.len() != energies.len()
96                || coefficients.ncols() != energies.len()
97            {
98                panic!(
99                    "Inconsistent numbers of multi-determinantal states in `coefficients`, `energies`, and `density_matrices`."
100                )
101            }
102        };
103        Self {
104            basis,
105            coefficients,
106            energies,
107            density_matrices,
108            threshold,
109        }
110    }
111
112    #[getter]
113    pub fn coefficients<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
114        Ok(self.coefficients.to_pyarray(py))
115    }
116
117    #[getter]
118    pub fn energies<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
119        Ok(self.energies.to_pyarray(py))
120    }
121
122    #[getter]
123    pub fn density_matrices<'py>(
124        &self,
125        py: Python<'py>,
126    ) -> PyResult<Option<Vec<Bound<'py, PyArray2<f64>>>>> {
127        Ok(self.density_matrices.as_ref().map(|denmats| {
128            denmats
129                .iter()
130                .map(|denmat| denmat.to_pyarray(py))
131                .collect_vec()
132        }))
133    }
134
135    /// Boolean indicating whether inner products involving these multi-determinantal states are
136    /// complex-symmetric.
137    pub fn complex_symmetric<'py>(&self, _py: Python<'py>) -> PyResult<bool> {
138        let complex_symmetric_set = self
139            .basis
140            .iter()
141            .map(|pydet| pydet.complex_symmetric)
142            .collect::<HashSet<_>>();
143        if complex_symmetric_set.len() != 1 {
144            Err(PyRuntimeError::new_err(
145                "Inconsistent complex-symmetric flags across basis functions.",
146            ))
147        } else {
148            complex_symmetric_set.into_iter().next().ok_or_else(|| {
149                PyRuntimeError::new_err("Unable to extract the complex-symmetric flag.")
150            })
151        }
152    }
153
154    /// Returns the coefficients for a particular state.
155    pub fn state_coefficients<'py>(
156        &self,
157        py: Python<'py>,
158        state_index: usize,
159    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
160        Ok(self.coefficients.column(state_index).to_pyarray(py))
161    }
162
163    /// Returns the energy for a particular state.
164    pub fn state_energy<'py>(&self, _py: Python<'py>, state_index: usize) -> PyResult<f64> {
165        Ok(self.energies[state_index])
166    }
167
168    /// Returns the density matrix for a particular state.
169    pub fn state_density_matrix<'py>(
170        &self,
171        py: Python<'py>,
172        state_index: usize,
173    ) -> PyResult<Bound<'py, PyArray2<f64>>> {
174        self.density_matrices
175            .as_ref()
176            .ok_or_else(|| {
177                PyRuntimeError::new_err(
178                    "No multi-determinantal density matrices found.".to_string(),
179                )
180            })
181            .map(|denmats| denmats[state_index].to_pyarray(py))
182    }
183}
184
185impl PyMultiDeterminantsReal {
186    /// Extracts the information in the [`PyMultiDeterminantsReal`] structure into a vector of
187    /// `QSym2`'s native [`MultiDeterminant`] structures.
188    ///
189    /// # Arguments
190    ///
191    /// * `baos` - The [`BasisAngularOrder`]s for the basis set in which the Slater determinant is
192    /// given, one for each explicit component per coefficient matrix.
193    /// * `mol` - The molecule with which the Slater determinant is associated.
194    ///
195    /// # Returns
196    ///
197    /// The A vector of [`MultiDeterminant`] structures, one for each multi-determinantal state
198    /// contained in the Python version.
199    ///
200    /// # Errors
201    ///
202    /// Errors if the [`MultiDeterminant`] structures fail to build.
203    pub fn to_qsym2<'b, 'a: 'b, SC>(
204        &'b self,
205        baos: &[&'a BasisAngularOrder],
206        mol: &'a Molecule,
207    ) -> Result<
208        Vec<MultiDeterminant<'b, f64, EagerBasis<SlaterDeterminant<'b, f64, SC>>, SC>>,
209        anyhow::Error,
210    >
211    where
212        SC: StructureConstraint
213            + Eq
214            + Hash
215            + Clone
216            + fmt::Display
217            + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
218    {
219        let eager_basis = EagerBasis::builder()
220            .elements(
221                self.basis
222                    .iter()
223                    .map(|pydet| pydet.to_qsym2(baos, mol))
224                    .collect::<Result<Vec<_>, _>>()?,
225            )
226            .build()?;
227        let multidets = self
228            .energies
229            .iter()
230            .zip(self.coefficients.columns())
231            .map(|(e, c)| {
232                MultiDeterminant::builder()
233                    .basis(eager_basis.clone())
234                    .coefficients(c.to_owned())
235                    .energy(Ok(*e))
236                    .threshold(self.threshold)
237                    .build()
238            })
239            .collect::<Result<Vec<_>, _>>()
240            .map_err(|err| format_err!(err));
241        multidets
242    }
243}
244
245// ~~~~~~~
246// Complex
247// ~~~~~~~
248//
249/// Python-exposed structure to marshall complex multi-determinant information between Rust and
250/// Python.
251#[pyclass]
252#[derive(Clone)]
253pub struct PyMultiDeterminantsComplex {
254    /// The basis of Slater determinants in which the multi-determinantal states are expressed.
255    #[pyo3(get)]
256    basis: Vec<PySlaterDeterminantComplex>,
257
258    /// The coefficients for the multi-determinantal states in the specified basis. Each column of
259    /// the coefficient matrix contains the coefficients for one state.
260    coefficients: Array2<C128>,
261
262    /// The energies of the multi-determinantal states.
263    energies: Array1<C128>,
264
265    /// The density matrices for the multi-determinantal states in the specified basis.
266    density_matrices: Option<Vec<Array2<C128>>>,
267
268    /// The threshold for comparisons.
269    #[pyo3(get)]
270    threshold: f64,
271}
272
273#[pymethods]
274impl PyMultiDeterminantsComplex {
275    /// Constructs a set of complex Python-exposed multi-determinants.
276    ///
277    /// # Arguments
278    ///
279    /// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
280    /// expressed.
281    /// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
282    /// Each column of the coefficient matrix contains the coefficients for one state.
283    /// * `energies` - The energies of the multi-determinantal states.
284    /// * `density_matrices` - The optional density matrices of the multi-determinantal states.
285    /// * `threshold` - The threshold for comparisons.
286    #[new]
287    #[pyo3(signature = (basis, coefficients, energies, density_matrices, threshold))]
288    pub fn new(
289        basis: Vec<PySlaterDeterminantComplex>,
290        coefficients: Bound<'_, PyArray2<C128>>,
291        energies: Bound<'_, PyArray1<C128>>,
292        density_matrices: Option<Vec<Bound<'_, PyArray2<C128>>>>,
293        threshold: f64,
294    ) -> Self {
295        let coefficients = coefficients.to_owned_array();
296        let energies = energies.to_owned_array();
297        let density_matrices = density_matrices.map(|denmats| {
298            denmats
299                .into_iter()
300                .map(|denmat| denmat.to_owned_array())
301                .collect_vec()
302        });
303        if let Some(ref denmats) = density_matrices {
304            if denmats.len() != coefficients.ncols()
305                || denmats.len() != energies.len()
306                || coefficients.ncols() != energies.len()
307            {
308                panic!(
309                    "Inconsistent numbers of multi-determinantal states in `coefficients`, `energies`, and `density_matrices`."
310                )
311            }
312        };
313        Self {
314            basis,
315            coefficients,
316            energies,
317            density_matrices,
318            threshold,
319        }
320    }
321
322    #[getter]
323    pub fn coefficients<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<C128>>> {
324        Ok(self.coefficients.to_pyarray(py))
325    }
326
327    #[getter]
328    pub fn energies<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<C128>>> {
329        Ok(self.energies.to_pyarray(py))
330    }
331
332    #[getter]
333    pub fn density_matrices<'py>(
334        &self,
335        py: Python<'py>,
336    ) -> PyResult<Option<Vec<Bound<'py, PyArray2<C128>>>>> {
337        Ok(self.density_matrices.as_ref().map(|denmats| {
338            denmats
339                .iter()
340                .map(|denmat| denmat.to_pyarray(py))
341                .collect_vec()
342        }))
343    }
344
345    /// Boolean indicating whether inner products involving these multi-determinantal states are
346    /// complex-symmetric.
347    pub fn complex_symmetric<'py>(&self, _py: Python<'py>) -> PyResult<bool> {
348        let complex_symmetric_set = self
349            .basis
350            .iter()
351            .map(|pydet| pydet.complex_symmetric)
352            .collect::<HashSet<_>>();
353        if complex_symmetric_set.len() != 1 {
354            Err(PyRuntimeError::new_err(
355                "Inconsistent complex-symmetric flags across basis functions.",
356            ))
357        } else {
358            complex_symmetric_set.into_iter().next().ok_or_else(|| {
359                PyRuntimeError::new_err("Unable to extract the complex-symmetric flag.")
360            })
361        }
362    }
363
364    /// Returns the coefficients for a particular state.
365    pub fn state_coefficients<'py>(
366        &self,
367        py: Python<'py>,
368        state_index: usize,
369    ) -> PyResult<Bound<'py, PyArray1<C128>>> {
370        Ok(self.coefficients.column(state_index).to_pyarray(py))
371    }
372
373    /// Returns the energy for a particular state.
374    pub fn state_energy<'py>(&self, _py: Python<'py>, state_index: usize) -> PyResult<C128> {
375        Ok(self.energies[state_index])
376    }
377
378    /// Returns the density matrix for a particular state.
379    pub fn state_density_matrix<'py>(
380        &self,
381        py: Python<'py>,
382        state_index: usize,
383    ) -> PyResult<Bound<'py, PyArray2<C128>>> {
384        self.density_matrices
385            .as_ref()
386            .ok_or_else(|| {
387                PyRuntimeError::new_err(
388                    "No multi-determinantal density matrices found.".to_string(),
389                )
390            })
391            .map(|denmats| denmats[state_index].to_pyarray(py))
392    }
393}
394
395impl PyMultiDeterminantsComplex {
396    /// Extracts the information in the [`PyMultiDeterminantsComplex`] structure into `QSym2`'s native
397    /// [`MultiDeterminant`] structure.
398    ///
399    /// # Arguments
400    ///
401    /// * `baos` - The [`BasisAngularOrder`]s for the basis set in which the Slater determinant is
402    /// given, one for each explicit component per coefficient matrix.
403    /// * `mol` - The molecule with which the Slater determinant is associated.
404    ///
405    /// # Returns
406    ///
407    /// The A vector of [`MultiDeterminant`] structures, one for each multi-determinantal state
408    /// contained in the Python version.
409    ///
410    /// # Errors
411    ///
412    /// Errors if the [`MultiDeterminant`] structures fail to build.
413    pub fn to_qsym2<'b, 'a: 'b, SC>(
414        &'b self,
415        baos: &[&'a BasisAngularOrder],
416        mol: &'a Molecule,
417    ) -> Result<
418        Vec<MultiDeterminant<'b, C128, EagerBasis<SlaterDeterminant<'b, C128, SC>>, SC>>,
419        anyhow::Error,
420    >
421    where
422        SC: StructureConstraint
423            + Eq
424            + Hash
425            + Clone
426            + fmt::Display
427            + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
428    {
429        let eager_basis = EagerBasis::builder()
430            .elements(
431                self.basis
432                    .iter()
433                    .map(|pydet| pydet.to_qsym2(baos, mol))
434                    .collect::<Result<Vec<_>, _>>()?,
435            )
436            .build()?;
437        let multidets = self
438            .energies
439            .iter()
440            .zip(self.coefficients.columns())
441            .map(|(e, c)| {
442                MultiDeterminant::builder()
443                    .basis(eager_basis.clone())
444                    .coefficients(c.to_owned())
445                    .energy(Ok(*e))
446                    .threshold(self.threshold)
447                    .build()
448            })
449            .collect::<Result<Vec<_>, _>>()
450            .map_err(|err| format_err!(err));
451        multidets
452    }
453}
454
455// ================
456// Enum definitions
457// ================
458
459/// Python-exposed enumerated type to handle the union type
460/// `PyMultiDeterminantsReal | PyMultiDeterminantsComplex` in Python.
461#[derive(FromPyObject)]
462pub enum PyMultiDeterminants {
463    /// Variant for real Python-exposed multi-determinants.
464    Real(PyMultiDeterminantsReal),
465
466    /// Variant for complex Python-exposed multi-determinants.
467    Complex(PyMultiDeterminantsComplex),
468}
469
470// =====================
471// Functions definitions
472// =====================
473
474mod multideterminant_eager_basis;
475mod multideterminant_orbit_basis_external_solver;
476mod multideterminant_orbit_basis_internal_solver;
477
478pub use multideterminant_eager_basis::rep_analyse_multideterminants_eager_basis;
479pub use multideterminant_orbit_basis_external_solver::rep_analyse_multideterminants_orbit_basis_external_solver;
480pub use multideterminant_orbit_basis_internal_solver::rep_analyse_multideterminants_orbit_basis_internal_solver;