use std::collections::HashSet;
use std::fmt;
use std::marker::PhantomData;
use derive_builder::Builder;
use log;
use ndarray::Array1;
use ndarray_linalg::types::Lapack;
use num_complex::ComplexFloat;
use crate::angmom::spinor_rotation_3d::SpinConstraint;
use crate::target::determinant::SlaterDeterminant;
use super::basis::Basis;
#[path = "multideterminant_transformation.rs"]
pub(crate) mod multideterminant_transformation;
#[path = "multideterminant_analysis.rs"]
pub(crate) mod multideterminant_analysis;
#[cfg(test)]
#[path = "multideterminant_tests.rs"]
mod multideterminant_tests;
#[derive(Builder, Clone)]
#[builder(build_fn(validate = "Self::validate"))]
pub struct MultiDeterminant<'a, T, B>
where
T: ComplexFloat + Lapack,
B: Basis<SlaterDeterminant<'a, T>> + Clone,
{
#[builder(setter(skip), default = "PhantomData")]
_lifetime: PhantomData<&'a ()>,
#[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
complex_symmetric: bool,
#[builder(default = "false")]
complex_conjugated: bool,
basis: B,
coefficients: Array1<T>,
#[builder(
default = "Err(\"Multi-determinantal wavefunction energy not yet set.\".to_string())"
)]
energy: Result<T, String>,
threshold: <T as ComplexFloat>::Real,
}
impl<'a, T, B> MultiDeterminantBuilder<'a, T, B>
where
T: ComplexFloat + Lapack,
B: Basis<SlaterDeterminant<'a, T>> + Clone,
{
fn validate(&self) -> Result<(), String> {
let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
let coefficients = self
.coefficients
.as_ref()
.ok_or("No coefficients found.".to_string())?;
let nbasis = basis.n_items() == coefficients.len();
if !nbasis {
log::error!(
"The number of coefficients does not match the number of basis determinants."
);
}
let complex_symmetric = basis
.iter()
.map(|det_res| det_res.map(|det| det.complex_symmetric()))
.collect::<Result<HashSet<_>, _>>()
.map_err(|err| err.to_string())?
.len()
== 1;
if !complex_symmetric {
log::error!("Inconsistent complex-symmetric flag across basis determinants.");
}
let spincons = basis
.iter()
.map(|det_res| det_res.map(|det| det.spin_constraint().clone()))
.collect::<Result<HashSet<_>, _>>()
.map_err(|err| err.to_string())?
.len()
== 1;
if !spincons {
log::error!("Inconsistent spin constraints across basis determinants.");
}
if nbasis && spincons && complex_symmetric {
Ok(())
} else {
Err("Multi-determinant wavefunction validation failed.".to_string())
}
}
fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
let complex_symmetric_set = basis
.iter()
.map(|det_res| det_res.map(|det| det.complex_symmetric()))
.collect::<Result<HashSet<_>, _>>()
.map_err(|err| err.to_string())?;
if complex_symmetric_set.len() == 1 {
complex_symmetric_set
.into_iter()
.next()
.ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
} else {
Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
}
}
}
impl<'a, T, B> MultiDeterminant<'a, T, B>
where
T: ComplexFloat + Lapack,
B: Basis<SlaterDeterminant<'a, T>> + Clone,
{
pub(crate) fn builder() -> MultiDeterminantBuilder<'a, T, B> {
MultiDeterminantBuilder::default()
}
pub fn spin_constraint(&self) -> SpinConstraint {
self.basis
.iter()
.next()
.expect("No basis determinant found.")
.expect("No basis determinant found.")
.spin_constraint()
.clone()
}
pub fn complex_conjugated(&self) -> bool {
self.complex_conjugated
}
pub fn basis(&self) -> &B {
&self.basis
}
pub fn coefficients(&self) -> &Array1<T> {
&self.coefficients
}
pub fn energy(&self) -> Result<&T, &String> {
self.energy.as_ref()
}
pub fn threshold(&self) -> <T as ComplexFloat>::Real {
self.threshold
}
}
impl<'a, T, B> fmt::Debug for MultiDeterminant<'a, T, B>
where
T: ComplexFloat + Lapack,
B: Basis<SlaterDeterminant<'a, T>> + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"MultiDeterminant over {} basis Slater determinants",
self.coefficients.len(),
)?;
Ok(())
}
}
impl<'a, T, B> fmt::Display for MultiDeterminant<'a, T, B>
where
T: ComplexFloat + Lapack,
B: Basis<SlaterDeterminant<'a, T>> + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"MultiDeterminant over {} basis Slater determinants",
self.coefficients.len(),
)?;
Ok(())
}
}