1use std::collections::HashSet;
4use std::fmt::{self, LowerExp};
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use anyhow::{ensure, format_err};
9use derive_builder::Builder;
10use itertools::Itertools;
11use log;
12use ndarray::{Array1, Array2, ArrayView2, Axis, ScalarOperand};
13use ndarray_linalg::types::Lapack;
14use num_complex::ComplexFloat;
15use rayon::prelude::*;
16
17use crate::angmom::spinor_rotation_3d::StructureConstraint;
18use crate::group::GroupProperties;
19use crate::target::determinant::SlaterDeterminant;
20use crate::target::noci::backend::nonortho::{calc_lowdin_pairing, calc_transition_density_matrix};
21use crate::target::noci::basis::{EagerBasis, OrbitBasis};
22
23use super::basis::Basis;
24
25#[path = "multideterminant_transformation.rs"]
26pub(crate) mod multideterminant_transformation;
27
28#[path = "multideterminant_analysis.rs"]
29pub(crate) mod multideterminant_analysis;
30
31#[cfg(test)]
32#[path = "multideterminant_tests.rs"]
33mod multideterminant_tests;
34
35#[derive(Builder, Clone)]
41#[builder(build_fn(validate = "Self::validate"))]
42pub struct MultiDeterminant<'a, T, B, SC>
43where
44 T: ComplexFloat + Lapack,
45 SC: StructureConstraint + Hash + Eq + fmt::Display,
46 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
47{
48 #[builder(setter(skip), default = "PhantomData")]
49 _lifetime: PhantomData<&'a ()>,
50
51 #[builder(setter(skip), default = "PhantomData")]
52 _structure_constraint: PhantomData<SC>,
53
54 #[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
57 complex_symmetric: bool,
58
59 #[builder(default = "false")]
62 complex_conjugated: bool,
63
64 basis: B,
66
67 coefficients: Array1<T>,
70
71 #[builder(
73 default = "Err(\"Multi-determinantal wavefunction energy not yet set.\".to_string())"
74 )]
75 energy: Result<T, String>,
76
77 threshold: <T as ComplexFloat>::Real,
79}
80
81impl<'a, T, B, SC> MultiDeterminantBuilder<'a, T, B, SC>
86where
87 T: ComplexFloat + Lapack,
88 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
89 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
90{
91 fn validate(&self) -> Result<(), String> {
92 let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
93 let coefficients = self
94 .coefficients
95 .as_ref()
96 .ok_or("No coefficients found.".to_string())?;
97 let nbasis = basis.n_items() == coefficients.len();
98 if !nbasis {
99 log::error!(
100 "The number of coefficients does not match the number of basis determinants."
101 );
102 }
103
104 let complex_symmetric = basis
105 .iter()
106 .map(|det_res| det_res.map(|det| det.complex_symmetric()))
107 .collect::<Result<HashSet<_>, _>>()
108 .map_err(|err| err.to_string())?
109 .len()
110 == 1;
111 if !complex_symmetric {
112 log::error!("Inconsistent complex-symmetric flag across basis determinants.");
113 }
114
115 let structcons_check = basis
116 .iter()
117 .map(|det_res| det_res.map(|det| det.structure_constraint().clone()))
118 .collect::<Result<HashSet<_>, _>>()
119 .map_err(|err| err.to_string())?
120 .len()
121 == 1;
122 if !structcons_check {
123 log::error!("Inconsistent spin constraints across basis determinants.");
124 }
125
126 if nbasis && structcons_check && complex_symmetric {
127 Ok(())
128 } else {
129 Err("Multi-determinant wavefunction validation failed.".to_string())
130 }
131 }
132
133 fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
135 let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
136 let complex_symmetric_set = basis
137 .iter()
138 .map(|det_res| det_res.map(|det| det.complex_symmetric()))
139 .collect::<Result<HashSet<_>, _>>()
140 .map_err(|err| err.to_string())?;
141 if complex_symmetric_set.len() == 1 {
142 complex_symmetric_set
143 .into_iter()
144 .next()
145 .ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
146 } else {
147 Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
148 }
149 }
150}
151
152impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
153where
154 T: ComplexFloat + Lapack,
155 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
156 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
157{
158 pub fn builder() -> MultiDeterminantBuilder<'a, T, B, SC> {
160 MultiDeterminantBuilder::default()
161 }
162
163 pub fn structure_constraint(&self) -> SC {
165 self.basis
166 .iter()
167 .next()
168 .expect("No basis determinant found.")
169 .expect("No basis determinant found.")
170 .structure_constraint()
171 .clone()
172 }
173}
174
175impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
176where
177 T: ComplexFloat + Lapack,
178 SC: StructureConstraint + Hash + Eq + fmt::Display,
179 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
180{
181 pub fn complex_conjugated(&self) -> bool {
183 self.complex_conjugated
184 }
185
186 pub fn complex_symmetric(&self) -> bool {
188 self.complex_symmetric
189 }
190
191 pub fn basis(&self) -> &B {
194 &self.basis
195 }
196
197 pub fn coefficients(&self) -> &Array1<T> {
200 &self.coefficients
201 }
202
203 pub fn energy(&self) -> Result<&T, &String> {
205 self.energy.as_ref()
206 }
207
208 pub fn threshold(&self) -> <T as ComplexFloat>::Real {
210 self.threshold
211 }
212}
213
214impl<'a, T, G, SC> MultiDeterminant<'a, T, OrbitBasis<'a, G, SlaterDeterminant<'a, T, SC>>, SC>
219where
220 T: ComplexFloat + Lapack,
221 G: GroupProperties + Clone,
222 SC: StructureConstraint + Hash + Eq + fmt::Display + Clone,
223{
224 #[allow(clippy::type_complexity)]
227 pub fn to_eager_basis(
228 &self,
229 ) -> Result<MultiDeterminant<'a, T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>, anyhow::Error>
230 {
231 MultiDeterminant::<T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>::builder()
232 .complex_conjugated(self.complex_conjugated)
233 .basis(self.basis.to_eager()?)
234 .coefficients(self.coefficients().clone())
235 .energy(self.energy.clone())
236 .threshold(self.threshold)
237 .build()
238 .map_err(|err| format_err!(err))
239 }
240}
241
242impl<'a, T, B, SC> MultiDeterminant<'a, T, B, SC>
247where
248 T: ComplexFloat + Lapack + ScalarOperand + Send + Sync,
249 <T as ComplexFloat>::Real: LowerExp + fmt::Display + Sync,
250 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display + Sync,
251 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone + Sync,
252 SlaterDeterminant<'a, T, SC>: Send + Sync,
253{
254 pub fn density_matrix(
271 &self,
272 sao: &ArrayView2<T>,
273 thresh_offdiag: <T as ComplexFloat>::Real,
274 thresh_zeroov: <T as ComplexFloat>::Real,
275 normalised_wavefunction: bool,
276 ) -> Result<Array2<T>, anyhow::Error> {
277 let nao = sao.nrows();
278 let dets = self.basis().iter().collect::<Result<Vec<_>, _>>()?;
279 let sqnorm_denmat_res = dets.iter()
280 .zip(self.coefficients().iter())
281 .cartesian_product(dets.iter().zip(self.coefficients().iter()))
282 .par_bridge()
283 .fold(
284 || Ok((T::zero(), Array2::<T>::zeros((nao, nao)))),
285 |acc_res, ((det_w, c_w), (det_x, c_x))| {
286 ensure!(
287 det_w.structure_constraint() == det_x.structure_constraint(),
288 "Inconsistent spin constraints: {} != {}.",
289 det_w.structure_constraint(),
290 det_x.structure_constraint(),
291 );
292
293 if det_w.complex_symmetric() != det_x.complex_symmetric() {
294 return Err(format_err!(
295 "The `complex_symmetric` booleans of the specified determinants do not match: `det_w` (`{}`) != `det_x` (`{}`).",
296 det_w.complex_symmetric(),
297 det_x.complex_symmetric(),
298 ));
299 }
300 let complex_symmetric = det_w.complex_symmetric();
301 let lowdin_paired_coefficientss = det_w
302 .coefficients()
303 .iter()
304 .zip(det_w.occupations().iter())
305 .zip(det_x.coefficients().iter().zip(det_x.occupations().iter()))
306 .map(|((cw, occw), (cx, occx))| {
307 let occw_indices = occw
308 .iter()
309 .enumerate()
310 .filter_map(|(i, occ_i)| {
311 if occ_i.abs() >= det_w.threshold() {
312 Some(i)
313 } else {
314 None
315 }
316 })
317 .collect::<Vec<_>>();
318 let ne_w = occw_indices.len();
319 let cw_occ = cw.select(Axis(1), &occw_indices);
320 let occx_indices = occx
321 .iter()
322 .enumerate()
323 .filter_map(|(i, occ_i)| {
324 if occ_i.abs() >= det_x.threshold() {
325 Some(i)
326 } else {
327 None
328 }
329 })
330 .collect::<Vec<_>>();
331 let ne_x = occx_indices.len();
332 ensure!(ne_w == ne_x, "Inconsistent number of electrons: {ne_w} != {ne_x}.");
333 let cx_occ = cx.select(Axis(1), &occx_indices);
334 calc_lowdin_pairing(
335 &cw_occ.view(),
336 &cx_occ.view(),
337 sao,
338 complex_symmetric,
339 thresh_offdiag,
340 thresh_zeroov,
341 )
342 })
343 .collect::<Result<Vec<_>, _>>()?;
344 let den_wx = calc_transition_density_matrix(&lowdin_paired_coefficientss, &self.structure_constraint())?;
345 let ov_wx = lowdin_paired_coefficientss
346 .iter()
347 .flat_map(|lpc| lpc.lowdin_overlaps().iter())
348 .fold(T::one(), |acc, ov| acc * *ov);
349
350 let c_w = if self.complex_conjugated() {
351 if complex_symmetric { c_w.conj() } else { *c_w }
352 } else if complex_symmetric { *c_w } else { c_w.conj() };
353 let c_x = if self.complex_conjugated() {
354 c_x.conj()
355 } else {
356 *c_x
357 };
358 let den_wx = if self.complex_conjugated() {
359 den_wx.mapv(|v| v.conj())
360 } else {
361 den_wx
362 };
363 acc_res.map(|(sqnorm_acc, denmat_acc)| (sqnorm_acc + ov_wx * c_w * c_x, denmat_acc + den_wx * c_w * c_x))
364 },
365 )
366 .reduce(
367 || Ok((T::zero(), Array2::<T>::zeros((nao, nao)))),
368 |sqnorm_denmat_res_a: Result<(T, Array2<T>), anyhow::Error>, sqnorm_denmat_res_b: Result<(T, Array2<T>), anyhow::Error>| {
369 sqnorm_denmat_res_a.and_then(|(sqnorm_acc, denmat_acc)| sqnorm_denmat_res_b.map(|(sqnorm, denmat)| {
370 (
371 sqnorm_acc + sqnorm,
372 denmat_acc + denmat
373 )
374 }))
375 }
376 );
377 sqnorm_denmat_res.map(|(sqnorm, denmat)| {
378 if normalised_wavefunction {
379 denmat / sqnorm
380 } else {
381 denmat
382 }
383 })
384 }
385}
386
387impl<'a, T, B, SC> fmt::Debug for MultiDeterminant<'a, T, B, SC>
391where
392 T: ComplexFloat + Lapack,
393 SC: StructureConstraint + Hash + Eq + fmt::Display,
394 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
395{
396 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
397 write!(
398 f,
399 "MultiDeterminant over {} basis Slater determinants",
400 self.coefficients.len(),
401 )?;
402 Ok(())
403 }
404}
405
406impl<'a, T, B, SC> fmt::Display for MultiDeterminant<'a, T, B, SC>
410where
411 T: ComplexFloat + Lapack,
412 SC: StructureConstraint + Hash + Eq + fmt::Display,
413 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
414{
415 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416 write!(
417 f,
418 "MultiDeterminant over {} basis Slater determinants",
419 self.coefficients.len(),
420 )?;
421 Ok(())
422 }
423}