1use std::collections::HashSet;
4use std::fmt;
5use std::iter::Sum;
6use std::ops::{Add, Index, Sub};
7
8use anyhow::{ensure, format_err};
9use approx;
10use derive_builder::Builder;
11use itertools::Itertools;
12use log;
13use ndarray::Array2;
14use ndarray_linalg::types::Lapack;
15use num_complex::{Complex, ComplexFloat};
16use num_traits::float::{Float, FloatConst};
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19use crate::auxiliary::molecule::Molecule;
20use crate::basis::ao::BasisAngularOrder;
21
22#[cfg(test)]
23mod density_tests;
24
25pub mod density_analysis;
26pub mod density_projection;
27mod density_transformation;
28
29#[derive(Builder, Clone)]
35#[builder(build_fn(validate = "Self::validate"))]
36pub struct Densities<'a, T, SC>
37where
38 T: ComplexFloat + Lapack,
39 SC: StructureConstraint + fmt::Display,
40{
41 structure_constraint: SC,
43
44 densities: Vec<&'a Density<'a, T>>,
46}
47
48impl<'a, T, SC> Densities<'a, T, SC>
49where
50 T: ComplexFloat + Lapack,
51 SC: StructureConstraint + Clone + fmt::Display,
52{
53 pub fn builder() -> DensitiesBuilder<'a, T, SC> {
55 DensitiesBuilder::default()
56 }
57
58 pub fn iter(&self) -> impl Iterator<Item = &Density<'a, T>> {
59 self.densities.iter().cloned()
60 }
61}
62
63impl<'a, T, SC> DensitiesBuilder<'a, T, SC>
64where
65 T: ComplexFloat + Lapack,
66 SC: StructureConstraint + fmt::Display,
67{
68 fn validate(&self) -> Result<(), String> {
69 let densities = self
70 .densities
71 .as_ref()
72 .ok_or("No `densities` found.".to_string())?;
73 let structure_constraint = self
74 .structure_constraint
75 .as_ref()
76 .ok_or("No structure constraint found.".to_string())?;
77 let num_dens = structure_constraint.n_coefficient_matrices()
78 * structure_constraint.n_explicit_comps_per_coefficient_matrix();
79 if densities.len() != num_dens {
80 Err(format!(
81 "{} {} expected in structure constraint {}, but {} found.",
82 num_dens,
83 structure_constraint,
84 if num_dens == 1 {
85 "density"
86 } else {
87 "densities"
88 },
89 densities.len()
90 ))
91 } else {
92 Ok(())
93 }
94 }
95}
96
97impl<'a, T, SC> Index<usize> for Densities<'a, T, SC>
98where
99 T: ComplexFloat + Lapack,
100 SC: StructureConstraint + fmt::Display,
101{
102 type Output = Density<'a, T>;
103
104 fn index(&self, index: usize) -> &Self::Output {
105 self.densities[index]
106 }
107}
108
109#[derive(Builder, Clone)]
111#[builder(build_fn(validate = "Self::validate"))]
112pub struct DensitiesOwned<'a, T, SC>
113where
114 T: ComplexFloat + Lapack,
115 SC: StructureConstraint + fmt::Display,
116{
117 structure_constraint: SC,
119
120 densities: Vec<Density<'a, T>>,
122}
123
124impl<'a, T, SC> DensitiesOwned<'a, T, SC>
125where
126 T: ComplexFloat + Lapack,
127 SC: StructureConstraint + Clone + fmt::Display,
128{
129 pub fn builder() -> DensitiesOwnedBuilder<'a, T, SC> {
131 DensitiesOwnedBuilder::default()
132 }
133
134 pub fn iter(&self) -> impl Iterator<Item = &Density<'a, T>> {
135 self.densities.iter()
136 }
137
138 pub fn calc_extra_densities<'b: 'a>(
140 &'b self,
141 ) -> Result<Vec<(String, Density<'a, T>)>, anyhow::Error> {
142 let nspatials = self
143 .iter()
144 .map(|den| den.bao().n_funcs())
145 .collect::<HashSet<usize>>();
146 ensure!(
147 nspatials.len() == 1,
148 "Inconsistent number of spatial functions."
149 );
150 let nspatial = *nspatials.iter().next().ok_or_else(|| {
151 format_err!(
152 "Unable to retrieve the number of spatial functions of the density matrices."
153 )
154 })?;
155
156 let den0 = self
157 .iter()
158 .next()
159 .ok_or_else(|| format_err!("Unable to retrieve the first density."))?;
160 ensure!(
161 nspatial == den0.density_matrix.nrows(),
162 "Unexpected density matrix dimension: {} != {}",
163 nspatial,
164 den0.density_matrix.nrows()
165 );
166
167 let total_denmat = self
169 .iter()
170 .fold(Array2::<T>::zeros((nspatial, nspatial)), |acc, den| {
171 acc + den.density_matrix()
172 });
173 let total_den = Density::<T>::builder()
174 .density_matrix(total_denmat)
175 .bao(den0.bao())
176 .mol(den0.mol)
177 .complex_symmetric(den0.complex_symmetric())
178 .threshold(den0.threshold())
179 .build()?;
180
181 vec![Ok(("Total density".to_string(), total_den))]
183 .into_iter()
184 .chain((0..self.densities.len()).combinations(2).map(|indices| {
185 let i0 = indices[0];
186 let i1 = indices[1];
187 let denmat_0 = self.densities[i0].density_matrix();
188 let denmat_1 = self.densities[i1].density_matrix();
189 let denmat_01 = denmat_0 - denmat_1;
190 let den_01 = Density::<T>::builder()
191 .density_matrix(denmat_01)
192 .bao(den0.bao())
193 .mol(den0.mol)
194 .complex_symmetric(den0.complex_symmetric())
195 .threshold(den0.threshold())
196 .build()?;
197 Ok((
198 format!("Density (component {i0}) - Density (component {i1})"),
199 den_01,
200 ))
201 }))
202 .collect::<Result<Vec<_>, _>>()
203 }
204}
205
206impl<'b, 'a: 'b, T, SC> DensitiesOwned<'a, T, SC>
207where
208 T: ComplexFloat + Lapack,
209 SC: StructureConstraint + Clone + fmt::Display,
210{
211 pub fn as_ref(&'a self) -> Densities<'b, T, SC> {
212 Densities::builder()
213 .structure_constraint(self.structure_constraint.clone())
214 .densities(self.iter().collect_vec())
215 .build()
216 .expect("Unable to convert `DensitiesOwned` to `Densities`.")
217 }
218}
219
220impl<'a, T, SC> DensitiesOwnedBuilder<'a, T, SC>
221where
222 T: ComplexFloat + Lapack,
223 SC: StructureConstraint + Clone + fmt::Display,
224{
225 fn validate(&self) -> Result<(), String> {
226 let densities = self
227 .densities
228 .as_ref()
229 .ok_or("No `densities` found.".to_string())?;
230 let structure_constraint = self
231 .structure_constraint
232 .as_ref()
233 .ok_or("No spin constraint found.".to_string())?;
234 let num_dens = structure_constraint.n_coefficient_matrices()
235 * structure_constraint.n_explicit_comps_per_coefficient_matrix();
236 if densities.len() != num_dens {
237 Err(format!(
238 "{} {} expected in structure constraint {}, but {} found.",
239 num_dens,
240 structure_constraint,
241 if num_dens == 1 {
242 "density"
243 } else {
244 "densities"
245 },
246 densities.len()
247 ))
248 } else {
249 Ok(())
250 }
251 }
252}
253
254impl<'a, T, SC> Index<usize> for DensitiesOwned<'a, T, SC>
255where
256 T: ComplexFloat + Lapack,
257 SC: StructureConstraint + fmt::Display,
258{
259 type Output = Density<'a, T>;
260
261 fn index(&self, index: usize) -> &Self::Output {
262 &self.densities[index]
263 }
264}
265
266#[derive(Builder, Clone)]
269#[builder(build_fn(validate = "Self::validate"))]
270pub struct Density<'a, T>
271where
272 T: ComplexFloat + Lapack,
273{
274 bao: &'a BasisAngularOrder<'a>,
277
278 complex_symmetric: bool,
281
282 #[builder(default = "false")]
285 complex_conjugated: bool,
286
287 mol: &'a Molecule,
289
290 density_matrix: Array2<T>,
292
293 threshold: <T as ComplexFloat>::Real,
295}
296
297impl<'a, T> DensityBuilder<'a, T>
298where
299 T: ComplexFloat + Lapack,
300{
301 fn validate(&self) -> Result<(), String> {
302 let bao = self
303 .bao
304 .ok_or("No `BasisAngularOrder` found.".to_string())?;
305 let nbas = bao.n_funcs();
306 let density_matrix = self
307 .density_matrix
308 .as_ref()
309 .ok_or("No density matrices found.".to_string())?;
310
311 let denmat_shape = density_matrix.shape() == [nbas, nbas];
312 if !denmat_shape {
313 log::error!(
314 "The density matrix dimensions ({:?}) are incompatible with the basis ({nbas} {}).",
315 density_matrix.shape(),
316 if nbas != 1 { "functions" } else { "function" }
317 );
318 }
319
320 let mol = self.mol.ok_or("No molecule found.".to_string())?;
321 let natoms = mol.atoms.len() == bao.n_atoms();
322 if !natoms {
323 log::error!(
324 "The number of atoms in the molecule does not match the number of local sites in the basis."
325 );
326 }
327 if denmat_shape && natoms {
328 Ok(())
329 } else {
330 Err(format!(
331 "Density validation failed: `denmat_shape`: {denmat_shape}, `natoms`: {natoms}"
332 ))
333 }
334 }
335}
336
337impl<'a, T> Density<'a, T>
338where
339 T: ComplexFloat + Clone + Lapack,
340{
341 pub fn builder() -> DensityBuilder<'a, T> {
343 DensityBuilder::default()
344 }
345
346 pub fn complex_symmetric(&self) -> bool {
348 self.complex_symmetric
349 }
350
351 pub fn complex_conjugated(&self) -> bool {
353 self.complex_conjugated
354 }
355
356 pub fn mol(&self) -> &Molecule {
358 self.mol
359 }
360
361 pub fn bao(&'_ self) -> &'_ BasisAngularOrder<'_> {
364 self.bao
365 }
366
367 pub fn density_matrix(&self) -> &Array2<T> {
369 &self.density_matrix
370 }
371
372 pub fn threshold(&self) -> <T as ComplexFloat>::Real {
374 self.threshold
375 }
376}
377
378impl<'a, T> From<Density<'a, T>> for Density<'a, Complex<T>>
386where
387 T: Float + FloatConst + Lapack,
388 Complex<T>: Lapack,
389{
390 fn from(value: Density<'a, T>) -> Self {
391 Density::<'a, Complex<T>>::builder()
392 .density_matrix(value.density_matrix.map(Complex::from))
393 .bao(value.bao)
394 .mol(value.mol)
395 .complex_symmetric(value.complex_symmetric)
396 .threshold(value.threshold)
397 .build()
398 .expect("Unable to complexify a `Density`.")
399 }
400}
401
402impl<'a, T> PartialEq for Density<'a, T>
406where
407 T: ComplexFloat<Real = f64> + Lapack,
408{
409 fn eq(&self, other: &Self) -> bool {
410 let thresh = (self.threshold * other.threshold).sqrt();
411 let density_matrix_eq = approx::relative_eq!(
412 (&self.density_matrix - &other.density_matrix)
413 .map(|x| ComplexFloat::abs(*x).powi(2))
414 .sum()
415 .sqrt(),
416 0.0,
417 epsilon = thresh,
418 max_relative = thresh,
419 );
420 self.bao == other.bao && self.mol == other.mol && density_matrix_eq
421 }
422}
423
424impl<'a, T> Eq for Density<'a, T> where T: ComplexFloat<Real = f64> + Lapack {}
428
429impl<'a, T> fmt::Debug for Density<'a, T>
433where
434 T: fmt::Debug + ComplexFloat + Lapack,
435 <T as ComplexFloat>::Real: Sum + From<u16> + fmt::Debug,
436{
437 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438 write!(
439 f,
440 "Density[density matrix of dimensions {}]",
441 self.density_matrix
442 .shape()
443 .iter()
444 .map(|x| x.to_string())
445 .collect::<Vec<_>>()
446 .join("×")
447 )?;
448 Ok(())
449 }
450}
451
452impl<'a, T> fmt::Display for Density<'a, T>
456where
457 T: fmt::Display + ComplexFloat + Lapack,
458 <T as ComplexFloat>::Real: Sum + From<u16> + fmt::Display,
459{
460 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
461 write!(
462 f,
463 "Density[density matrix of dimensions {}]",
464 self.density_matrix
465 .shape()
466 .iter()
467 .map(|x| x.to_string())
468 .collect::<Vec<_>>()
469 .join("×")
470 )?;
471 Ok(())
472 }
473}
474
475impl<'a, T> Add<&'_ Density<'a, T>> for &Density<'a, T>
479where
480 T: ComplexFloat + Lapack,
481{
482 type Output = Density<'a, T>;
483
484 fn add(self, rhs: &Density<'a, T>) -> Self::Output {
485 assert_eq!(
486 self.density_matrix.shape(),
487 rhs.density_matrix.shape(),
488 "Inconsistent shapes of density matrices between `self` and `rhs`."
489 );
490 assert_eq!(
491 self.bao, rhs.bao,
492 "Inconsistent basis angular order between `self` and `rhs`."
493 );
494 Density::<T>::builder()
495 .density_matrix(&self.density_matrix + &rhs.density_matrix)
496 .bao(self.bao)
497 .mol(self.mol)
498 .complex_symmetric(self.complex_symmetric)
499 .threshold(self.threshold)
500 .build()
501 .expect("Unable to add two densities together.")
502 }
503}
504
505impl<'a, T> Add<&'_ Density<'a, T>> for Density<'a, T>
506where
507 T: ComplexFloat + Lapack,
508{
509 type Output = Density<'a, T>;
510
511 fn add(self, rhs: &Self) -> Self::Output {
512 &self + rhs
513 }
514}
515
516impl<'a, T> Add<Density<'a, T>> for Density<'a, T>
517where
518 T: ComplexFloat + Lapack,
519{
520 type Output = Density<'a, T>;
521
522 fn add(self, rhs: Self) -> Self::Output {
523 &self + &rhs
524 }
525}
526
527impl<'a, T> Add<Density<'a, T>> for &Density<'a, T>
528where
529 T: ComplexFloat + Lapack,
530{
531 type Output = Density<'a, T>;
532
533 fn add(self, rhs: Density<'a, T>) -> Self::Output {
534 self + &rhs
535 }
536}
537
538impl<'a, T> Sub<&'_ Density<'a, T>> for &Density<'a, T>
542where
543 T: ComplexFloat + Lapack,
544{
545 type Output = Density<'a, T>;
546
547 fn sub(self, rhs: &Density<'a, T>) -> Self::Output {
548 assert_eq!(
549 self.density_matrix.shape(),
550 rhs.density_matrix.shape(),
551 "Inconsistent shapes of density matrices between `self` and `rhs`."
552 );
553 assert_eq!(
554 self.bao, rhs.bao,
555 "Inconsistent basis angular order between `self` and `rhs`."
556 );
557 Density::<T>::builder()
558 .density_matrix(&self.density_matrix - &rhs.density_matrix)
559 .bao(self.bao)
560 .mol(self.mol)
561 .complex_symmetric(self.complex_symmetric)
562 .threshold(self.threshold)
563 .build()
564 .expect("Unable to subtract two densities.")
565 }
566}
567
568impl<'a, T> Sub<&'_ Density<'a, T>> for Density<'a, T>
569where
570 T: ComplexFloat + Lapack,
571{
572 type Output = Density<'a, T>;
573
574 fn sub(self, rhs: &Self) -> Self::Output {
575 &self - rhs
576 }
577}
578
579impl<'a, T> Sub<Density<'a, T>> for Density<'a, T>
580where
581 T: ComplexFloat + Lapack,
582{
583 type Output = Density<'a, T>;
584
585 fn sub(self, rhs: Self) -> Self::Output {
586 &self - &rhs
587 }
588}
589
590impl<'a, T> Sub<Density<'a, T>> for &Density<'a, T>
591where
592 T: ComplexFloat + Lapack,
593{
594 type Output = Density<'a, T>;
595
596 fn sub(self, rhs: Density<'a, T>) -> Self::Output {
597 self - &rhs
598 }
599}