qsym2/bindings/python/representation_analysis/multideterminant/
mod.rs

1//! Python bindings for QSym² symmetry analysis of multi-determinants.
2
3use std::collections::HashSet;
4use std::fmt;
5use std::hash::Hash;
6
7use anyhow::format_err;
8use ndarray::{Array1, Array2};
9use num_complex::Complex;
10use numpy::{PyArray1, PyArray2, PyArrayMethods, ToPyArray};
11use pyo3::exceptions::PyRuntimeError;
12use pyo3::prelude::*;
13
14use crate::angmom::spinor_rotation_3d::StructureConstraint;
15use crate::auxiliary::molecule::Molecule;
16use crate::basis::ao::BasisAngularOrder;
17use crate::bindings::python::integrals::PyStructureConstraint;
18use crate::bindings::python::representation_analysis::slater_determinant::{
19    PySlaterDeterminantComplex, PySlaterDeterminantReal,
20};
21use crate::target::determinant::SlaterDeterminant;
22use crate::target::noci::basis::EagerBasis;
23use crate::target::noci::multideterminant::MultiDeterminant;
24
25type C128 = Complex<f64>;
26
27// ==================
28// Struct definitions
29// ==================
30
31// -----------------
32// Multi-determinant
33// -----------------
34
35// ~~~~
36// Real
37// ~~~~
38//
39/// Python-exposed structure to marshall real multi-determinant information between Rust and
40/// Python.
41///
42/// # Constructor arguments
43///
44/// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
45/// expressed. Python type: `list[PySlaterDeterminantReal]`.
46/// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
47/// Each column of the coefficient matrix contains the coefficients for one state.
48/// Python type: `numpy.2darray[float]`.
49/// * `energies` - The energies of the multi-determinantal states. Python type: `numpy.1darray[float]`.
50/// * `threshold` - The threshold for comparisons. Python type: `float`.
51#[pyclass]
52#[derive(Clone)]
53pub struct PyMultiDeterminantsReal {
54    /// The basis of Slater determinants in which the multi-determinantal states are expressed.
55    ///
56    /// Python type: `list[PySlaterDeterminantReal]`.
57    #[pyo3(get)]
58    basis: Vec<PySlaterDeterminantReal>,
59
60    /// The coefficients for the multi-determinantal states in the specified basis. Each column of
61    /// the coefficient matrix contains the coefficients for one state.
62    ///
63    /// Python type: `numpy.2darray[float]`.
64    coefficients: Array2<f64>,
65
66    /// The energies of the multi-determinantal states.
67    ///
68    /// Python type: `numpy.1darray[float]`.
69    energies: Array1<f64>,
70
71    /// The threshold for comparisons.
72    ///
73    /// Python type: `float`.
74    #[pyo3(get)]
75    threshold: f64,
76}
77
78#[pymethods]
79impl PyMultiDeterminantsReal {
80    /// Constructs a set of real Python-exposed multi-determinants.
81    ///
82    /// # Arguments
83    ///
84    /// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
85    /// expressed. Python type: `list[PySlaterDeterminantReal]`.
86    /// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
87    /// Each column of the coefficient matrix contains the coefficients for one state.
88    /// Python type: `numpy.2darray[float]`.
89    /// * `energies` - The energies of the multi-determinantal states. Python type: `numpy.1darray[float]`.
90    /// * `threshold` - The threshold for comparisons. Python type: `float`.
91    #[new]
92    #[pyo3(signature = (basis, coefficients, energies, threshold))]
93    pub fn new(
94        basis: Vec<PySlaterDeterminantReal>,
95        coefficients: Bound<'_, PyArray2<f64>>,
96        energies: Bound<'_, PyArray1<f64>>,
97        threshold: f64,
98    ) -> Self {
99        let multidet = Self {
100            basis,
101            coefficients: coefficients.to_owned_array(),
102            energies: energies.to_owned_array(),
103            threshold,
104        };
105        multidet
106    }
107
108    #[getter]
109    pub fn coefficients<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
110        Ok(self.coefficients.to_pyarray(py))
111    }
112
113    #[getter]
114    pub fn energies<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
115        Ok(self.energies.to_pyarray(py))
116    }
117
118    pub fn complex_symmetric<'py>(&self, py: Python<'py>) -> PyResult<bool> {
119        let complex_symmetric_set = self
120            .basis
121            .iter()
122            .map(|pydet| pydet.complex_symmetric(py))
123            .collect::<Result<HashSet<_>, _>>()?;
124        if complex_symmetric_set.len() != 1 {
125            Err(PyRuntimeError::new_err(
126                "Inconsistent complex-symmetric flags across basis functions.",
127            ))
128        } else {
129            complex_symmetric_set.into_iter().next().ok_or_else(|| {
130                PyRuntimeError::new_err("Unable to extract the complex-symmetric flag.")
131            })
132        }
133    }
134
135    pub fn state_coefficients<'py>(
136        &self,
137        py: Python<'py>,
138        state_index: usize,
139    ) -> PyResult<Bound<'py, PyArray1<f64>>> {
140        Ok(self.coefficients.column(state_index).to_pyarray(py))
141    }
142
143    pub fn state_energy<'py>(&self, _py: Python<'py>, state_index: usize) -> PyResult<f64> {
144        Ok(self.energies[state_index])
145    }
146}
147
148impl PyMultiDeterminantsReal {
149    /// Extracts the information in the [`PyMultiDeterminantsReal`] structure into a vector of
150    /// `QSym2`'s native [`MultiDeterminant`] structures.
151    ///
152    /// # Arguments
153    ///
154    /// * `bao` - The [`BasisAngularOrder`] for the basis set in which the Slater determinant is
155    /// given.
156    /// * `mol` - The molecule with which the Slater determinant is associated.
157    ///
158    /// # Returns
159    ///
160    /// The A vector of [`MultiDeterminant`] structures, one for each multi-determinantal state
161    /// contained in the Python version.
162    ///
163    /// # Errors
164    ///
165    /// Errors if the [`MultiDeterminant`] structures fail to build.
166    pub fn to_qsym2<'b, 'a: 'b, SC>(
167        &'b self,
168        bao: &'a BasisAngularOrder,
169        mol: &'a Molecule,
170    ) -> Result<
171        Vec<MultiDeterminant<'b, f64, EagerBasis<SlaterDeterminant<'b, f64, SC>>, SC>>,
172        anyhow::Error,
173    >
174    where
175        SC: StructureConstraint
176            + Eq
177            + Hash
178            + Clone
179            + fmt::Display
180            + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
181    {
182        let eager_basis = EagerBasis::builder()
183            .elements(
184                self.basis
185                    .iter()
186                    .map(|pydet| pydet.to_qsym2(bao, mol))
187                    .collect::<Result<Vec<_>, _>>()?,
188            )
189            .build()?;
190        let multidets = self
191            .energies
192            .iter()
193            .zip(self.coefficients.columns())
194            .map(|(e, c)| {
195                MultiDeterminant::builder()
196                    .basis(eager_basis.clone())
197                    .coefficients(c.to_owned())
198                    .energy(Ok(*e))
199                    .threshold(self.threshold)
200                    .build()
201            })
202            .collect::<Result<Vec<_>, _>>()
203            .map_err(|err| format_err!(err));
204        multidets
205    }
206}
207
208// ~~~~~~~
209// Complex
210// ~~~~~~~
211//
212/// Python-exposed structure to marshall complex multi-determinant information between Rust and
213/// Python.
214///
215/// # Constructor arguments
216///
217/// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
218/// expressed. Python type: `list[PySlaterDeterminantComplex]`.
219/// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
220/// Each column of the coefficient matrix contains the coefficients for one state.
221/// Python type: `numpy.2darray[complex]`.
222/// * `energies` - The energies of the multi-determinantal states. Python type:
223/// `numpy.1darray[complex]`.
224/// * `threshold` - The threshold for comparisons. Python type: `float`.
225#[pyclass]
226#[derive(Clone)]
227pub struct PyMultiDeterminantsComplex {
228    /// The basis of Slater determinants in which the multi-determinantal states are expressed.
229    ///
230    /// Python type: `list[PySlaterDeterminantReal]`.
231    #[pyo3(get)]
232    basis: Vec<PySlaterDeterminantComplex>,
233
234    /// The coefficients for the multi-determinantal states in the specified basis. Each column of
235    /// the coefficient matrix contains the coefficients for one state.
236    ///
237    /// Python type: `numpy.2darray[complex]`.
238    coefficients: Array2<C128>,
239
240    /// The energies of the multi-determinantal states.
241    ///
242    /// Python type: `numpy.1darray[complex]`.
243    energies: Array1<C128>,
244
245    /// The threshold for comparisons.
246    ///
247    /// Python type: `float`.
248    #[pyo3(get)]
249    threshold: f64,
250}
251
252#[pymethods]
253impl PyMultiDeterminantsComplex {
254    /// Constructs a set of complex Python-exposed multi-determinants.
255    ///
256    /// # Arguments
257    ///
258    /// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
259    /// expressed. Python type: `list[PySlaterDeterminantComplex]`.
260    /// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
261    /// Each column of the coefficient matrix contains the coefficients for one state.
262    /// Python type: `numpy.2darray[complex]`.
263    /// * `energies` - The energies of the multi-determinantal states. Python type: `numpy.1darray[complex]`.
264    /// * `threshold` - The threshold for comparisons. Python type: `float`.
265    #[new]
266    #[pyo3(signature = (basis, coefficients, energies, threshold))]
267    pub fn new(
268        basis: Vec<PySlaterDeterminantComplex>,
269        coefficients: Bound<'_, PyArray2<C128>>,
270        energies: Bound<'_, PyArray1<C128>>,
271        threshold: f64,
272    ) -> Self {
273        let multidet = Self {
274            basis,
275            coefficients: coefficients.to_owned_array(),
276            energies: energies.to_owned_array(),
277            threshold,
278        };
279        multidet
280    }
281
282    #[getter]
283    pub fn coefficients<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<C128>>> {
284        Ok(self.coefficients.to_pyarray(py))
285    }
286
287    #[getter]
288    pub fn energies<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<C128>>> {
289        Ok(self.energies.to_pyarray(py))
290    }
291
292    pub fn complex_symmetric<'py>(&self, py: Python<'py>) -> PyResult<bool> {
293        let complex_symmetric_set = self
294            .basis
295            .iter()
296            .map(|pydet| pydet.complex_symmetric(py))
297            .collect::<Result<HashSet<_>, _>>()?;
298        if complex_symmetric_set.len() != 1 {
299            Err(PyRuntimeError::new_err(
300                "Inconsistent complex-symmetric flags across basis functions.",
301            ))
302        } else {
303            complex_symmetric_set.into_iter().next().ok_or_else(|| {
304                PyRuntimeError::new_err("Unable to extract the complex-symmetric flag.")
305            })
306        }
307    }
308
309    pub fn state_coefficients<'py>(
310        &self,
311        py: Python<'py>,
312        state_index: usize,
313    ) -> PyResult<Bound<'py, PyArray1<C128>>> {
314        Ok(self.coefficients.column(state_index).to_pyarray(py))
315    }
316
317    pub fn state_energy<'py>(&self, _py: Python<'py>, state_index: usize) -> PyResult<C128> {
318        Ok(self.energies[state_index])
319    }
320}
321
322impl PyMultiDeterminantsComplex {
323    /// Extracts the information in the [`PyMultiDeterminantsComplex`] structure into `QSym2`'s native
324    /// [`MultiDeterminant`] structure.
325    ///
326    /// # Arguments
327    ///
328    /// * `bao` - The [`BasisAngularOrder`] for the basis set in which the Slater determinant is
329    /// given.
330    /// * `mol` - The molecule with which the Slater determinant is associated.
331    ///
332    /// # Returns
333    ///
334    /// The A vector of [`MultiDeterminant`] structures, one for each multi-determinantal state
335    /// contained in the Python version.
336    ///
337    /// # Errors
338    ///
339    /// Errors if the [`MultiDeterminant`] structures fail to build.
340    pub fn to_qsym2<'b, 'a: 'b, SC>(
341        &'b self,
342        bao: &'a BasisAngularOrder,
343        mol: &'a Molecule,
344    ) -> Result<
345        Vec<MultiDeterminant<'b, C128, EagerBasis<SlaterDeterminant<'b, C128, SC>>, SC>>,
346        anyhow::Error,
347    >
348    where
349        SC: StructureConstraint
350            + Eq
351            + Hash
352            + Clone
353            + fmt::Display
354            + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
355    {
356        let eager_basis = EagerBasis::builder()
357            .elements(
358                self.basis
359                    .iter()
360                    .map(|pydet| pydet.to_qsym2(bao, mol))
361                    .collect::<Result<Vec<_>, _>>()?,
362            )
363            .build()?;
364        let multidets = self
365            .energies
366            .iter()
367            .zip(self.coefficients.columns())
368            .map(|(e, c)| {
369                MultiDeterminant::builder()
370                    .basis(eager_basis.clone())
371                    .coefficients(c.to_owned())
372                    .energy(Ok(*e))
373                    .threshold(self.threshold)
374                    .build()
375            })
376            .collect::<Result<Vec<_>, _>>()
377            .map_err(|err| format_err!(err));
378        multidets
379    }
380}
381
382// ================
383// Enum definitions
384// ================
385
386/// Python-exposed enumerated type to handle the union type
387/// `PyMultiDeterminantsReal | PyMultiDeterminantsComplex` in Python.
388#[derive(FromPyObject)]
389pub enum PyMultiDeterminants {
390    /// Variant for real Python-exposed multi-determinants.
391    Real(PyMultiDeterminantsReal),
392
393    /// Variant for complex Python-exposed multi-determinants.
394    Complex(PyMultiDeterminantsComplex),
395}
396
397// =====================
398// Functions definitions
399// =====================
400
401mod multideterminant_eager_basis;
402mod multideterminant_orbit_basis_external_solver;
403mod multideterminant_orbit_basis_internal_solver;
404
405pub use multideterminant_eager_basis::rep_analyse_multideterminants_eager_basis;
406pub use multideterminant_orbit_basis_external_solver::rep_analyse_multideterminants_orbit_basis_external_solver;
407pub use multideterminant_orbit_basis_internal_solver::rep_analyse_multideterminants_orbit_basis_internal_solver;