1use std::collections::HashSet;
4use std::fmt;
5use std::iter::Sum;
6use std::ops::{Add, Index, Sub};
7
8use anyhow::{ensure, format_err};
9use approx;
10use derive_builder::Builder;
11use itertools::Itertools;
12use log;
13use ndarray::Array2;
14use ndarray_linalg::types::Lapack;
15use num_complex::{Complex, ComplexFloat};
16use num_traits::float::{Float, FloatConst};
17
18use crate::angmom::spinor_rotation_3d::StructureConstraint;
19use crate::auxiliary::molecule::Molecule;
20use crate::basis::ao::BasisAngularOrder;
21
22#[cfg(test)]
23mod density_tests;
24
25pub mod density_analysis;
26mod density_transformation;
27
28#[derive(Builder, Clone)]
34#[builder(build_fn(validate = "Self::validate"))]
35pub struct Densities<'a, T, SC>
36where
37 T: ComplexFloat + Lapack,
38 SC: StructureConstraint + fmt::Display,
39{
40 structure_constraint: SC,
42
43 densities: Vec<&'a Density<'a, T>>,
45}
46
47impl<'a, T, SC> Densities<'a, T, SC>
48where
49 T: ComplexFloat + Lapack,
50 SC: StructureConstraint + Clone + fmt::Display,
51{
52 pub fn builder() -> DensitiesBuilder<'a, T, SC> {
54 DensitiesBuilder::default()
55 }
56
57 pub fn iter(&self) -> impl Iterator<Item = &Density<'a, T>> {
58 self.densities.iter().cloned()
59 }
60}
61
62impl<'a, T, SC> DensitiesBuilder<'a, T, SC>
63where
64 T: ComplexFloat + Lapack,
65 SC: StructureConstraint + fmt::Display,
66{
67 fn validate(&self) -> Result<(), String> {
68 let densities = self
69 .densities
70 .as_ref()
71 .ok_or("No `densities` found.".to_string())?;
72 let structure_constraint = self
73 .structure_constraint
74 .as_ref()
75 .ok_or("No structure constraint found.".to_string())?;
76 let num_dens = structure_constraint.n_coefficient_matrices()
77 * structure_constraint.n_explicit_comps_per_coefficient_matrix();
78 if densities.len() != num_dens {
79 Err(format!(
80 "{} {} expected in structure constraint {}, but {} found.",
81 num_dens,
82 structure_constraint,
83 if num_dens == 1 {
84 "density"
85 } else {
86 "densities"
87 },
88 densities.len()
89 ))
90 } else {
91 Ok(())
92 }
93 }
94}
95
96impl<'a, T, SC> Index<usize> for Densities<'a, T, SC>
97where
98 T: ComplexFloat + Lapack,
99 SC: StructureConstraint + fmt::Display,
100{
101 type Output = Density<'a, T>;
102
103 fn index(&self, index: usize) -> &Self::Output {
104 self.densities[index]
105 }
106}
107
108#[derive(Builder, Clone)]
110#[builder(build_fn(validate = "Self::validate"))]
111pub struct DensitiesOwned<'a, T, SC>
112where
113 T: ComplexFloat + Lapack,
114 SC: StructureConstraint + fmt::Display,
115{
116 structure_constraint: SC,
118
119 densities: Vec<Density<'a, T>>,
121}
122
123impl<'a, T, SC> DensitiesOwned<'a, T, SC>
124where
125 T: ComplexFloat + Lapack,
126 SC: StructureConstraint + Clone + fmt::Display,
127{
128 pub fn builder() -> DensitiesOwnedBuilder<'a, T, SC> {
130 DensitiesOwnedBuilder::default()
131 }
132
133 pub fn iter(&self) -> impl Iterator<Item = &Density<'a, T>> {
134 self.densities.iter()
135 }
136
137 pub fn calc_extra_densities<'b: 'a>(
139 &'b self,
140 ) -> Result<Vec<(String, Density<'a, T>)>, anyhow::Error> {
141 let nspatials = self
142 .iter()
143 .map(|den| den.bao().n_funcs())
144 .collect::<HashSet<usize>>();
145 ensure!(
146 nspatials.len() == 1,
147 "Inconsistent number of spatial functions."
148 );
149 let nspatial = *nspatials.iter().next().ok_or_else(|| {
150 format_err!(
151 "Unable to retrieve the number of spatial functions of the density matrices."
152 )
153 })?;
154
155 let den0 = self
156 .iter()
157 .next()
158 .ok_or_else(|| format_err!("Unable to retrieve the first density."))?;
159 ensure!(
160 nspatial == den0.density_matrix.nrows(),
161 "Unexpected density matrix dimension: {} != {}",
162 nspatial,
163 den0.density_matrix.nrows()
164 );
165
166 let total_denmat = self
168 .iter()
169 .fold(Array2::<T>::zeros((nspatial, nspatial)), |acc, den| {
170 acc + den.density_matrix()
171 });
172 let total_den = Density::<T>::builder()
173 .density_matrix(total_denmat)
174 .bao(den0.bao())
175 .mol(den0.mol)
176 .complex_symmetric(den0.complex_symmetric())
177 .threshold(den0.threshold())
178 .build()?;
179
180 let extra_dens = vec![Ok(("Total density".to_string(), total_den))]
182 .into_iter()
183 .chain((0..self.densities.len()).combinations(2).map(|indices| {
184 let i0 = indices[0];
185 let i1 = indices[1];
186 let denmat_0 = self.densities[i0].density_matrix();
187 let denmat_1 = self.densities[i1].density_matrix();
188 let denmat_01 = denmat_0 - denmat_1;
189 let den_01 = Density::<T>::builder()
190 .density_matrix(denmat_01)
191 .bao(den0.bao())
192 .mol(den0.mol)
193 .complex_symmetric(den0.complex_symmetric())
194 .threshold(den0.threshold())
195 .build()?;
196 Ok((
197 format!("Density (component {i0}) - Density (component {i1})"),
198 den_01,
199 ))
200 }))
201 .collect::<Result<Vec<_>, _>>();
202
203 extra_dens
204 }
205}
206
207impl<'b, 'a: 'b, T, SC> DensitiesOwned<'a, T, SC>
208where
209 T: ComplexFloat + Lapack,
210 SC: StructureConstraint + Clone + fmt::Display,
211{
212 pub fn as_ref(&'a self) -> Densities<'b, T, SC> {
213 Densities::builder()
214 .structure_constraint(self.structure_constraint.clone())
215 .densities(self.iter().collect_vec())
216 .build()
217 .expect("Unable to convert `DensitiesOwned` to `Densities`.")
218 }
219}
220
221impl<'a, T, SC> DensitiesOwnedBuilder<'a, T, SC>
222where
223 T: ComplexFloat + Lapack,
224 SC: StructureConstraint + Clone + fmt::Display,
225{
226 fn validate(&self) -> Result<(), String> {
227 let densities = self
228 .densities
229 .as_ref()
230 .ok_or("No `densities` found.".to_string())?;
231 let structure_constraint = self
232 .structure_constraint
233 .as_ref()
234 .ok_or("No spin constraint found.".to_string())?;
235 let num_dens = structure_constraint.n_coefficient_matrices()
236 * structure_constraint.n_explicit_comps_per_coefficient_matrix();
237 if densities.len() != num_dens {
238 Err(format!(
239 "{} {} expected in structure constraint {}, but {} found.",
240 num_dens,
241 structure_constraint,
242 if num_dens == 1 {
243 "density"
244 } else {
245 "densities"
246 },
247 densities.len()
248 ))
249 } else {
250 Ok(())
251 }
252 }
253}
254
255impl<'a, T, SC> Index<usize> for DensitiesOwned<'a, T, SC>
256where
257 T: ComplexFloat + Lapack,
258 SC: StructureConstraint + fmt::Display,
259{
260 type Output = Density<'a, T>;
261
262 fn index(&self, index: usize) -> &Self::Output {
263 &self.densities[index]
264 }
265}
266
267#[derive(Builder, Clone)]
270#[builder(build_fn(validate = "Self::validate"))]
271pub struct Density<'a, T>
272where
273 T: ComplexFloat + Lapack,
274{
275 bao: &'a BasisAngularOrder<'a>,
278
279 complex_symmetric: bool,
282
283 #[builder(default = "false")]
286 complex_conjugated: bool,
287
288 mol: &'a Molecule,
290
291 density_matrix: Array2<T>,
293
294 threshold: <T as ComplexFloat>::Real,
296}
297
298impl<'a, T> DensityBuilder<'a, T>
299where
300 T: ComplexFloat + Lapack,
301{
302 fn validate(&self) -> Result<(), String> {
303 let bao = self
304 .bao
305 .ok_or("No `BasisAngularOrder` found.".to_string())?;
306 let nbas = bao.n_funcs();
307 let density_matrix = self
308 .density_matrix
309 .as_ref()
310 .ok_or("No density matrices found.".to_string())?;
311
312 let denmat_shape = density_matrix.shape() == [nbas, nbas];
313 if !denmat_shape {
314 log::error!(
315 "The density matrix dimensions ({:?}) are incompatible with the basis ({nbas} {}).",
316 density_matrix.shape(),
317 if nbas != 1 { "functions" } else { "function" }
318 );
319 }
320
321 let mol = self.mol.ok_or("No molecule found.".to_string())?;
322 let natoms = mol.atoms.len() == bao.n_atoms();
323 if !natoms {
324 log::error!("The number of atoms in the molecule does not match the number of local sites in the basis.");
325 }
326 if denmat_shape && natoms {
327 Ok(())
328 } else {
329 Err("Density validation failed.".to_string())
330 }
331 }
332}
333
334impl<'a, T> Density<'a, T>
335where
336 T: ComplexFloat + Clone + Lapack,
337{
338 pub fn builder() -> DensityBuilder<'a, T> {
340 DensityBuilder::default()
341 }
342
343 pub fn complex_symmetric(&self) -> bool {
345 self.complex_symmetric
346 }
347
348 pub fn bao(&self) -> &BasisAngularOrder {
351 self.bao
352 }
353
354 pub fn density_matrix(&self) -> &Array2<T> {
356 &self.density_matrix
357 }
358
359 pub fn threshold(&self) -> <T as ComplexFloat>::Real {
361 self.threshold
362 }
363}
364
365impl<'a, T> From<Density<'a, T>> for Density<'a, Complex<T>>
373where
374 T: Float + FloatConst + Lapack,
375 Complex<T>: Lapack,
376{
377 fn from(value: Density<'a, T>) -> Self {
378 Density::<'a, Complex<T>>::builder()
379 .density_matrix(value.density_matrix.map(Complex::from))
380 .bao(value.bao)
381 .mol(value.mol)
382 .complex_symmetric(value.complex_symmetric)
383 .threshold(value.threshold)
384 .build()
385 .expect("Unable to complexify a `Density`.")
386 }
387}
388
389impl<'a, T> PartialEq for Density<'a, T>
393where
394 T: ComplexFloat<Real = f64> + Lapack,
395{
396 fn eq(&self, other: &Self) -> bool {
397 let thresh = (self.threshold * other.threshold).sqrt();
398 let density_matrix_eq = approx::relative_eq!(
399 (&self.density_matrix - &other.density_matrix)
400 .map(|x| ComplexFloat::abs(*x).powi(2))
401 .sum()
402 .sqrt(),
403 0.0,
404 epsilon = thresh,
405 max_relative = thresh,
406 );
407 self.bao == other.bao && self.mol == other.mol && density_matrix_eq
408 }
409}
410
411impl<'a, T> Eq for Density<'a, T> where T: ComplexFloat<Real = f64> + Lapack {}
415
416impl<'a, T> fmt::Debug for Density<'a, T>
420where
421 T: fmt::Debug + ComplexFloat + Lapack,
422 <T as ComplexFloat>::Real: Sum + From<u16> + fmt::Debug,
423{
424 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
425 write!(
426 f,
427 "Density[density matrix of dimensions {}]",
428 self.density_matrix
429 .shape()
430 .iter()
431 .map(|x| x.to_string())
432 .collect::<Vec<_>>()
433 .join("×")
434 )?;
435 Ok(())
436 }
437}
438
439impl<'a, T> fmt::Display for Density<'a, T>
443where
444 T: fmt::Display + ComplexFloat + Lapack,
445 <T as ComplexFloat>::Real: Sum + From<u16> + fmt::Display,
446{
447 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448 write!(
449 f,
450 "Density[density matrix of dimensions {}]",
451 self.density_matrix
452 .shape()
453 .iter()
454 .map(|x| x.to_string())
455 .collect::<Vec<_>>()
456 .join("×")
457 )?;
458 Ok(())
459 }
460}
461
462impl<'a, T> Add<&'_ Density<'a, T>> for &Density<'a, T>
466where
467 T: ComplexFloat + Lapack,
468{
469 type Output = Density<'a, T>;
470
471 fn add(self, rhs: &Density<'a, T>) -> Self::Output {
472 assert_eq!(
473 self.density_matrix.shape(),
474 rhs.density_matrix.shape(),
475 "Inconsistent shapes of density matrices between `self` and `rhs`."
476 );
477 assert_eq!(
478 self.bao, rhs.bao,
479 "Inconsistent basis angular order between `self` and `rhs`."
480 );
481 Density::<T>::builder()
482 .density_matrix(&self.density_matrix + &rhs.density_matrix)
483 .bao(self.bao)
484 .mol(self.mol)
485 .complex_symmetric(self.complex_symmetric)
486 .threshold(self.threshold)
487 .build()
488 .expect("Unable to add two densities together.")
489 }
490}
491
492impl<'a, T> Add<&'_ Density<'a, T>> for Density<'a, T>
493where
494 T: ComplexFloat + Lapack,
495{
496 type Output = Density<'a, T>;
497
498 fn add(self, rhs: &Self) -> Self::Output {
499 &self + rhs
500 }
501}
502
503impl<'a, T> Add<Density<'a, T>> for Density<'a, T>
504where
505 T: ComplexFloat + Lapack,
506{
507 type Output = Density<'a, T>;
508
509 fn add(self, rhs: Self) -> Self::Output {
510 &self + &rhs
511 }
512}
513
514impl<'a, T> Add<Density<'a, T>> for &Density<'a, T>
515where
516 T: ComplexFloat + Lapack,
517{
518 type Output = Density<'a, T>;
519
520 fn add(self, rhs: Density<'a, T>) -> Self::Output {
521 self + &rhs
522 }
523}
524
525impl<'a, T> Sub<&'_ Density<'a, T>> for &Density<'a, T>
529where
530 T: ComplexFloat + Lapack,
531{
532 type Output = Density<'a, T>;
533
534 fn sub(self, rhs: &Density<'a, T>) -> Self::Output {
535 assert_eq!(
536 self.density_matrix.shape(),
537 rhs.density_matrix.shape(),
538 "Inconsistent shapes of density matrices between `self` and `rhs`."
539 );
540 assert_eq!(
541 self.bao, rhs.bao,
542 "Inconsistent basis angular order between `self` and `rhs`."
543 );
544 Density::<T>::builder()
545 .density_matrix(&self.density_matrix - &rhs.density_matrix)
546 .bao(self.bao)
547 .mol(self.mol)
548 .complex_symmetric(self.complex_symmetric)
549 .threshold(self.threshold)
550 .build()
551 .expect("Unable to subtract two densities.")
552 }
553}
554
555impl<'a, T> Sub<&'_ Density<'a, T>> for Density<'a, T>
556where
557 T: ComplexFloat + Lapack,
558{
559 type Output = Density<'a, T>;
560
561 fn sub(self, rhs: &Self) -> Self::Output {
562 &self - rhs
563 }
564}
565
566impl<'a, T> Sub<Density<'a, T>> for Density<'a, T>
567where
568 T: ComplexFloat + Lapack,
569{
570 type Output = Density<'a, T>;
571
572 fn sub(self, rhs: Self) -> Self::Output {
573 &self - &rhs
574 }
575}
576
577impl<'a, T> Sub<Density<'a, T>> for &Density<'a, T>
578where
579 T: ComplexFloat + Lapack,
580{
581 type Output = Density<'a, T>;
582
583 fn sub(self, rhs: Density<'a, T>) -> Self::Output {
584 self - &rhs
585 }
586}