1use std::collections::HashSet;
4use std::fmt;
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use derive_builder::Builder;
9use log;
10use ndarray::Array1;
11use ndarray_linalg::types::Lapack;
12use num_complex::ComplexFloat;
13
14use crate::angmom::spinor_rotation_3d::StructureConstraint;
15use crate::target::determinant::SlaterDeterminant;
16
17use super::basis::Basis;
18
19#[path = "multideterminant_transformation.rs"]
20pub(crate) mod multideterminant_transformation;
21
22#[path = "multideterminant_analysis.rs"]
23pub(crate) mod multideterminant_analysis;
24
25#[cfg(test)]
26#[path = "multideterminant_tests.rs"]
27mod multideterminant_tests;
28
29#[derive(Builder, Clone)]
35#[builder(build_fn(validate = "Self::validate"))]
36pub struct MultiDeterminant<'a, T, B, SC>
37where
38 T: ComplexFloat + Lapack,
39 SC: StructureConstraint + Hash + Eq + fmt::Display,
40 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
41{
42 #[builder(setter(skip), default = "PhantomData")]
43 _lifetime: PhantomData<&'a ()>,
44
45 #[builder(setter(skip), default = "PhantomData")]
46 _structure_constraint: PhantomData<SC>,
47
48 #[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
51 complex_symmetric: bool,
52
53 #[builder(default = "false")]
56 complex_conjugated: bool,
57
58 basis: B,
60
61 coefficients: Array1<T>,
64
65 #[builder(
67 default = "Err(\"Multi-determinantal wavefunction energy not yet set.\".to_string())"
68 )]
69 energy: Result<T, String>,
70
71 threshold: <T as ComplexFloat>::Real,
73}
74
75impl<'a, T, B, SC> MultiDeterminantBuilder<'a, T, B, SC>
80where
81 T: ComplexFloat + Lapack,
82 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
83 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
84{
85 fn validate(&self) -> Result<(), String> {
86 let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
87 let coefficients = self
88 .coefficients
89 .as_ref()
90 .ok_or("No coefficients found.".to_string())?;
91 let nbasis = basis.n_items() == coefficients.len();
92 if !nbasis {
93 log::error!(
94 "The number of coefficients does not match the number of basis determinants."
95 );
96 }
97
98 let complex_symmetric = basis
99 .iter()
100 .map(|det_res| det_res.map(|det| det.complex_symmetric()))
101 .collect::<Result<HashSet<_>, _>>()
102 .map_err(|err| err.to_string())?
103 .len()
104 == 1;
105 if !complex_symmetric {
106 log::error!("Inconsistent complex-symmetric flag across basis determinants.");
107 }
108
109 let structcons_check = basis
110 .iter()
111 .map(|det_res| det_res.map(|det| det.structure_constraint().clone()))
112 .collect::<Result<HashSet<_>, _>>()
113 .map_err(|err| err.to_string())?
114 .len()
115 == 1;
116 if !structcons_check {
117 log::error!("Inconsistent spin constraints across basis determinants.");
118 }
119
120 if nbasis && structcons_check && complex_symmetric {
121 Ok(())
122 } else {
123 Err("Multi-determinant wavefunction validation failed.".to_string())
124 }
125 }
126
127 fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
129 let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
130 let complex_symmetric_set = basis
131 .iter()
132 .map(|det_res| det_res.map(|det| det.complex_symmetric()))
133 .collect::<Result<HashSet<_>, _>>()
134 .map_err(|err| err.to_string())?;
135 if complex_symmetric_set.len() == 1 {
136 complex_symmetric_set
137 .into_iter()
138 .next()
139 .ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
140 } else {
141 Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
142 }
143 }
144}
145
146impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
147where
148 T: ComplexFloat + Lapack,
149 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
150 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
151{
152 pub(crate) fn builder() -> MultiDeterminantBuilder<'a, T, B, SC> {
154 MultiDeterminantBuilder::default()
155 }
156
157 pub fn structure_constraint(&self) -> SC {
159 self.basis
160 .iter()
161 .next()
162 .expect("No basis determinant found.")
163 .expect("No basis determinant found.")
164 .structure_constraint()
165 .clone()
166 }
167}
168
169impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
170where
171 T: ComplexFloat + Lapack,
172 SC: StructureConstraint + Hash + Eq + fmt::Display,
173 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
174{
175 pub fn complex_conjugated(&self) -> bool {
177 self.complex_conjugated
178 }
179
180 pub fn basis(&self) -> &B {
183 &self.basis
184 }
185
186 pub fn coefficients(&self) -> &Array1<T> {
189 &self.coefficients
190 }
191
192 pub fn energy(&self) -> Result<&T, &String> {
194 self.energy.as_ref()
195 }
196
197 pub fn threshold(&self) -> <T as ComplexFloat>::Real {
199 self.threshold
200 }
201}
202
203impl<'a, T, B, SC> fmt::Debug for MultiDeterminant<'a, T, B, SC>
207where
208 T: ComplexFloat + Lapack,
209 SC: StructureConstraint + Hash + Eq + fmt::Display,
210 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
211{
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 write!(
214 f,
215 "MultiDeterminant over {} basis Slater determinants",
216 self.coefficients.len(),
217 )?;
218 Ok(())
219 }
220}
221
222impl<'a, T, B, SC> fmt::Display for MultiDeterminant<'a, T, B, SC>
226where
227 T: ComplexFloat + Lapack,
228 SC: StructureConstraint + Hash + Eq + fmt::Display,
229 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
230{
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 write!(
233 f,
234 "MultiDeterminant over {} basis Slater determinants",
235 self.coefficients.len(),
236 )?;
237 Ok(())
238 }
239}