1use std::collections::HashSet;
4use std::fmt;
5use std::hash::Hash;
6use std::path::PathBuf;
7
8use anyhow::format_err;
9use itertools::Itertools;
10use ndarray::{Array1, Array2};
11use num_complex::Complex;
12use numpy::{PyArray1, PyArray2, PyArrayMethods, ToPyArray};
13use pyo3::exceptions::PyRuntimeError;
14use pyo3::prelude::*;
15use serde::{Deserialize, Serialize};
16
17use crate::angmom::spinor_rotation_3d::StructureConstraint;
18use crate::auxiliary::molecule::Molecule;
19use crate::basis::ao::BasisAngularOrder;
20use crate::bindings::python::integrals::PyStructureConstraint;
21use crate::bindings::python::representation_analysis::slater_determinant::{
22 PySlaterDeterminantComplex, PySlaterDeterminantReal,
23};
24use crate::io::format::qsym2_output;
25use crate::io::{QSym2FileType, read_qsym2_binary, write_qsym2_binary};
26use crate::target::determinant::SlaterDeterminant;
27use crate::target::noci::basis::EagerBasis;
28use crate::target::noci::multideterminant::MultiDeterminant;
29use crate::target::noci::multideterminants::MultiDeterminants;
30
31type C128 = Complex<f64>;
32
33#[pyclass]
47#[derive(Clone, Serialize, Deserialize)]
48pub struct PyMultiDeterminantsReal {
49 #[pyo3(get)]
51 basis: Vec<PySlaterDeterminantReal>,
52
53 coefficients: Array2<f64>,
56
57 energies: Array1<f64>,
59
60 density_matrices: Option<Vec<Array2<f64>>>,
62
63 #[pyo3(get)]
65 threshold: f64,
66}
67
68#[pymethods]
69impl PyMultiDeterminantsReal {
70 #[new]
82 #[pyo3(signature = (basis, coefficients, energies, density_matrices, threshold))]
83 pub fn new(
84 basis: Vec<PySlaterDeterminantReal>,
85 coefficients: Bound<'_, PyArray2<f64>>,
86 energies: Bound<'_, PyArray1<f64>>,
87 density_matrices: Option<Vec<Bound<'_, PyArray2<f64>>>>,
88 threshold: f64,
89 ) -> Self {
90 let coefficients = coefficients.to_owned_array();
91 let energies = energies.to_owned_array();
92 let density_matrices = density_matrices.map(|denmats| {
93 denmats
94 .into_iter()
95 .map(|denmat| denmat.to_owned_array())
96 .collect_vec()
97 });
98 if let Some(ref denmats) = density_matrices
99 && (denmats.len() != coefficients.ncols()
100 || denmats.len() != energies.len()
101 || coefficients.ncols() != energies.len())
102 {
103 panic!(
104 "Inconsistent numbers of multi-determinantal states in `coefficients`, `energies`, and `density_matrices`."
105 )
106 };
107 Self {
108 basis,
109 coefficients,
110 energies,
111 density_matrices,
112 threshold,
113 }
114 }
115
116 #[getter]
117 pub fn coefficients<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
118 Ok(self.coefficients.to_pyarray(py))
119 }
120
121 #[getter]
122 pub fn energies<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
123 Ok(self.energies.to_pyarray(py))
124 }
125
126 #[getter]
127 pub fn density_matrices<'py>(
128 &self,
129 py: Python<'py>,
130 ) -> PyResult<Option<Vec<Bound<'py, PyArray2<f64>>>>> {
131 Ok(self.density_matrices.as_ref().map(|denmats| {
132 denmats
133 .iter()
134 .map(|denmat| denmat.to_pyarray(py))
135 .collect_vec()
136 }))
137 }
138
139 pub fn complex_symmetric<'py>(&self, _py: Python<'py>) -> PyResult<bool> {
142 let complex_symmetric_set = self
143 .basis
144 .iter()
145 .map(|pydet| pydet.complex_symmetric)
146 .collect::<HashSet<_>>();
147 if complex_symmetric_set.len() != 1 {
148 Err(PyRuntimeError::new_err(
149 "Inconsistent complex-symmetric flags across basis functions.",
150 ))
151 } else {
152 complex_symmetric_set.into_iter().next().ok_or_else(|| {
153 PyRuntimeError::new_err("Unable to extract the complex-symmetric flag.")
154 })
155 }
156 }
157
158 pub fn state_coefficients<'py>(
160 &self,
161 py: Python<'py>,
162 state_index: usize,
163 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
164 Ok(self.coefficients.column(state_index).to_pyarray(py))
165 }
166
167 pub fn state_energy<'py>(&self, _py: Python<'py>, state_index: usize) -> PyResult<f64> {
169 Ok(self.energies[state_index])
170 }
171
172 pub fn state_density_matrix<'py>(
174 &self,
175 py: Python<'py>,
176 state_index: usize,
177 ) -> PyResult<Bound<'py, PyArray2<f64>>> {
178 self.density_matrices
179 .as_ref()
180 .ok_or_else(|| {
181 PyRuntimeError::new_err(
182 "No multi-determinantal density matrices found.".to_string(),
183 )
184 })
185 .map(|denmats| denmats[state_index].to_pyarray(py))
186 }
187
188 pub fn to_qsym2_binary<'py>(&self, _py: Python<'py>, name: PathBuf) -> PyResult<usize> {
199 let mut path = name.to_path_buf();
200 path.set_extension(QSym2FileType::Pymdet.ext());
201 qsym2_output!(
202 "Real Python-exposed multi-determinants saved as {}.",
203 path.display().to_string()
204 );
205 write_qsym2_binary(name, QSym2FileType::Pymdet, self)
206 .map_err(|err| PyRuntimeError::new_err(err.to_string()))
207 }
208
209 #[staticmethod]
220 pub fn from_qsym2_binary(name: PathBuf) -> PyResult<Self> {
221 let mut path = name.to_path_buf();
222 path.set_extension(QSym2FileType::Pymdet.ext());
223 qsym2_output!(
224 "Real Python-exposed multi-determinants read in from {}.",
225 path.display().to_string()
226 );
227 read_qsym2_binary(name, QSym2FileType::Pymdet)
228 .map_err(|err| PyRuntimeError::new_err(err.to_string()))
229 }
230}
231
232impl PyMultiDeterminantsReal {
233 #[allow(clippy::type_complexity)]
251 pub fn to_qsym2_individuals<'b, 'a: 'b, SC>(
252 &'b self,
253 baos: &[&'a BasisAngularOrder],
254 mol: &'a Molecule,
255 ) -> Result<
256 Vec<MultiDeterminant<'b, f64, EagerBasis<SlaterDeterminant<'b, f64, SC>>, SC>>,
257 anyhow::Error,
258 >
259 where
260 SC: StructureConstraint
261 + Eq
262 + Hash
263 + Clone
264 + fmt::Display
265 + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
266 {
267 let eager_basis = EagerBasis::builder()
268 .elements(
269 self.basis
270 .iter()
271 .map(|pydet| pydet.to_qsym2(baos, mol))
272 .collect::<Result<Vec<_>, _>>()?,
273 )
274 .build()?;
275 self.energies
276 .iter()
277 .zip(self.coefficients.columns())
278 .map(|(e, c)| {
279 MultiDeterminant::builder()
280 .basis(eager_basis.clone())
281 .coefficients(c.to_owned())
282 .energy(Ok(*e))
283 .threshold(self.threshold)
284 .build()
285 })
286 .collect::<Result<Vec<_>, _>>()
287 .map_err(|err| format_err!(err))
288 }
289
290 #[allow(clippy::type_complexity)]
307 pub fn to_qsym2_collection<'b, 'a: 'b, SC>(
308 &'b self,
309 baos: &[&'a BasisAngularOrder],
310 mol: &'a Molecule,
311 ) -> Result<
312 MultiDeterminants<'b, f64, EagerBasis<SlaterDeterminant<'b, f64, SC>>, SC>,
313 anyhow::Error,
314 >
315 where
316 SC: StructureConstraint
317 + Eq
318 + Hash
319 + Clone
320 + fmt::Display
321 + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
322 {
323 let eager_basis = EagerBasis::builder()
324 .elements(
325 self.basis
326 .iter()
327 .map(|pydet| pydet.to_qsym2(baos, mol))
328 .collect::<Result<Vec<_>, _>>()?,
329 )
330 .build()?;
331 MultiDeterminants::builder()
332 .basis(eager_basis)
333 .coefficients(self.coefficients.clone())
334 .energies(Ok(self.energies.clone()))
335 .threshold(self.threshold)
336 .build()
337 .map_err(|err| format_err!(err))
338 }
339}
340
341#[pyclass]
348#[derive(Clone, Serialize, Deserialize)]
349pub struct PyMultiDeterminantsComplex {
350 #[pyo3(get)]
352 basis: Vec<PySlaterDeterminantComplex>,
353
354 coefficients: Array2<C128>,
357
358 energies: Array1<C128>,
360
361 density_matrices: Option<Vec<Array2<C128>>>,
363
364 #[pyo3(get)]
366 threshold: f64,
367}
368
369#[pymethods]
370impl PyMultiDeterminantsComplex {
371 #[new]
383 #[pyo3(signature = (basis, coefficients, energies, density_matrices, threshold))]
384 pub fn new(
385 basis: Vec<PySlaterDeterminantComplex>,
386 coefficients: Bound<'_, PyArray2<C128>>,
387 energies: Bound<'_, PyArray1<C128>>,
388 density_matrices: Option<Vec<Bound<'_, PyArray2<C128>>>>,
389 threshold: f64,
390 ) -> Self {
391 let coefficients = coefficients.to_owned_array();
392 let energies = energies.to_owned_array();
393 let density_matrices = density_matrices.map(|denmats| {
394 denmats
395 .into_iter()
396 .map(|denmat| denmat.to_owned_array())
397 .collect_vec()
398 });
399 if let Some(ref denmats) = density_matrices
400 && (denmats.len() != coefficients.ncols()
401 || denmats.len() != energies.len()
402 || coefficients.ncols() != energies.len())
403 {
404 panic!(
405 "Inconsistent numbers of multi-determinantal states in `coefficients`, `energies`, and `density_matrices`."
406 )
407 };
408 Self {
409 basis,
410 coefficients,
411 energies,
412 density_matrices,
413 threshold,
414 }
415 }
416
417 #[getter]
418 pub fn coefficients<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<C128>>> {
419 Ok(self.coefficients.to_pyarray(py))
420 }
421
422 #[getter]
423 pub fn energies<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<C128>>> {
424 Ok(self.energies.to_pyarray(py))
425 }
426
427 #[getter]
428 pub fn density_matrices<'py>(
429 &self,
430 py: Python<'py>,
431 ) -> PyResult<Option<Vec<Bound<'py, PyArray2<C128>>>>> {
432 Ok(self.density_matrices.as_ref().map(|denmats| {
433 denmats
434 .iter()
435 .map(|denmat| denmat.to_pyarray(py))
436 .collect_vec()
437 }))
438 }
439
440 pub fn complex_symmetric<'py>(&self, _py: Python<'py>) -> PyResult<bool> {
443 let complex_symmetric_set = self
444 .basis
445 .iter()
446 .map(|pydet| pydet.complex_symmetric)
447 .collect::<HashSet<_>>();
448 if complex_symmetric_set.len() != 1 {
449 Err(PyRuntimeError::new_err(
450 "Inconsistent complex-symmetric flags across basis functions.",
451 ))
452 } else {
453 complex_symmetric_set.into_iter().next().ok_or_else(|| {
454 PyRuntimeError::new_err("Unable to extract the complex-symmetric flag.")
455 })
456 }
457 }
458
459 pub fn state_coefficients<'py>(
461 &self,
462 py: Python<'py>,
463 state_index: usize,
464 ) -> PyResult<Bound<'py, PyArray1<C128>>> {
465 Ok(self.coefficients.column(state_index).to_pyarray(py))
466 }
467
468 pub fn state_energy<'py>(&self, _py: Python<'py>, state_index: usize) -> PyResult<C128> {
470 Ok(self.energies[state_index])
471 }
472
473 pub fn state_density_matrix<'py>(
475 &self,
476 py: Python<'py>,
477 state_index: usize,
478 ) -> PyResult<Bound<'py, PyArray2<C128>>> {
479 self.density_matrices
480 .as_ref()
481 .ok_or_else(|| {
482 PyRuntimeError::new_err(
483 "No multi-determinantal density matrices found.".to_string(),
484 )
485 })
486 .map(|denmats| denmats[state_index].to_pyarray(py))
487 }
488
489 pub fn to_qsym2_binary<'py>(&self, _py: Python<'py>, name: PathBuf) -> PyResult<usize> {
500 let mut path = name.to_path_buf();
501 path.set_extension(QSym2FileType::Pymdet.ext());
502 qsym2_output!(
503 "Complex Python-exposed multi-determinants saved as {}.",
504 path.display().to_string()
505 );
506 write_qsym2_binary(name, QSym2FileType::Pymdet, self)
507 .map_err(|err| PyRuntimeError::new_err(err.to_string()))
508 }
509
510 #[staticmethod]
521 pub fn from_qsym2_binary(name: PathBuf) -> PyResult<Self> {
522 let mut path = name.to_path_buf();
523 path.set_extension(QSym2FileType::Pymdet.ext());
524 qsym2_output!(
525 "Complex Python-exposed multi-determinants read in from {}.",
526 path.display().to_string()
527 );
528 read_qsym2_binary(name, QSym2FileType::Pymdet)
529 .map_err(|err| PyRuntimeError::new_err(err.to_string()))
530 }
531}
532
533impl PyMultiDeterminantsComplex {
534 #[allow(clippy::type_complexity)]
552 pub fn to_qsym2_individuals<'b, 'a: 'b, SC>(
553 &'b self,
554 baos: &[&'a BasisAngularOrder],
555 mol: &'a Molecule,
556 ) -> Result<
557 Vec<MultiDeterminant<'b, C128, EagerBasis<SlaterDeterminant<'b, C128, SC>>, SC>>,
558 anyhow::Error,
559 >
560 where
561 SC: StructureConstraint
562 + Eq
563 + Hash
564 + Clone
565 + fmt::Display
566 + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
567 {
568 let eager_basis = EagerBasis::builder()
569 .elements(
570 self.basis
571 .iter()
572 .map(|pydet| pydet.to_qsym2(baos, mol))
573 .collect::<Result<Vec<_>, _>>()?,
574 )
575 .build()?;
576 self.energies
577 .iter()
578 .zip(self.coefficients.columns())
579 .map(|(e, c)| {
580 MultiDeterminant::builder()
581 .basis(eager_basis.clone())
582 .coefficients(c.to_owned())
583 .energy(Ok(*e))
584 .threshold(self.threshold)
585 .build()
586 })
587 .collect::<Result<Vec<_>, _>>()
588 .map_err(|err| format_err!(err))
589 }
590
591 #[allow(clippy::type_complexity)]
608 pub fn to_qsym2_collection<'b, 'a: 'b, SC>(
609 &'b self,
610 baos: &[&'a BasisAngularOrder],
611 mol: &'a Molecule,
612 ) -> Result<
613 MultiDeterminants<'b, C128, EagerBasis<SlaterDeterminant<'b, C128, SC>>, SC>,
614 anyhow::Error,
615 >
616 where
617 SC: StructureConstraint
618 + Eq
619 + Hash
620 + Clone
621 + fmt::Display
622 + TryFrom<PyStructureConstraint, Error = anyhow::Error>,
623 {
624 let eager_basis = EagerBasis::builder()
625 .elements(
626 self.basis
627 .iter()
628 .map(|pydet| pydet.to_qsym2(baos, mol))
629 .collect::<Result<Vec<_>, _>>()?,
630 )
631 .build()?;
632 MultiDeterminants::builder()
633 .basis(eager_basis)
634 .coefficients(self.coefficients.clone())
635 .energies(Ok(self.energies.clone()))
636 .threshold(self.threshold)
637 .build()
638 .map_err(|err| format_err!(err))
639 }
640}
641
642#[derive(FromPyObject)]
649pub enum PyMultiDeterminants {
650 Real(PyMultiDeterminantsReal),
652
653 Complex(PyMultiDeterminantsComplex),
655}
656
657mod multideterminant_eager_basis;
662mod multideterminant_orbit_basis_external_solver;
663mod multideterminant_orbit_basis_internal_solver;
664
665pub use multideterminant_eager_basis::rep_analyse_multideterminants_eager_basis;
666pub use multideterminant_orbit_basis_external_solver::rep_analyse_multideterminants_orbit_basis_external_solver;
667pub use multideterminant_orbit_basis_internal_solver::rep_analyse_multideterminants_orbit_basis_internal_solver;