1use std::collections::HashSet;
4use std::fmt;
5use std::iter::Sum;
6use std::ops::{Add, Index, Sub};
7
8use anyhow::{ensure, format_err};
9use approx;
10use derive_builder::Builder;
11use itertools::Itertools;
12use log;
13use ndarray::Array2;
14use ndarray_linalg::types::Lapack;
15use num_complex::{Complex, ComplexFloat};
16use num_traits::float::{Float, FloatConst};
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19use crate::auxiliary::molecule::Molecule;
20use crate::basis::ao::BasisAngularOrder;
21
22#[cfg(test)]
23mod density_tests;
24
25pub mod density_analysis;
26pub mod density_projection;
27mod density_transformation;
28
29#[derive(Builder, Clone)]
35#[builder(build_fn(validate = "Self::validate"))]
36pub struct Densities<'a, T, SC>
37where
38 T: ComplexFloat + Lapack,
39 SC: StructureConstraint + fmt::Display,
40{
41 structure_constraint: SC,
43
44 densities: Vec<&'a Density<'a, T>>,
46}
47
48impl<'a, T, SC> Densities<'a, T, SC>
49where
50 T: ComplexFloat + Lapack,
51 SC: StructureConstraint + Clone + fmt::Display,
52{
53 pub fn builder() -> DensitiesBuilder<'a, T, SC> {
55 DensitiesBuilder::default()
56 }
57
58 pub fn structure_constraint(&self) -> &SC {
60 &self.structure_constraint
61 }
62
63 pub fn iter(&self) -> impl Iterator<Item = &Density<'a, T>> {
64 self.densities.iter().cloned()
65 }
66}
67
68impl<'a, T, SC> DensitiesBuilder<'a, T, SC>
69where
70 T: ComplexFloat + Lapack,
71 SC: StructureConstraint + fmt::Display,
72{
73 fn validate(&self) -> Result<(), String> {
74 let densities = self
75 .densities
76 .as_ref()
77 .ok_or("No `densities` found.".to_string())?;
78 let structure_constraint = self
79 .structure_constraint
80 .as_ref()
81 .ok_or("No structure constraint found.".to_string())?;
82 let num_dens = structure_constraint.n_coefficient_matrices()
83 * structure_constraint.n_explicit_comps_per_coefficient_matrix();
84 if densities.len() != num_dens {
85 Err(format!(
86 "{} {} expected in structure constraint {}, but {} found.",
87 num_dens,
88 structure_constraint,
89 if num_dens == 1 {
90 "density"
91 } else {
92 "densities"
93 },
94 densities.len()
95 ))
96 } else {
97 Ok(())
98 }
99 }
100}
101
102impl<'a, T, SC> Index<usize> for Densities<'a, T, SC>
103where
104 T: ComplexFloat + Lapack,
105 SC: StructureConstraint + fmt::Display,
106{
107 type Output = Density<'a, T>;
108
109 fn index(&self, index: usize) -> &Self::Output {
110 self.densities[index]
111 }
112}
113
114#[derive(Builder, Clone)]
116#[builder(build_fn(validate = "Self::validate"))]
117pub struct DensitiesOwned<'a, T, SC>
118where
119 T: ComplexFloat + Lapack,
120 SC: StructureConstraint + fmt::Display,
121{
122 structure_constraint: SC,
124
125 densities: Vec<Density<'a, T>>,
127}
128
129impl<'a, T, SC> DensitiesOwned<'a, T, SC>
130where
131 T: ComplexFloat + Lapack,
132 SC: StructureConstraint + Clone + fmt::Display,
133{
134 pub fn builder() -> DensitiesOwnedBuilder<'a, T, SC> {
136 DensitiesOwnedBuilder::default()
137 }
138
139 pub fn iter(&self) -> impl Iterator<Item = &Density<'a, T>> {
140 self.densities.iter()
141 }
142
143 pub fn calc_extra_densities<'b: 'a>(
145 &'b self,
146 ) -> Result<Vec<(String, Density<'a, T>)>, anyhow::Error> {
147 let nspatials = self
148 .iter()
149 .map(|den| den.bao().n_funcs())
150 .collect::<HashSet<usize>>();
151 ensure!(
152 nspatials.len() == 1,
153 "Inconsistent number of spatial functions."
154 );
155 let nspatial = *nspatials.iter().next().ok_or_else(|| {
156 format_err!(
157 "Unable to retrieve the number of spatial functions of the density matrices."
158 )
159 })?;
160
161 let den0 = self
162 .iter()
163 .next()
164 .ok_or_else(|| format_err!("Unable to retrieve the first density."))?;
165 ensure!(
166 nspatial == den0.density_matrix.nrows(),
167 "Unexpected density matrix dimension: {} != {}",
168 nspatial,
169 den0.density_matrix.nrows()
170 );
171
172 let total_denmat = self
174 .iter()
175 .fold(Array2::<T>::zeros((nspatial, nspatial)), |acc, den| {
176 acc + den.density_matrix()
177 });
178 let total_den = Density::<T>::builder()
179 .density_matrix(total_denmat)
180 .bao(den0.bao())
181 .mol(den0.mol)
182 .complex_symmetric(den0.complex_symmetric())
183 .threshold(den0.threshold())
184 .build()?;
185
186 vec![Ok(("Total density".to_string(), total_den))]
188 .into_iter()
189 .chain((0..self.densities.len()).combinations(2).map(|indices| {
190 let i0 = indices[0];
191 let i1 = indices[1];
192 let denmat_0 = self.densities[i0].density_matrix();
193 let denmat_1 = self.densities[i1].density_matrix();
194 let denmat_01 = denmat_0 - denmat_1;
195 let den_01 = Density::<T>::builder()
196 .density_matrix(denmat_01)
197 .bao(den0.bao())
198 .mol(den0.mol)
199 .complex_symmetric(den0.complex_symmetric())
200 .threshold(den0.threshold())
201 .build()?;
202 Ok((
203 format!("Density (component {i0}) - Density (component {i1})"),
204 den_01,
205 ))
206 }))
207 .collect::<Result<Vec<_>, _>>()
208 }
209}
210
211impl<'b, 'a: 'b, T, SC> DensitiesOwned<'a, T, SC>
212where
213 T: ComplexFloat + Lapack,
214 SC: StructureConstraint + Clone + fmt::Display,
215{
216 pub fn as_ref(&'a self) -> Densities<'b, T, SC> {
217 Densities::builder()
218 .structure_constraint(self.structure_constraint.clone())
219 .densities(self.iter().collect_vec())
220 .build()
221 .expect("Unable to convert `DensitiesOwned` to `Densities`.")
222 }
223}
224
225impl<'a, T, SC> DensitiesOwnedBuilder<'a, T, SC>
226where
227 T: ComplexFloat + Lapack,
228 SC: StructureConstraint + Clone + fmt::Display,
229{
230 fn validate(&self) -> Result<(), String> {
231 let densities = self
232 .densities
233 .as_ref()
234 .ok_or("No `densities` found.".to_string())?;
235 let structure_constraint = self
236 .structure_constraint
237 .as_ref()
238 .ok_or("No spin constraint found.".to_string())?;
239 let num_dens = structure_constraint.n_coefficient_matrices()
240 * structure_constraint.n_explicit_comps_per_coefficient_matrix();
241 if densities.len() != num_dens {
242 Err(format!(
243 "{} {} expected in structure constraint {}, but {} found.",
244 num_dens,
245 structure_constraint,
246 if num_dens == 1 {
247 "density"
248 } else {
249 "densities"
250 },
251 densities.len()
252 ))
253 } else {
254 Ok(())
255 }
256 }
257}
258
259impl<'a, T, SC> Index<usize> for DensitiesOwned<'a, T, SC>
260where
261 T: ComplexFloat + Lapack,
262 SC: StructureConstraint + fmt::Display,
263{
264 type Output = Density<'a, T>;
265
266 fn index(&self, index: usize) -> &Self::Output {
267 &self.densities[index]
268 }
269}
270
271#[derive(Builder, Clone)]
274#[builder(build_fn(validate = "Self::validate"))]
275pub struct Density<'a, T>
276where
277 T: ComplexFloat + Lapack,
278{
279 bao: &'a BasisAngularOrder<'a>,
282
283 complex_symmetric: bool,
286
287 #[builder(default = "false")]
290 complex_conjugated: bool,
291
292 mol: &'a Molecule,
294
295 density_matrix: Array2<T>,
297
298 threshold: <T as ComplexFloat>::Real,
300}
301
302impl<'a, T> DensityBuilder<'a, T>
303where
304 T: ComplexFloat + Lapack,
305{
306 fn validate(&self) -> Result<(), String> {
307 let bao = self
308 .bao
309 .ok_or("No `BasisAngularOrder` found.".to_string())?;
310 let nbas = bao.n_funcs();
311 let density_matrix = self
312 .density_matrix
313 .as_ref()
314 .ok_or("No density matrices found.".to_string())?;
315
316 let denmat_shape = density_matrix.shape() == [nbas, nbas];
317 if !denmat_shape {
318 log::error!(
319 "The density matrix dimensions ({:?}) are incompatible with the basis ({nbas} {}).",
320 density_matrix.shape(),
321 if nbas != 1 { "functions" } else { "function" }
322 );
323 }
324
325 let mol = self.mol.ok_or("No molecule found.".to_string())?;
326 let natoms = mol.atoms.len() == bao.n_atoms();
327 if !natoms {
328 log::error!(
329 "The number of atoms in the molecule does not match the number of local sites in the basis."
330 );
331 }
332 if denmat_shape && natoms {
333 Ok(())
334 } else {
335 Err(format!(
336 "Density validation failed: `denmat_shape`: {denmat_shape}, `natoms`: {natoms}"
337 ))
338 }
339 }
340}
341
342impl<'a, T> Density<'a, T>
343where
344 T: ComplexFloat + Clone + Lapack,
345{
346 pub fn builder() -> DensityBuilder<'a, T> {
348 DensityBuilder::default()
349 }
350
351 pub fn complex_symmetric(&self) -> bool {
353 self.complex_symmetric
354 }
355
356 pub fn complex_conjugated(&self) -> bool {
358 self.complex_conjugated
359 }
360
361 pub fn mol(&self) -> &Molecule {
363 self.mol
364 }
365
366 pub fn bao(&'_ self) -> &'_ BasisAngularOrder<'_> {
369 self.bao
370 }
371
372 pub fn density_matrix(&self) -> &Array2<T> {
374 &self.density_matrix
375 }
376
377 pub fn threshold(&self) -> <T as ComplexFloat>::Real {
379 self.threshold
380 }
381}
382
383impl<'a, T> From<Density<'a, T>> for Density<'a, Complex<T>>
391where
392 T: Float + FloatConst + Lapack,
393 Complex<T>: Lapack,
394{
395 fn from(value: Density<'a, T>) -> Self {
396 Density::<'a, Complex<T>>::builder()
397 .density_matrix(value.density_matrix.map(Complex::from))
398 .bao(value.bao)
399 .mol(value.mol)
400 .complex_symmetric(value.complex_symmetric)
401 .threshold(value.threshold)
402 .build()
403 .expect("Unable to complexify a `Density`.")
404 }
405}
406
407impl<'a, T> PartialEq for Density<'a, T>
411where
412 T: ComplexFloat<Real = f64> + Lapack,
413{
414 fn eq(&self, other: &Self) -> bool {
415 let thresh = (self.threshold * other.threshold).sqrt();
416 let density_matrix_eq = approx::relative_eq!(
417 (&self.density_matrix - &other.density_matrix)
418 .map(|x| ComplexFloat::abs(*x).powi(2))
419 .sum()
420 .sqrt(),
421 0.0,
422 epsilon = thresh,
423 max_relative = thresh,
424 );
425 self.bao == other.bao && self.mol == other.mol && density_matrix_eq
426 }
427}
428
429impl<'a, T> Eq for Density<'a, T> where T: ComplexFloat<Real = f64> + Lapack {}
433
434impl<'a, T> fmt::Debug for Density<'a, T>
438where
439 T: fmt::Debug + ComplexFloat + Lapack,
440 <T as ComplexFloat>::Real: Sum + From<u16> + fmt::Debug,
441{
442 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
443 write!(
444 f,
445 "Density[density matrix of dimensions {}]",
446 self.density_matrix
447 .shape()
448 .iter()
449 .map(|x| x.to_string())
450 .collect::<Vec<_>>()
451 .join("×")
452 )?;
453 Ok(())
454 }
455}
456
457impl<'a, T> fmt::Display for Density<'a, T>
461where
462 T: fmt::Display + ComplexFloat + Lapack,
463 <T as ComplexFloat>::Real: Sum + From<u16> + fmt::Display,
464{
465 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
466 write!(
467 f,
468 "Density[density matrix of dimensions {}]",
469 self.density_matrix
470 .shape()
471 .iter()
472 .map(|x| x.to_string())
473 .collect::<Vec<_>>()
474 .join("×")
475 )?;
476 Ok(())
477 }
478}
479
480impl<'a, T> Add<&'_ Density<'a, T>> for &Density<'a, T>
484where
485 T: ComplexFloat + Lapack,
486{
487 type Output = Density<'a, T>;
488
489 fn add(self, rhs: &Density<'a, T>) -> Self::Output {
490 assert_eq!(
491 self.density_matrix.shape(),
492 rhs.density_matrix.shape(),
493 "Inconsistent shapes of density matrices between `self` and `rhs`."
494 );
495 assert_eq!(
496 self.bao, rhs.bao,
497 "Inconsistent basis angular order between `self` and `rhs`."
498 );
499 Density::<T>::builder()
500 .density_matrix(&self.density_matrix + &rhs.density_matrix)
501 .bao(self.bao)
502 .mol(self.mol)
503 .complex_symmetric(self.complex_symmetric)
504 .threshold(self.threshold)
505 .build()
506 .expect("Unable to add two densities together.")
507 }
508}
509
510impl<'a, T> Add<&'_ Density<'a, T>> for Density<'a, T>
511where
512 T: ComplexFloat + Lapack,
513{
514 type Output = Density<'a, T>;
515
516 fn add(self, rhs: &Self) -> Self::Output {
517 &self + rhs
518 }
519}
520
521impl<'a, T> Add<Density<'a, T>> for Density<'a, T>
522where
523 T: ComplexFloat + Lapack,
524{
525 type Output = Density<'a, T>;
526
527 fn add(self, rhs: Self) -> Self::Output {
528 &self + &rhs
529 }
530}
531
532impl<'a, T> Add<Density<'a, T>> for &Density<'a, T>
533where
534 T: ComplexFloat + Lapack,
535{
536 type Output = Density<'a, T>;
537
538 fn add(self, rhs: Density<'a, T>) -> Self::Output {
539 self + &rhs
540 }
541}
542
543impl<'a, T> Sub<&'_ Density<'a, T>> for &Density<'a, T>
547where
548 T: ComplexFloat + Lapack,
549{
550 type Output = Density<'a, T>;
551
552 fn sub(self, rhs: &Density<'a, T>) -> Self::Output {
553 assert_eq!(
554 self.density_matrix.shape(),
555 rhs.density_matrix.shape(),
556 "Inconsistent shapes of density matrices between `self` and `rhs`."
557 );
558 assert_eq!(
559 self.bao, rhs.bao,
560 "Inconsistent basis angular order between `self` and `rhs`."
561 );
562 Density::<T>::builder()
563 .density_matrix(&self.density_matrix - &rhs.density_matrix)
564 .bao(self.bao)
565 .mol(self.mol)
566 .complex_symmetric(self.complex_symmetric)
567 .threshold(self.threshold)
568 .build()
569 .expect("Unable to subtract two densities.")
570 }
571}
572
573impl<'a, T> Sub<&'_ Density<'a, T>> for Density<'a, T>
574where
575 T: ComplexFloat + Lapack,
576{
577 type Output = Density<'a, T>;
578
579 fn sub(self, rhs: &Self) -> Self::Output {
580 &self - rhs
581 }
582}
583
584impl<'a, T> Sub<Density<'a, T>> for Density<'a, T>
585where
586 T: ComplexFloat + Lapack,
587{
588 type Output = Density<'a, T>;
589
590 fn sub(self, rhs: Self) -> Self::Output {
591 &self - &rhs
592 }
593}
594
595impl<'a, T> Sub<Density<'a, T>> for &Density<'a, T>
596where
597 T: ComplexFloat + Lapack,
598{
599 type Output = Density<'a, T>;
600
601 fn sub(self, rhs: Density<'a, T>) -> Self::Output {
602 self - &rhs
603 }
604}