1use std::collections::HashSet;
4use std::fmt::{self, LowerExp};
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use anyhow::{ensure, format_err};
9use derive_builder::Builder;
10use itertools::Itertools;
11use log;
12use ndarray::{Array1, Array2, Array3, ArrayView2, Axis, Ix1, Ix3, ScalarOperand, ShapeBuilder};
13use ndarray_einsum::einsum;
14use ndarray_linalg::types::Lapack;
15use num_complex::ComplexFloat;
16use rayon::prelude::*;
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19use crate::group::GroupProperties;
20use crate::target::determinant::SlaterDeterminant;
21use crate::target::noci::backend::nonortho::{calc_lowdin_pairing, calc_transition_density_matrix};
22use crate::target::noci::basis::{EagerBasis, OrbitBasis};
23use crate::target::noci::multideterminant::MultiDeterminant;
24
25use super::basis::Basis;
26
27#[cfg(test)]
28#[path = "multideterminants_tests.rs"]
29mod multideterminants_tests;
30
31#[derive(Builder, Clone)]
38#[builder(build_fn(validate = "Self::validate"))]
39pub struct MultiDeterminants<'a, T, B, SC>
40where
41 T: ComplexFloat + Lapack,
42 SC: StructureConstraint + Hash + Eq + fmt::Display,
43 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
44{
45 #[builder(setter(skip), default = "PhantomData")]
46 _lifetime: PhantomData<&'a ()>,
47
48 #[builder(setter(skip), default = "PhantomData")]
49 _structure_constraint: PhantomData<SC>,
50
51 #[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
55 complex_symmetric: bool,
56
57 #[builder(default = "false")]
61 complex_conjugated: bool,
62
63 basis: B,
66
67 coefficients: Array2<T>,
71
72 #[builder(
74 default = "Err(\"Multi-determinantal wavefunction energies not yet set.\".to_string())"
75 )]
76 energies: Result<Array1<T>, String>,
77
78 threshold: <T as ComplexFloat>::Real,
80}
81
82impl<'a, T, B, SC> MultiDeterminantsBuilder<'a, T, B, SC>
87where
88 T: ComplexFloat + Lapack,
89 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
90 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
91{
92 fn validate(&self) -> Result<(), String> {
93 let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
94 let coefficients = self
95 .coefficients
96 .as_ref()
97 .ok_or("No coefficients found.".to_string())?;
98 let nbasis = basis.n_items() == coefficients.nrows();
99 if !nbasis {
100 log::error!(
101 "The number of coefficient rows does not match the number of basis determinants."
102 );
103 }
104
105 let complex_symmetric = basis
106 .iter()
107 .map(|det_res| det_res.map(|det| det.complex_symmetric()))
108 .collect::<Result<HashSet<_>, _>>()
109 .map_err(|err| err.to_string())?
110 .len()
111 == 1;
112 if !complex_symmetric {
113 log::error!("Inconsistent complex-symmetric flag across basis determinants.");
114 }
115
116 let structcons_check = basis
117 .iter()
118 .map(|det_res| det_res.map(|det| det.structure_constraint().clone()))
119 .collect::<Result<HashSet<_>, _>>()
120 .map_err(|err| err.to_string())?
121 .len()
122 == 1;
123 if !structcons_check {
124 log::error!("Inconsistent spin constraints across basis determinants.");
125 }
126
127 if nbasis && structcons_check && complex_symmetric {
128 Ok(())
129 } else {
130 Err("Multi-determinantal wavefunction collection validation failed.".to_string())
131 }
132 }
133
134 fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
136 let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
137 let complex_symmetric_set = basis
138 .iter()
139 .map(|det_res| det_res.map(|det| det.complex_symmetric()))
140 .collect::<Result<HashSet<_>, _>>()
141 .map_err(|err| err.to_string())?;
142 if complex_symmetric_set.len() == 1 {
143 complex_symmetric_set
144 .into_iter()
145 .next()
146 .ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
147 } else {
148 Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
149 }
150 }
151}
152
153impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
154where
155 T: ComplexFloat + Lapack,
156 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
157 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
158{
159 pub fn builder() -> MultiDeterminantsBuilder<'a, T, B, SC> {
161 MultiDeterminantsBuilder::default()
162 }
163
164 pub fn from_multideterminant_vec(
175 mtds: &[&MultiDeterminant<'a, T, B, SC>],
176 ) -> Result<MultiDeterminants<'a, T, B, SC>, anyhow::Error> {
177 log::warn!(
178 "Using basis from the first multi-determinantal wavefunction as the common basis for the collection of multi-determinantal wavefunctions..."
179 );
180 let nmultidets = mtds.len();
181 let dims_set = mtds
182 .iter()
183 .map(|mtd| mtd.basis().n_items())
184 .collect::<HashSet<_>>();
185 let dim = if dims_set.len() == 1 {
186 dims_set
187 .into_iter()
188 .next()
189 .ok_or_else(|| format_err!("Unable to obtain the unique basis size."))
190 } else {
191 Err(format_err!(
192 "Inconsistent basis sizes across the supplied multi-determinantal wavefunctions."
193 ))
194 }?;
195 let coefficients = Array2::from_shape_vec(
196 (dim, nmultidets).f(),
197 mtds.iter()
198 .flat_map(|mtd| mtd.coefficients())
199 .cloned()
200 .collect::<Vec<_>>(),
201 )
202 .map_err(|err| format_err!(err))?;
203
204 let (basis, threshold) = mtds
205 .first()
206 .map(|mtd| (mtd.basis().clone(), mtd.threshold()))
207 .ok_or_else(|| {
208 format_err!("Unable to access the first multi-determinantal wavefunction.")
209 })?;
210
211 MultiDeterminants::builder()
212 .basis(basis)
213 .coefficients(coefficients)
214 .threshold(threshold)
215 .build()
216 .map_err(|err| format_err!(err))
217 }
218
219 pub fn structure_constraint(&self) -> SC {
221 self.basis
222 .iter()
223 .next()
224 .expect("No basis determinant found.")
225 .expect("No basis determinant found.")
226 .structure_constraint()
227 .clone()
228 }
229
230 pub fn iter(&self) -> impl Iterator {
232 let energies = self
233 .energies
234 .as_ref()
235 .map(|energies| energies.mapv(|v| Ok(v)))
236 .unwrap_or(Array1::from_elem(
237 self.coefficients.ncols(),
238 Err("Multi-determinantal energy not available.".to_string()),
239 ));
240 self.coefficients
241 .columns()
242 .into_iter()
243 .zip(energies)
244 .map(|(c, e)| {
245 MultiDeterminant::builder()
246 .complex_conjugated(self.complex_conjugated)
247 .basis(self.basis().clone())
248 .coefficients(c.to_owned())
249 .energy(e)
250 .threshold(self.threshold)
251 .build()
252 .map_err(|err| format_err!(err))
253 })
254 }
255}
256
257impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
258where
259 T: ComplexFloat + Lapack,
260 SC: StructureConstraint + Hash + Eq + fmt::Display,
261 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
262{
263 pub fn complex_conjugated(&self) -> bool {
266 self.complex_conjugated
267 }
268
269 pub fn complex_symmetric(&self) -> bool {
272 self.complex_symmetric
273 }
274
275 pub fn basis(&self) -> &B {
278 &self.basis
279 }
280
281 pub fn coefficients(&self) -> &Array2<T> {
284 &self.coefficients
285 }
286
287 pub fn energies(&self) -> Result<&Array1<T>, &String> {
289 self.energies.as_ref()
290 }
291
292 pub fn threshold(&self) -> <T as ComplexFloat>::Real {
294 self.threshold
295 }
296}
297
298impl<'a, T, G, SC> MultiDeterminants<'a, T, OrbitBasis<'a, G, SlaterDeterminant<'a, T, SC>>, SC>
303where
304 T: ComplexFloat + Lapack,
305 G: GroupProperties + Clone,
306 SC: StructureConstraint + Hash + Eq + fmt::Display + Clone,
307{
308 #[allow(clippy::type_complexity)]
311 pub fn to_eager_basis(
312 &self,
313 ) -> Result<MultiDeterminants<'a, T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>, anyhow::Error>
314 {
315 MultiDeterminants::<T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>::builder()
316 .complex_conjugated(self.complex_conjugated)
317 .basis(self.basis.to_eager()?)
318 .coefficients(self.coefficients().clone())
319 .energies(self.energies.clone())
320 .threshold(self.threshold)
321 .build()
322 .map_err(|err| format_err!(err))
323 }
324}
325
326impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
331where
332 T: ComplexFloat + Lapack + ScalarOperand + Send + Sync,
333 <T as ComplexFloat>::Real: LowerExp + fmt::Display + Sync,
334 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display + Sync,
335 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone + Sync,
336 SlaterDeterminant<'a, T, SC>: Send + Sync,
337{
338 pub fn density_matrices(
361 &self,
362 sao: &ArrayView2<T>,
363 thresh_offdiag: <T as ComplexFloat>::Real,
364 thresh_zeroov: <T as ComplexFloat>::Real,
365 normalised_wavefunctions: bool,
366 ) -> Result<Array3<T>, anyhow::Error> {
367 let nao = sao.nrows();
368 let dets = self.basis().iter().collect::<Result<Vec<_>, _>>()?;
369 let nmultidets = self.coefficients.ncols();
370 let sqnorms_denmats_res = dets.iter()
371 .zip(self.coefficients().rows())
372 .cartesian_product(dets.iter().zip(self.coefficients().rows()))
373 .par_bridge()
374 .fold(
375 || Ok((Array1::<T>::zeros(nmultidets), Array3::<T>::zeros((nmultidets, nao, nao)))),
376 |acc_res, ((det_w, c_wm), (det_x, c_xm))| {
377 ensure!(
378 det_w.structure_constraint() == det_x.structure_constraint(),
379 "Inconsistent spin constraints: {} != {}.",
380 det_w.structure_constraint(),
381 det_x.structure_constraint(),
382 );
383
384 if det_w.complex_symmetric() != det_x.complex_symmetric() {
385 return Err(format_err!(
386 "The `complex_symmetric` booleans of the specified determinants do not match: `det_w` (`{}`) != `det_x` (`{}`).",
387 det_w.complex_symmetric(),
388 det_x.complex_symmetric(),
389 ));
390 }
391 let complex_symmetric = det_w.complex_symmetric();
392 let lowdin_paired_coefficientss = det_w
393 .coefficients()
394 .iter()
395 .zip(det_w.occupations().iter())
396 .zip(det_x.coefficients().iter().zip(det_x.occupations().iter()))
397 .map(|((cw, occw), (cx, occx))| {
398 let occw_indices = occw
399 .iter()
400 .enumerate()
401 .filter_map(|(i, occ_i)| {
402 if occ_i.abs() >= det_w.threshold() {
403 Some(i)
404 } else {
405 None
406 }
407 })
408 .collect::<Vec<_>>();
409 let ne_w = occw_indices.len();
410 let cw_occ = cw.select(Axis(1), &occw_indices);
411 let occx_indices = occx
412 .iter()
413 .enumerate()
414 .filter_map(|(i, occ_i)| {
415 if occ_i.abs() >= det_x.threshold() {
416 Some(i)
417 } else {
418 None
419 }
420 })
421 .collect::<Vec<_>>();
422 let ne_x = occx_indices.len();
423 ensure!(ne_w == ne_x, "Inconsistent number of electrons: {ne_w} != {ne_x}.");
424 let cx_occ = cx.select(Axis(1), &occx_indices);
425 calc_lowdin_pairing(
426 &cw_occ.view(),
427 &cx_occ.view(),
428 sao,
429 complex_symmetric,
430 thresh_offdiag,
431 thresh_zeroov,
432 )
433 })
434 .collect::<Result<Vec<_>, _>>()?;
435 let den_wx = calc_transition_density_matrix(&lowdin_paired_coefficientss, &self.structure_constraint())?;
436 let ov_wx = lowdin_paired_coefficientss
437 .iter()
438 .flat_map(|lpc| lpc.lowdin_overlaps().iter())
439 .fold(T::one(), |acc, ov| acc * *ov);
440
441 let c_wm = if self.complex_conjugated() {
442 if complex_symmetric { c_wm.map(|v| v.conj()) } else { c_wm.to_owned() }
443 } else if complex_symmetric { c_wm.to_owned() } else { c_wm.map(|v| v.conj()) };
444 let c_xm = if self.complex_conjugated() {
445 c_xm.map(|v| v.conj())
446 } else {
447 c_xm.to_owned()
448 };
449 let den_wx = if self.complex_conjugated() {
450 den_wx.mapv(|v| v.conj())
451 } else {
452 den_wx
453 };
454 acc_res.and_then(|(sqnorm_acc, denmat_acc)| {
455 let ov_wx_m = einsum("m,m->m", &[&c_wm.view(), &c_xm.view()])
456 .map_err(|err| format_err!(err))?
457 .into_dimensionality::<Ix1>()
458 .map_err(|err| format_err!(err))?
459 .mapv(|v| v * ov_wx);
460 let denmat_wx_mij = einsum("ij,m,m->mij", &[&den_wx.view(), &c_wm.view(), &c_xm.view()])
461 .map_err(|err| format_err!(err))?
462 .into_dimensionality::<Ix3>()
463 .map_err(|err| format_err!(err))?;
464 Ok((sqnorm_acc + ov_wx_m, denmat_acc + denmat_wx_mij))
465 })
466 },
467 )
468 .reduce(
469 || Ok((Array1::<T>::zeros(nmultidets), Array3::<T>::zeros((nmultidets, nao, nao)))),
470 |sqnorms_denmats_res_a: Result<(Array1<T>, Array3<T>), anyhow::Error>, sqnorms_denmats_res_b: Result<(Array1<T>, Array3<T>), anyhow::Error>| {
471 sqnorms_denmats_res_a.and_then(|(sqnorm_acc, denmat_acc)| sqnorms_denmats_res_b.map(|(sqnorm, denmat)| {
472 (
473 sqnorm_acc + sqnorm,
474 denmat_acc + denmat
475 )
476 }))
477 }
478 );
479 sqnorms_denmats_res.and_then(|(sqnorms, denmats)| {
480 if normalised_wavefunctions {
481 let sqnorms_inv = sqnorms.mapv(|v| T::one() / v);
482 einsum("m,mij->mij", &[&sqnorms_inv.view(), &denmats.view()])
483 .map_err(|err| format_err!(err))?
484 .into_dimensionality::<Ix3>()
485 .map_err(|err| format_err!(err))
486 } else {
487 Ok(denmats)
488 }
489 })
490 }
491}
492
493impl<'a, T, B, SC> fmt::Debug for MultiDeterminants<'a, T, B, SC>
497where
498 T: ComplexFloat + Lapack,
499 SC: StructureConstraint + Hash + Eq + fmt::Display,
500 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
501{
502 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
503 write!(
504 f,
505 "MultiDeterminant collection over {} basis Slater determinants",
506 self.coefficients.len(),
507 )?;
508 Ok(())
509 }
510}
511
512impl<'a, T, B, SC> fmt::Display for MultiDeterminants<'a, T, B, SC>
516where
517 T: ComplexFloat + Lapack,
518 SC: StructureConstraint + Hash + Eq + fmt::Display,
519 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
520{
521 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522 write!(
523 f,
524 "MultiDeterminant collection over {} basis Slater determinants",
525 self.coefficients.len(),
526 )?;
527 Ok(())
528 }
529}