1use std::collections::HashSet;
4use std::fmt;
5
6use anyhow::{self, ensure, format_err};
7use approx;
8use derive_builder::Builder;
9use ndarray::{Array1, Array2, Ix2, s};
10use ndarray_einsum::*;
11use ndarray_linalg::types::Lapack;
12use num_complex::{Complex, ComplexFloat};
13use num_traits::float::{Float, FloatConst};
14
15use crate::angmom::spinor_rotation_3d::{SpinConstraint, SpinOrbitCoupled, StructureConstraint};
16use crate::auxiliary::molecule::Molecule;
17use crate::basis::ao::BasisAngularOrder;
18use crate::target::density::Density;
19
20#[cfg(test)]
21mod orbital_tests;
22
23pub mod orbital_analysis;
24pub mod orbital_projection;
25mod orbital_transformation;
26
27#[derive(Builder, Clone)]
34#[builder(build_fn(validate = "Self::validate"))]
35pub struct MolecularOrbital<'a, T, SC>
36where
37 T: ComplexFloat + Lapack,
38 SC: StructureConstraint,
39{
40 structure_constraint: SC,
42
43 component_index: usize,
46
47 #[builder(default = "false")]
50 complex_conjugated: bool,
51
52 baos: Vec<&'a BasisAngularOrder<'a>>,
56
57 complex_symmetric: bool,
60
61 mol: &'a Molecule,
63
64 coefficients: Array1<T>,
66
67 #[builder(default = "None")]
69 energy: Option<T>,
70
71 threshold: <T as ComplexFloat>::Real,
73}
74
75impl<'a, T, SC> MolecularOrbitalBuilder<'a, T, SC>
76where
77 T: ComplexFloat + Lapack,
78 SC: StructureConstraint,
79{
80 fn validate(&self) -> Result<(), String> {
81 let structcons = self
82 .structure_constraint
83 .as_ref()
84 .ok_or("No structure constraints found.".to_string())?;
85 let baos = self
86 .baos
87 .as_ref()
88 .ok_or("No `BasisAngularOrder`s found.".to_string())?;
89 let baos_length_check = baos.len() == structcons.n_explicit_comps_per_coefficient_matrix();
90 if !baos_length_check {
91 log::error!(
92 "The number of `BasisAngularOrder`s provided does not match the number of explicit components per coefficient matrix."
93 );
94 }
95
96 let nbas_tot = baos.iter().map(|bao| bao.n_funcs()).sum::<usize>();
97 let coefficients = self
98 .coefficients
99 .as_ref()
100 .ok_or("No coefficients found.".to_string())?;
101
102 let coefficients_shape_check = {
103 let nrows = nbas_tot;
104 if !coefficients.shape()[0] == nrows {
105 log::error!(
106 "Unexpected shapes of coefficient vector: {} {} expected, but {} found.",
107 nrows,
108 if nrows == 1 { "row" } else { "rows" },
109 coefficients.shape()[0],
110 );
111 false
112 } else {
113 true
114 }
115 };
116
117 let mol = self.mol.ok_or("No molecule found.".to_string())?;
118 let natoms_set = baos.iter().map(|bao| bao.n_atoms()).collect::<HashSet<_>>();
119 if natoms_set.len() != 1 {
120 return Err("Inconsistent numbers of atoms between `BasisAngularOrder`s of different explicit components.".to_string());
121 };
122 let n_atoms = natoms_set.iter().next().ok_or_else(|| {
123 "Unable to retrieve the number of atoms from the `BasisAngularOrder`s.".to_string()
124 })?;
125 let natoms_check = mol.atoms.len() == *n_atoms;
126 if !natoms_check {
127 log::error!(
128 "The number of atoms in the molecule does not match the number of local sites in the basis."
129 );
130 }
131
132 if baos_length_check && coefficients_shape_check && natoms_check {
133 Ok(())
134 } else {
135 Err(format!(
136 "Molecular orbital validation failed:
137 baos_length ({baos_length_check}),
138 coefficients_shape ({coefficients_shape_check}),
139 natoms ({natoms_check})."
140 ))
141 }
142 }
143}
144
145impl<'a, T, SC> MolecularOrbital<'a, T, SC>
146where
147 T: ComplexFloat + Clone + Lapack,
148 SC: StructureConstraint + Clone,
149{
150 pub fn builder() -> MolecularOrbitalBuilder<'a, T, SC> {
152 MolecularOrbitalBuilder::default()
153 }
154}
155
156impl<'a, T, SC> MolecularOrbital<'a, T, SC>
157where
158 T: ComplexFloat + Clone + Lapack,
159 SC: StructureConstraint,
160{
161 pub fn coefficients(&self) -> &Array1<T> {
163 &self.coefficients
164 }
165
166 pub fn structure_constraint(&self) -> &SC {
168 &self.structure_constraint
169 }
170
171 pub fn baos(&'_ self) -> &Vec<&'_ BasisAngularOrder<'_>> {
174 &self.baos
175 }
176
177 pub fn mol(&self) -> &Molecule {
179 self.mol
180 }
181
182 pub fn complex_symmetric(&self) -> bool {
184 self.complex_symmetric
185 }
186
187 pub fn threshold(&self) -> <T as ComplexFloat>::Real {
189 self.threshold
190 }
191}
192
193impl<'a, T> MolecularOrbital<'a, T, SpinConstraint>
194where
195 T: ComplexFloat + Clone + Lapack,
196{
197 pub fn to_generalised(&self) -> Self {
205 match self.structure_constraint {
206 SpinConstraint::Restricted(n) => {
207 let bao = self.baos[0];
208 let nbas = bao.n_funcs();
209
210 let cr = &self.coefficients;
211 let mut cg = Array1::<T>::zeros(nbas * usize::from(n));
212 let start = nbas * self.component_index;
213 let end = nbas * (self.component_index + 1);
214 cg.slice_mut(s![start..end]).assign(cr);
215 Self::builder()
216 .coefficients(cg)
217 .baos((0..n).map(|_| bao).collect::<Vec<_>>())
218 .mol(self.mol)
219 .structure_constraint(SpinConstraint::Generalised(n, false))
220 .component_index(0)
221 .complex_symmetric(self.complex_symmetric)
222 .threshold(self.threshold)
223 .build()
224 .expect("Unable to construct a generalised molecular orbital.")
225 }
226 SpinConstraint::Unrestricted(n, increasingm) => {
227 let bao = self.baos[0];
228 let nbas = bao.n_funcs();
229
230 let cr = &self.coefficients;
231 let mut cg = Array1::<T>::zeros(nbas * usize::from(n));
232 let start = nbas * self.component_index;
233 let end = nbas * (self.component_index + 1);
234 cg.slice_mut(s![start..end]).assign(cr);
235 Self::builder()
236 .coefficients(cg)
237 .baos((0..n).map(|_| bao).collect::<Vec<_>>())
238 .mol(self.mol)
239 .structure_constraint(SpinConstraint::Generalised(n, increasingm))
240 .component_index(0)
241 .complex_symmetric(self.complex_symmetric)
242 .threshold(self.threshold)
243 .build()
244 .expect("Unable to construct a generalised molecular orbital.")
245 }
246 SpinConstraint::Generalised(_, _) => self.clone(),
247 }
248 }
249}
250
251impl<'a> MolecularOrbital<'a, f64, SpinConstraint> {
252 pub fn to_total_density(&'a self) -> Result<Density<'a, f64>, anyhow::Error> {
254 match self.structure_constraint {
255 SpinConstraint::Restricted(nspins) => {
256 let denmat = f64::from(nspins)
257 * einsum(
258 "m,n->mn",
259 &[&self.coefficients.view(), &self.coefficients.view()],
260 )
261 .expect("Unable to construct a density matrix from the coefficient matrix.")
262 .into_dimensionality::<Ix2>()
263 .expect("Unable to convert the resultant density matrix to two dimensions.");
264 Density::<f64>::builder()
265 .density_matrix(denmat)
266 .bao(self.baos()[0])
267 .mol(self.mol())
268 .complex_symmetric(self.complex_symmetric())
269 .threshold(self.threshold())
270 .build()
271 .map_err(|err| format_err!(err))
272 }
273 SpinConstraint::Unrestricted(_, _) => {
274 let denmat = einsum(
275 "m,n->mn",
276 &[&self.coefficients.view(), &self.coefficients.view()],
277 )
278 .expect("Unable to construct a density matrix from the coefficient matrix.")
279 .into_dimensionality::<Ix2>()
280 .expect("Unable to convert the resultant density matrix to two dimensions.");
281 Density::<f64>::builder()
282 .density_matrix(denmat)
283 .bao(self.baos()[0])
284 .mol(self.mol())
285 .complex_symmetric(self.complex_symmetric())
286 .threshold(self.threshold())
287 .build()
288 .map_err(|err| format_err!(err))
289 }
290 SpinConstraint::Generalised(nspins, _) => {
291 let full_denmat = einsum(
292 "m,n->mn",
293 &[&self.coefficients.view(), &self.coefficients.view()],
294 )
295 .expect("Unable to construct a density matrix from the coefficient matrix.")
296 .into_dimensionality::<Ix2>()
297 .expect("Unable to convert the resultant density matrix to two dimensions.");
298
299 let nspatial_set = self
300 .baos()
301 .iter()
302 .map(|bao| bao.n_funcs())
303 .collect::<HashSet<_>>();
304 ensure!(
305 nspatial_set.len() == 1,
306 "Mismatched numbers of basis functions between the explicit components."
307 );
308 let nspatial = *nspatial_set.iter().next().ok_or_else(|| {
309 format_err!(
310 "Unable to extract the number of basis functions per explicit component."
311 )
312 })?;
313
314 let denmat = (0..usize::from(nspins)).fold(
315 Array2::<f64>::zeros((nspatial, nspatial)),
316 |acc, ispin| {
317 acc + full_denmat.slice(s![
318 ispin * nspatial..(ispin + 1) * nspatial,
319 ispin * nspatial..(ispin + 1) * nspatial
320 ])
321 },
322 );
323 Density::<f64>::builder()
324 .density_matrix(denmat)
325 .bao(self.baos()[0])
326 .mol(self.mol())
327 .complex_symmetric(self.complex_symmetric())
328 .threshold(self.threshold())
329 .build()
330 .map_err(|err| format_err!(err))
331 }
332 }
333 }
334}
335
336impl<'a, T> MolecularOrbital<'a, Complex<T>, SpinConstraint>
337where
338 T: Float + FloatConst + Lapack + From<u16>,
339 Complex<T>: Lapack,
340{
341 pub fn to_total_density(&'a self) -> Result<Density<'a, Complex<T>>, anyhow::Error> {
343 match self.structure_constraint {
344 SpinConstraint::Restricted(nspins) => {
345 let nspins_t = Complex::<T>::from(<T as From<u16>>::from(nspins));
346 let denmat = einsum(
347 "m,n->mn",
348 &[
349 &self.coefficients.view(),
350 &self.coefficients.map(Complex::conj).view(),
351 ],
352 )
353 .expect("Unable to construct a density matrix from the coefficient matrix.")
354 .into_dimensionality::<Ix2>()
355 .expect("Unable to convert the resultant density matrix to two dimensions.")
356 .map(|x| x * nspins_t);
357 Density::<Complex<T>>::builder()
358 .density_matrix(denmat)
359 .bao(self.baos()[0])
360 .mol(self.mol())
361 .complex_symmetric(self.complex_symmetric())
362 .threshold(self.threshold())
363 .build()
364 .map_err(|err| format_err!(err))
365 }
366 SpinConstraint::Unrestricted(_, _) => {
367 let denmat = einsum(
368 "m,n->mn",
369 &[
370 &self.coefficients.view(),
371 &self.coefficients.map(Complex::conj).view(),
372 ],
373 )
374 .expect("Unable to construct a density matrix from the coefficient matrix.")
375 .into_dimensionality::<Ix2>()
376 .expect("Unable to convert the resultant density matrix to two dimensions.");
377 Density::<Complex<T>>::builder()
378 .density_matrix(denmat)
379 .bao(self.baos()[0])
380 .mol(self.mol())
381 .complex_symmetric(self.complex_symmetric())
382 .threshold(self.threshold())
383 .build()
384 .map_err(|err| format_err!(err))
385 }
386 SpinConstraint::Generalised(nspins, _) => {
387 let full_denmat = einsum(
388 "m,n->mn",
389 &[
390 &self.coefficients.view(),
391 &self.coefficients.map(Complex::conj).view(),
392 ],
393 )
394 .expect("Unable to construct a density matrix from the coefficient matrix.")
395 .into_dimensionality::<Ix2>()
396 .expect("Unable to convert the resultant density matrix to two dimensions.");
397
398 let nspatial_set = self
399 .baos()
400 .iter()
401 .map(|bao| bao.n_funcs())
402 .collect::<HashSet<_>>();
403 ensure!(
404 nspatial_set.len() == 1,
405 "Mismatched numbers of basis functions between the explicit components in the generalised spin constraint."
406 );
407 let nspatial = *nspatial_set.iter().next().ok_or_else(|| {
408 format_err!(
409 "Unable to extract the number of basis functions per explicit component."
410 )
411 })?;
412
413 let denmat = (0..usize::from(nspins)).fold(
414 Array2::<Complex<T>>::zeros((nspatial, nspatial)),
415 |acc, ispin| {
416 acc + full_denmat.slice(s![
417 ispin * nspatial..(ispin + 1) * nspatial,
418 ispin * nspatial..(ispin + 1) * nspatial
419 ])
420 },
421 );
422 Density::<Complex<T>>::builder()
423 .density_matrix(denmat)
424 .bao(self.baos()[0])
425 .mol(self.mol())
426 .complex_symmetric(self.complex_symmetric())
427 .threshold(self.threshold())
428 .build()
429 .map_err(|err| format_err!(err))
430 }
431 }
432 }
433}
434
435impl<'a, T> MolecularOrbital<'a, Complex<T>, SpinOrbitCoupled>
436where
437 T: Float + FloatConst + Lapack + From<u16>,
438 Complex<T>: Lapack,
439{
440 pub fn to_total_density(&'a self) -> Result<Density<'a, Complex<T>>, anyhow::Error> {
442 Err(format_err!(
443 "The total density of a spin--orbit-coupled molecular orbital is not implemented."
444 ))
445 }
493}
494
495impl<'a, T, SC> From<MolecularOrbital<'a, T, SC>> for MolecularOrbital<'a, Complex<T>, SC>
503where
504 T: Float + FloatConst + Lapack,
505 Complex<T>: Lapack,
506 SC: StructureConstraint + Clone,
507{
508 fn from(value: MolecularOrbital<'a, T, SC>) -> Self {
509 MolecularOrbital::<'a, Complex<T>, SC>::builder()
510 .coefficients(value.coefficients.map(Complex::from))
511 .baos(value.baos.clone())
512 .mol(value.mol)
513 .structure_constraint(value.structure_constraint)
514 .component_index(value.component_index)
515 .complex_symmetric(value.complex_symmetric)
516 .threshold(value.threshold)
517 .build()
518 .expect("Unable to construct a complex molecular orbital.")
519 }
520}
521
522impl<'a, T, SC> PartialEq for MolecularOrbital<'a, T, SC>
526where
527 T: ComplexFloat<Real = f64> + Lapack,
528 SC: StructureConstraint + PartialEq,
529{
530 fn eq(&self, other: &Self) -> bool {
531 let thresh = (self.threshold * other.threshold).sqrt();
532 let coefficients_eq = approx::relative_eq!(
533 (&self.coefficients - &other.coefficients)
534 .map(|x| ComplexFloat::abs(*x).powi(2))
535 .sum()
536 .sqrt(),
537 0.0,
538 epsilon = thresh,
539 max_relative = thresh,
540 );
541 self.structure_constraint == other.structure_constraint
542 && self.component_index == other.component_index
543 && self.baos == other.baos
544 && self.mol == other.mol
545 && coefficients_eq
546 }
547}
548
549impl<'a, T, SC> Eq for MolecularOrbital<'a, T, SC>
553where
554 T: ComplexFloat<Real = f64> + Lapack,
555 SC: StructureConstraint + Eq,
556{
557}
558
559impl<'a, T, SC> fmt::Debug for MolecularOrbital<'a, T, SC>
563where
564 T: fmt::Debug + ComplexFloat + Lapack,
565 SC: StructureConstraint + fmt::Debug,
566{
567 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
568 write!(
569 f,
570 "MolecularOrbital[{:?} (spin index {}): coefficient array of length {}]",
571 self.structure_constraint,
572 self.component_index,
573 self.coefficients.len()
574 )?;
575 Ok(())
576 }
577}
578
579impl<'a, T, SC> fmt::Display for MolecularOrbital<'a, T, SC>
583where
584 T: fmt::Display + ComplexFloat + Lapack,
585 SC: StructureConstraint + fmt::Display,
586{
587 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588 write!(
589 f,
590 "MolecularOrbital[{} (spin index {}): coefficient array of length {}]",
591 self.structure_constraint,
592 self.component_index,
593 self.coefficients.len()
594 )?;
595 Ok(())
596 }
597}