1use std::collections::HashSet;
4use std::fmt::{self, LowerExp};
5use std::hash::Hash;
6use std::marker::PhantomData;
7
8use anyhow::{ensure, format_err};
9use derive_builder::Builder;
10use itertools::Itertools;
11use log;
12use ndarray::{Array1, Array2, Array3, ArrayView2, Axis, Ix1, Ix3, ScalarOperand, ShapeBuilder};
13use ndarray_einsum::einsum;
14use ndarray_linalg::types::Lapack;
15use num_complex::ComplexFloat;
16use rayon::prelude::*;
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19use crate::group::GroupProperties;
20use crate::target::determinant::SlaterDeterminant;
21use crate::target::noci::backend::nonortho::{
22 calc_lowdin_pairing, calc_o0_matrix_element, calc_transition_density_matrix,
23};
24use crate::target::noci::basis::{EagerBasis, OrbitBasis};
25use crate::target::noci::multideterminant::MultiDeterminant;
26
27use super::basis::Basis;
28
29#[cfg(test)]
30#[path = "multideterminants_tests.rs"]
31mod multideterminants_tests;
32
33#[derive(Builder, Clone)]
40#[builder(build_fn(validate = "Self::validate"))]
41pub struct MultiDeterminants<'a, T, B, SC>
42where
43 T: ComplexFloat + Lapack,
44 SC: StructureConstraint + Hash + Eq + fmt::Display,
45 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
46{
47 #[builder(setter(skip), default = "PhantomData")]
48 _lifetime: PhantomData<&'a ()>,
49
50 #[builder(setter(skip), default = "PhantomData")]
51 _structure_constraint: PhantomData<SC>,
52
53 #[builder(setter(skip), default = "self.complex_symmetric_from_basis()?")]
57 complex_symmetric: bool,
58
59 #[builder(default = "false")]
63 complex_conjugated: bool,
64
65 basis: B,
68
69 coefficients: Array2<T>,
73
74 #[builder(
76 default = "Err(\"Multi-determinantal wavefunction energies not yet set.\".to_string())"
77 )]
78 energies: Result<Array1<T>, String>,
79
80 threshold: <T as ComplexFloat>::Real,
82}
83
84impl<'a, T, B, SC> MultiDeterminantsBuilder<'a, T, B, SC>
89where
90 T: ComplexFloat + Lapack,
91 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
92 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
93{
94 fn validate(&self) -> Result<(), String> {
95 let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
96 let coefficients = self
97 .coefficients
98 .as_ref()
99 .ok_or("No coefficients found.".to_string())?;
100 let nbasis = basis.n_items() == coefficients.nrows();
101 if !nbasis {
102 log::error!(
103 "The number of coefficient rows does not match the number of basis determinants."
104 );
105 }
106
107 let complex_symmetric = basis
108 .iter()
109 .map(|det_res| det_res.map(|det| det.complex_symmetric()))
110 .collect::<Result<HashSet<_>, _>>()
111 .map_err(|err| err.to_string())?
112 .len()
113 == 1;
114 if !complex_symmetric {
115 log::error!("Inconsistent complex-symmetric flag across basis determinants.");
116 }
117
118 let structcons_check = basis
119 .iter()
120 .map(|det_res| det_res.map(|det| det.structure_constraint().clone()))
121 .collect::<Result<HashSet<_>, _>>()
122 .map_err(|err| err.to_string())?
123 .len()
124 == 1;
125 if !structcons_check {
126 log::error!("Inconsistent spin constraints across basis determinants.");
127 }
128
129 if nbasis && structcons_check && complex_symmetric {
130 Ok(())
131 } else {
132 Err("Multi-determinantal wavefunction collection validation failed.".to_string())
133 }
134 }
135
136 fn complex_symmetric_from_basis(&self) -> Result<bool, String> {
138 let basis = self.basis.as_ref().ok_or("No basis found.".to_string())?;
139 let complex_symmetric_set = basis
140 .iter()
141 .map(|det_res| det_res.map(|det| det.complex_symmetric()))
142 .collect::<Result<HashSet<_>, _>>()
143 .map_err(|err| err.to_string())?;
144 if complex_symmetric_set.len() == 1 {
145 complex_symmetric_set
146 .into_iter()
147 .next()
148 .ok_or("Unable to retrieve the complex-symmetric flag from the basis.".to_string())
149 } else {
150 Err("Inconsistent complex-symmetric flag across basis determinants.".to_string())
151 }
152 }
153}
154
155impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
156where
157 T: ComplexFloat + Lapack,
158 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display,
159 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
160{
161 pub fn builder() -> MultiDeterminantsBuilder<'a, T, B, SC> {
163 MultiDeterminantsBuilder::default()
164 }
165
166 pub fn from_multideterminant_vec(
177 mtds: &[&MultiDeterminant<'a, T, B, SC>],
178 ) -> Result<MultiDeterminants<'a, T, B, SC>, anyhow::Error> {
179 log::warn!(
180 "Using basis from the first multi-determinantal wavefunction as the common basis for the collection of multi-determinantal wavefunctions..."
181 );
182 let nmultidets = mtds.len();
183 let dims_set = mtds
184 .iter()
185 .map(|mtd| mtd.basis().n_items())
186 .collect::<HashSet<_>>();
187 let dim = if dims_set.len() == 1 {
188 dims_set
189 .into_iter()
190 .next()
191 .ok_or_else(|| format_err!("Unable to obtain the unique basis size."))
192 } else {
193 Err(format_err!(
194 "Inconsistent basis sizes across the supplied multi-determinantal wavefunctions."
195 ))
196 }?;
197 let coefficients = Array2::from_shape_vec(
198 (dim, nmultidets).f(),
199 mtds.iter()
200 .flat_map(|mtd| mtd.coefficients())
201 .cloned()
202 .collect::<Vec<_>>(),
203 )
204 .map_err(|err| format_err!(err))?;
205
206 let (basis, threshold) = mtds
207 .first()
208 .map(|mtd| (mtd.basis().clone(), mtd.threshold()))
209 .ok_or_else(|| {
210 format_err!("Unable to access the first multi-determinantal wavefunction.")
211 })?;
212
213 MultiDeterminants::builder()
214 .basis(basis)
215 .coefficients(coefficients)
216 .threshold(threshold)
217 .build()
218 .map_err(|err| format_err!(err))
219 }
220
221 pub fn structure_constraint(&self) -> SC {
223 self.basis
224 .iter()
225 .next()
226 .expect("No basis determinant found.")
227 .expect("No basis determinant found.")
228 .structure_constraint()
229 .clone()
230 }
231
232 pub fn iter(&self) -> impl Iterator {
234 let energies = self
235 .energies
236 .as_ref()
237 .map(|energies| energies.mapv(|v| Ok(v)))
238 .unwrap_or(Array1::from_elem(
239 self.coefficients.ncols(),
240 Err("Multi-determinantal energy not available.".to_string()),
241 ));
242 self.coefficients
243 .columns()
244 .into_iter()
245 .zip(energies)
246 .map(|(c, e)| {
247 MultiDeterminant::builder()
248 .complex_conjugated(self.complex_conjugated)
249 .basis(self.basis().clone())
250 .coefficients(c.to_owned())
251 .energy(e)
252 .threshold(self.threshold)
253 .build()
254 .map_err(|err| format_err!(err))
255 })
256 }
257}
258
259impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
260where
261 T: ComplexFloat + Lapack,
262 SC: StructureConstraint + Hash + Eq + fmt::Display,
263 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
264{
265 pub fn complex_conjugated(&self) -> bool {
268 self.complex_conjugated
269 }
270
271 pub fn complex_symmetric(&self) -> bool {
274 self.complex_symmetric
275 }
276
277 pub fn basis(&self) -> &B {
280 &self.basis
281 }
282
283 pub fn coefficients(&self) -> &Array2<T> {
286 &self.coefficients
287 }
288
289 pub fn energies(&self) -> Result<&Array1<T>, &String> {
291 self.energies.as_ref()
292 }
293
294 pub fn threshold(&self) -> <T as ComplexFloat>::Real {
296 self.threshold
297 }
298}
299
300impl<'a, T, G, SC> MultiDeterminants<'a, T, OrbitBasis<'a, G, SlaterDeterminant<'a, T, SC>>, SC>
305where
306 T: ComplexFloat + Lapack,
307 G: GroupProperties + Clone,
308 SC: StructureConstraint + Hash + Eq + fmt::Display + Clone,
309{
310 #[allow(clippy::type_complexity)]
313 pub fn to_eager_basis(
314 &self,
315 ) -> Result<MultiDeterminants<'a, T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>, anyhow::Error>
316 {
317 MultiDeterminants::<T, EagerBasis<SlaterDeterminant<'a, T, SC>>, SC>::builder()
318 .complex_conjugated(self.complex_conjugated)
319 .basis(self.basis.to_eager()?)
320 .coefficients(self.coefficients().clone())
321 .energies(self.energies.clone())
322 .threshold(self.threshold)
323 .build()
324 .map_err(|err| format_err!(err))
325 }
326}
327
328impl<'a, T, B, SC> MultiDeterminants<'a, T, B, SC>
333where
334 T: ComplexFloat + Lapack + ScalarOperand + Send + Sync,
335 <T as ComplexFloat>::Real: LowerExp + fmt::Display + Sync,
336 SC: StructureConstraint + Hash + Eq + Clone + fmt::Display + Sync,
337 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone + Sync,
338 SlaterDeterminant<'a, T, SC>: Send + Sync,
339{
340 pub fn density_matrices(
363 &self,
364 sao: &ArrayView2<T>,
365 thresh_offdiag: <T as ComplexFloat>::Real,
366 thresh_zeroov: <T as ComplexFloat>::Real,
367 normalised_wavefunctions: bool,
368 ) -> Result<Array3<T>, anyhow::Error> {
369 let nao = sao.nrows();
370 let dets = self.basis().iter().collect::<Result<Vec<_>, _>>()?;
371 let nmultidets = self.coefficients.ncols();
372 let sqnorms_denmats_res = dets.iter()
373 .zip(self.coefficients().rows())
374 .cartesian_product(dets.iter().zip(self.coefficients().rows()))
375 .par_bridge()
376 .fold(
377 || Ok((Array1::<T>::zeros(nmultidets), Array3::<T>::zeros((nmultidets, nao, nao)))),
378 |acc_res, ((det_w, c_wm), (det_x, c_xm))| {
379 ensure!(
380 det_w.structure_constraint() == det_x.structure_constraint(),
381 "Inconsistent spin constraints: {} != {}.",
382 det_w.structure_constraint(),
383 det_x.structure_constraint(),
384 );
385
386 if det_w.complex_symmetric() != det_x.complex_symmetric() {
387 return Err(format_err!(
388 "The `complex_symmetric` booleans of the specified determinants do not match: `det_w` (`{}`) != `det_x` (`{}`).",
389 det_w.complex_symmetric(),
390 det_x.complex_symmetric(),
391 ));
392 }
393 let complex_symmetric = det_w.complex_symmetric();
394 let lowdin_paired_coefficientss = det_w
395 .coefficients()
396 .iter()
397 .zip(det_w.occupations().iter())
398 .zip(det_x.coefficients().iter().zip(det_x.occupations().iter()))
399 .map(|((cw, occw), (cx, occx))| {
400 let occw_indices = occw
401 .iter()
402 .enumerate()
403 .filter_map(|(i, occ_i)| {
404 if occ_i.abs() >= det_w.threshold() {
405 Some(i)
406 } else {
407 None
408 }
409 })
410 .collect::<Vec<_>>();
411 let ne_w = occw_indices.len();
412 let cw_occ = cw.select(Axis(1), &occw_indices);
413 let occx_indices = occx
414 .iter()
415 .enumerate()
416 .filter_map(|(i, occ_i)| {
417 if occ_i.abs() >= det_x.threshold() {
418 Some(i)
419 } else {
420 None
421 }
422 })
423 .collect::<Vec<_>>();
424 let ne_x = occx_indices.len();
425 ensure!(ne_w == ne_x, "Inconsistent number of electrons: {ne_w} != {ne_x}.");
426 let cx_occ = cx.select(Axis(1), &occx_indices);
427 calc_lowdin_pairing(
428 &cw_occ.view(),
429 &cx_occ.view(),
430 sao,
431 complex_symmetric,
432 thresh_offdiag,
433 thresh_zeroov,
434 )
435 })
436 .collect::<Result<Vec<_>, _>>()?;
437 let den_wx = calc_transition_density_matrix(&lowdin_paired_coefficientss, &self.structure_constraint())?;
438 let ov_wx = calc_o0_matrix_element(&lowdin_paired_coefficientss, T::one(), &self.structure_constraint())?;
439
440 let c_wm = if self.complex_conjugated() {
441 if complex_symmetric { c_wm.map(|v| v.conj()) } else { c_wm.to_owned() }
442 } else if complex_symmetric { c_wm.to_owned() } else { c_wm.map(|v| v.conj()) };
443 let c_xm = if self.complex_conjugated() {
444 c_xm.map(|v| v.conj())
445 } else {
446 c_xm.to_owned()
447 };
448 let den_wx = if self.complex_conjugated() {
449 den_wx.mapv(|v| v.conj())
450 } else {
451 den_wx
452 };
453 acc_res.and_then(|(sqnorm_acc, denmat_acc)| {
454 let ov_wx_m = einsum("m,m->m", &[&c_wm.view(), &c_xm.view()])
455 .map_err(|err| format_err!(err))?
456 .into_dimensionality::<Ix1>()
457 .map_err(|err| format_err!(err))?
458 .mapv(|v| v * ov_wx);
459 let denmat_wx_mij = einsum("ij,m,m->mij", &[&den_wx.view(), &c_wm.view(), &c_xm.view()])
460 .map_err(|err| format_err!(err))?
461 .into_dimensionality::<Ix3>()
462 .map_err(|err| format_err!(err))?;
463 Ok((sqnorm_acc + ov_wx_m, denmat_acc + denmat_wx_mij))
464 })
465 },
466 )
467 .reduce(
468 || Ok((Array1::<T>::zeros(nmultidets), Array3::<T>::zeros((nmultidets, nao, nao)))),
469 |sqnorms_denmats_res_a: Result<(Array1<T>, Array3<T>), anyhow::Error>, sqnorms_denmats_res_b: Result<(Array1<T>, Array3<T>), anyhow::Error>| {
470 sqnorms_denmats_res_a.and_then(|(sqnorm_acc, denmat_acc)| sqnorms_denmats_res_b.map(|(sqnorm, denmat)| {
471 (
472 sqnorm_acc + sqnorm,
473 denmat_acc + denmat
474 )
475 }))
476 }
477 );
478 sqnorms_denmats_res.and_then(|(sqnorms, denmats)| {
479 if normalised_wavefunctions {
480 let sqnorms_inv = sqnorms.mapv(|v| T::one() / v);
481 einsum("m,mij->mij", &[&sqnorms_inv.view(), &denmats.view()])
482 .map_err(|err| format_err!(err))?
483 .into_dimensionality::<Ix3>()
484 .map_err(|err| format_err!(err))
485 } else {
486 Ok(denmats)
487 }
488 })
489 }
490}
491
492impl<'a, T, B, SC> fmt::Debug for MultiDeterminants<'a, T, B, SC>
496where
497 T: ComplexFloat + Lapack,
498 SC: StructureConstraint + Hash + Eq + fmt::Display,
499 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
500{
501 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
502 write!(
503 f,
504 "MultiDeterminant collection over {} basis Slater determinants",
505 self.coefficients.len(),
506 )?;
507 Ok(())
508 }
509}
510
511impl<'a, T, B, SC> fmt::Display for MultiDeterminants<'a, T, B, SC>
515where
516 T: ComplexFloat + Lapack,
517 SC: StructureConstraint + Hash + Eq + fmt::Display,
518 B: Basis<SlaterDeterminant<'a, T, SC>> + Clone,
519{
520 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
521 write!(
522 f,
523 "MultiDeterminant collection over {} basis Slater determinants",
524 self.coefficients.len(),
525 )?;
526 Ok(())
527 }
528}