qsym2/bindings/python/representation_analysis/multideterminant/mod.rs
1//! Python bindings for QSym² symmetry analysis of multi-determinants.
2
3use std::collections::HashSet;
4use std::fmt;
5use std::hash::Hash;
6
7use anyhow::format_err;
8use ndarray::{Array1, Array2};
9use num_complex::Complex;
10use numpy::{PyArray1, PyArray2, PyArrayMethods, ToPyArray};
11use pyo3::exceptions::PyRuntimeError;
12use pyo3::prelude::*;
13
14use crate::angmom::spinor_rotation_3d::StructureConstraint;
15use crate::auxiliary::molecule::Molecule;
16use crate::basis::ao::BasisAngularOrder;
17use crate::bindings::python::integrals::PyStructureConstraint;
18use crate::bindings::python::representation_analysis::slater_determinant::{
19 PySlaterDeterminantComplex, PySlaterDeterminantReal,
20};
21use crate::target::determinant::SlaterDeterminant;
22use crate::target::noci::basis::EagerBasis;
23use crate::target::noci::multideterminant::MultiDeterminant;
24
25type C128 = Complex<f64>;
26
27// ==================
28// Struct definitions
29// ==================
30
31// -----------------
32// Multi-determinant
33// -----------------
34
35// ~~~~
36// Real
37// ~~~~
38//
39/// Python-exposed structure to marshall real multi-determinant information between Rust and
40/// Python.
41///
42/// # Constructor arguments
43///
44/// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
45/// expressed. Python type: `list[PySlaterDeterminantReal]`.
46/// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
47/// Each column of the coefficient matrix contains the coefficients for one state.
48/// Python type: `numpy.2darray[float]`.
49/// * `energies` - The energies of the multi-determinantal states. Python type: `numpy.1darray[float]`.
50/// * `threshold` - The threshold for comparisons. Python type: `float`.
51#[pyclass]
52#[derive(Clone)]
53pub struct PyMultiDeterminantsReal {
54 /// The basis of Slater determinants in which the multi-determinantal states are expressed.
55 ///
56 /// Python type: `list[PySlaterDeterminantReal]`.
57 #[pyo3(get)]
58 basis: Vec<PySlaterDeterminantReal>,
59
60 /// The coefficients for the multi-determinantal states in the specified basis. Each column of
61 /// the coefficient matrix contains the coefficients for one state.
62 ///
63 /// Python type: `numpy.2darray[float]`.
64 coefficients: Array2<f64>,
65
66 /// The energies of the multi-determinantal states.
67 ///
68 /// Python type: `numpy.1darray[float]`.
69 energies: Array1<f64>,
70
71 /// The threshold for comparisons.
72 ///
73 /// Python type: `float`.
74 #[pyo3(get)]
75 threshold: f64,
76}
77
78#[pymethods]
79impl PyMultiDeterminantsReal {
80 /// Constructs a set of real Python-exposed multi-determinants.
81 ///
82 /// # Arguments
83 ///
84 /// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
85 /// expressed. Python type: `list[PySlaterDeterminantReal]`.
86 /// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
87 /// Each column of the coefficient matrix contains the coefficients for one state.
88 /// Python type: `numpy.2darray[float]`.
89 /// * `energies` - The energies of the multi-determinantal states. Python type: `numpy.1darray[float]`.
90 /// * `threshold` - The threshold for comparisons. Python type: `float`.
91 #[new]
92 #[pyo3(signature = (basis, coefficients, energies, threshold))]
93 pub fn new(
94 basis: Vec<PySlaterDeterminantReal>,
95 coefficients: Bound<'_, PyArray2<f64>>,
96 energies: Bound<'_, PyArray1<f64>>,
97 threshold: f64,
98 ) -> Self {
99 let multidet = Self {
100 basis,
101 coefficients: coefficients.to_owned_array(),
102 energies: energies.to_owned_array(),
103 threshold,
104 };
105 multidet
106 }
107
108 #[getter]
109 pub fn coefficients<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
110 Ok(self.coefficients.to_pyarray(py))
111 }
112
113 #[getter]
114 pub fn energies<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
115 Ok(self.energies.to_pyarray(py))
116 }
117
118 pub fn complex_symmetric<'py>(&self, py: Python<'py>) -> PyResult<bool> {
119 let complex_symmetric_set = self
120 .basis
121 .iter()
122 .map(|pydet| pydet.complex_symmetric(py))
123 .collect::<Result<HashSet<_>, _>>()?;
124 if complex_symmetric_set.len() != 1 {
125 Err(PyRuntimeError::new_err(
126 "Inconsistent complex-symmetric flags across basis functions.",
127 ))
128 } else {
129 complex_symmetric_set.into_iter().next().ok_or_else(|| {
130 PyRuntimeError::new_err("Unable to extract the complex-symmetric flag.")
131 })
132 }
133 }
134
135 pub fn state_coefficients<'py>(
136 &self,
137 py: Python<'py>,
138 state_index: usize,
139 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
140 Ok(self.coefficients.column(state_index).to_pyarray(py))
141 }
142
143 pub fn state_energy<'py>(&self, _py: Python<'py>, state_index: usize) -> PyResult<f64> {
144 Ok(self.energies[state_index])
145 }
146}
147
148impl PyMultiDeterminantsReal {
149 /// Extracts the information in the [`PyMultiDeterminantsReal`] structure into a vector of
150 /// `QSym2`'s native [`MultiDeterminant`] structures.
151 ///
152 /// # Arguments
153 ///
154 /// * `bao` - The [`BasisAngularOrder`] for the basis set in which the Slater determinant is
155 /// given.
156 /// * `mol` - The molecule with which the Slater determinant is associated.
157 ///
158 /// # Returns
159 ///
160 /// The A vector of [`MultiDeterminant`] structures, one for each multi-determinantal state
161 /// contained in the Python version.
162 ///
163 /// # Errors
164 ///
165 /// Errors if the [`MultiDeterminant`] structures fail to build.
166 pub fn to_qsym2<'b, 'a: 'b, SC>(
167 &'b self,
168 bao: &'a BasisAngularOrder,
169 mol: &'a Molecule,
170 ) -> Result<
171 Vec<MultiDeterminant<'b, f64, EagerBasis<SlaterDeterminant<'b, f64, SC>>, SC>>,
172 anyhow::Error,
173 >
174 where
175 SC: StructureConstraint
176 + Eq
177 + Hash
178 + Clone
179 + fmt::Display
180 + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
181 {
182 let eager_basis = EagerBasis::builder()
183 .elements(
184 self.basis
185 .iter()
186 .map(|pydet| pydet.to_qsym2(bao, mol))
187 .collect::<Result<Vec<_>, _>>()?,
188 )
189 .build()?;
190 let multidets = self
191 .energies
192 .iter()
193 .zip(self.coefficients.columns())
194 .map(|(e, c)| {
195 MultiDeterminant::builder()
196 .basis(eager_basis.clone())
197 .coefficients(c.to_owned())
198 .energy(Ok(*e))
199 .threshold(self.threshold)
200 .build()
201 })
202 .collect::<Result<Vec<_>, _>>()
203 .map_err(|err| format_err!(err));
204 multidets
205 }
206}
207
208// ~~~~~~~
209// Complex
210// ~~~~~~~
211//
212/// Python-exposed structure to marshall complex multi-determinant information between Rust and
213/// Python.
214///
215/// # Constructor arguments
216///
217/// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
218/// expressed. Python type: `list[PySlaterDeterminantComplex]`.
219/// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
220/// Each column of the coefficient matrix contains the coefficients for one state.
221/// Python type: `numpy.2darray[complex]`.
222/// * `energies` - The energies of the multi-determinantal states. Python type:
223/// `numpy.1darray[complex]`.
224/// * `threshold` - The threshold for comparisons. Python type: `float`.
225#[pyclass]
226#[derive(Clone)]
227pub struct PyMultiDeterminantsComplex {
228 /// The basis of Slater determinants in which the multi-determinantal states are expressed.
229 ///
230 /// Python type: `list[PySlaterDeterminantReal]`.
231 #[pyo3(get)]
232 basis: Vec<PySlaterDeterminantComplex>,
233
234 /// The coefficients for the multi-determinantal states in the specified basis. Each column of
235 /// the coefficient matrix contains the coefficients for one state.
236 ///
237 /// Python type: `numpy.2darray[complex]`.
238 coefficients: Array2<C128>,
239
240 /// The energies of the multi-determinantal states.
241 ///
242 /// Python type: `numpy.1darray[complex]`.
243 energies: Array1<C128>,
244
245 /// The threshold for comparisons.
246 ///
247 /// Python type: `float`.
248 #[pyo3(get)]
249 threshold: f64,
250}
251
252#[pymethods]
253impl PyMultiDeterminantsComplex {
254 /// Constructs a set of complex Python-exposed multi-determinants.
255 ///
256 /// # Arguments
257 ///
258 /// * `basis` - The basis of Slater determinants in which the multi-determinantal states are
259 /// expressed. Python type: `list[PySlaterDeterminantComplex]`.
260 /// * `coefficients` - The coefficients for the multi-determinantal states in the specified basis.
261 /// Each column of the coefficient matrix contains the coefficients for one state.
262 /// Python type: `numpy.2darray[complex]`.
263 /// * `energies` - The energies of the multi-determinantal states. Python type: `numpy.1darray[complex]`.
264 /// * `threshold` - The threshold for comparisons. Python type: `float`.
265 #[new]
266 #[pyo3(signature = (basis, coefficients, energies, threshold))]
267 pub fn new(
268 basis: Vec<PySlaterDeterminantComplex>,
269 coefficients: Bound<'_, PyArray2<C128>>,
270 energies: Bound<'_, PyArray1<C128>>,
271 threshold: f64,
272 ) -> Self {
273 let multidet = Self {
274 basis,
275 coefficients: coefficients.to_owned_array(),
276 energies: energies.to_owned_array(),
277 threshold,
278 };
279 multidet
280 }
281
282 #[getter]
283 pub fn coefficients<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<C128>>> {
284 Ok(self.coefficients.to_pyarray(py))
285 }
286
287 #[getter]
288 pub fn energies<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<C128>>> {
289 Ok(self.energies.to_pyarray(py))
290 }
291
292 pub fn complex_symmetric<'py>(&self, py: Python<'py>) -> PyResult<bool> {
293 let complex_symmetric_set = self
294 .basis
295 .iter()
296 .map(|pydet| pydet.complex_symmetric(py))
297 .collect::<Result<HashSet<_>, _>>()?;
298 if complex_symmetric_set.len() != 1 {
299 Err(PyRuntimeError::new_err(
300 "Inconsistent complex-symmetric flags across basis functions.",
301 ))
302 } else {
303 complex_symmetric_set.into_iter().next().ok_or_else(|| {
304 PyRuntimeError::new_err("Unable to extract the complex-symmetric flag.")
305 })
306 }
307 }
308
309 pub fn state_coefficients<'py>(
310 &self,
311 py: Python<'py>,
312 state_index: usize,
313 ) -> PyResult<Bound<'py, PyArray1<C128>>> {
314 Ok(self.coefficients.column(state_index).to_pyarray(py))
315 }
316
317 pub fn state_energy<'py>(&self, _py: Python<'py>, state_index: usize) -> PyResult<C128> {
318 Ok(self.energies[state_index])
319 }
320}
321
322impl PyMultiDeterminantsComplex {
323 /// Extracts the information in the [`PyMultiDeterminantsComplex`] structure into `QSym2`'s native
324 /// [`MultiDeterminant`] structure.
325 ///
326 /// # Arguments
327 ///
328 /// * `bao` - The [`BasisAngularOrder`] for the basis set in which the Slater determinant is
329 /// given.
330 /// * `mol` - The molecule with which the Slater determinant is associated.
331 ///
332 /// # Returns
333 ///
334 /// The A vector of [`MultiDeterminant`] structures, one for each multi-determinantal state
335 /// contained in the Python version.
336 ///
337 /// # Errors
338 ///
339 /// Errors if the [`MultiDeterminant`] structures fail to build.
340 pub fn to_qsym2<'b, 'a: 'b, SC>(
341 &'b self,
342 bao: &'a BasisAngularOrder,
343 mol: &'a Molecule,
344 ) -> Result<
345 Vec<MultiDeterminant<'b, C128, EagerBasis<SlaterDeterminant<'b, C128, SC>>, SC>>,
346 anyhow::Error,
347 >
348 where
349 SC: StructureConstraint
350 + Eq
351 + Hash
352 + Clone
353 + fmt::Display
354 + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
355 {
356 let eager_basis = EagerBasis::builder()
357 .elements(
358 self.basis
359 .iter()
360 .map(|pydet| pydet.to_qsym2(bao, mol))
361 .collect::<Result<Vec<_>, _>>()?,
362 )
363 .build()?;
364 let multidets = self
365 .energies
366 .iter()
367 .zip(self.coefficients.columns())
368 .map(|(e, c)| {
369 MultiDeterminant::builder()
370 .basis(eager_basis.clone())
371 .coefficients(c.to_owned())
372 .energy(Ok(*e))
373 .threshold(self.threshold)
374 .build()
375 })
376 .collect::<Result<Vec<_>, _>>()
377 .map_err(|err| format_err!(err));
378 multidets
379 }
380}
381
382// ================
383// Enum definitions
384// ================
385
386/// Python-exposed enumerated type to handle the union type
387/// `PyMultiDeterminantsReal | PyMultiDeterminantsComplex` in Python.
388#[derive(FromPyObject)]
389pub enum PyMultiDeterminants {
390 /// Variant for real Python-exposed multi-determinants.
391 Real(PyMultiDeterminantsReal),
392
393 /// Variant for complex Python-exposed multi-determinants.
394 Complex(PyMultiDeterminantsComplex),
395}
396
397// =====================
398// Functions definitions
399// =====================
400
401mod multideterminant_eager_basis;
402mod multideterminant_orbit_basis_external_solver;
403mod multideterminant_orbit_basis_internal_solver;
404
405pub use multideterminant_eager_basis::rep_analyse_multideterminants_eager_basis;
406pub use multideterminant_orbit_basis_external_solver::rep_analyse_multideterminants_orbit_basis_external_solver;
407pub use multideterminant_orbit_basis_internal_solver::rep_analyse_multideterminants_orbit_basis_internal_solver;