1use anyhow::{self, ensure, format_err};
2use duplicate::duplicate_item;
3use itertools::Itertools;
4use ndarray::{stack, Array1, Array2, ArrayView1, ArrayView2, Axis, Ix2, LinalgScalar};
5use ndarray_einsum::einsum;
6use ndarray_linalg::{
7 Eig, EigGeneralized, Eigh, GeneralizedEigenvalue, Lapack, Norm, Scalar, UPLO,
8};
9use num::traits::FloatConst;
10use num::{Float, One};
11use num_complex::{Complex, ComplexFloat};
12
13use crate::analysis::EigenvalueComparisonMode;
14
15use crate::target::noci::backend::nonortho::CanonicalOrthogonalisable;
16
17pub mod noci;
18
19#[cfg(test)]
20#[path = "solver_tests.rs"]
21mod solver_tests;
22
23pub trait GeneralisedEigenvalueSolvable {
34 type NumType;
36
37 type RealType;
39
40 fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
61 &self,
62 complex_symmetric: bool,
63 thresh_offdiag: Self::RealType,
64 thresh_zeroov: Self::RealType,
65 eigenvalue_comparison_mode: EigenvalueComparisonMode,
66 ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
67
68 fn solve_generalised_eigenvalue_problem_with_ggev(
86 &self,
87 complex_symmetric: bool,
88 thresh_offdiag: Self::RealType,
89 thresh_zeroov: Self::RealType,
90 eigenvalue_comparison_mode: EigenvalueComparisonMode,
91 ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error>;
92}
93
94pub struct GeneralisedEigenvalueResult<T> {
97 eigenvalues: Array1<T>,
99
100 eigenvectors: Array2<T>,
102}
103
104impl<T> GeneralisedEigenvalueResult<T> {
105 pub fn eigenvalues(&self) -> ArrayView1<T> {
107 self.eigenvalues.view()
108 }
109
110 pub fn eigenvectors(&self) -> ArrayView2<T> {
112 self.eigenvectors.view()
113 }
114}
115
116#[duplicate_item(
117 [
118 dtype_ [ f64 ]
119 ]
120 [
121 dtype_ [ f32 ]
122 ]
123)]
124impl GeneralisedEigenvalueSolvable for (&ArrayView2<'_, dtype_>, &ArrayView2<'_, dtype_>) {
125 type NumType = dtype_;
126 type RealType = dtype_;
127
128 fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
129 &self,
130 _: bool,
131 thresh_offdiag: dtype_,
132 thresh_zeroov: dtype_,
133 eigenvalue_comparison_mode: EigenvalueComparisonMode,
134 ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
135 let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
136
137 ensure!(
139 (hmat.to_owned() - hmat.t()).norm_l2() <= thresh_offdiag,
140 "Hamiltonian matrix is not real-symmetric."
141 );
142
143 let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
147 true,
148 false,
149 thresh_offdiag,
150 thresh_zeroov,
151 )?;
152
153 let xmat = xmat_res.xmat();
154 let xmat_d = xmat_res.xmat_d();
155
156 let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
157 let smat_t = xmat_d.dot(&smat).dot(&xmat);
158
159 let max_diff = (&smat_t - &Array2::<dtype_>::eye(smat_t.nrows()))
162 .iter()
163 .map(|x| ComplexFloat::abs(*x))
164 .max_by(|x, y| {
165 x.partial_cmp(y)
166 .expect("Unable to compare two `abs` values.")
167 })
168 .ok_or_else(|| {
169 format_err!("Unable to determine the maximum element of the |S - I| matrix.")
170 })?;
171 ensure!(
172 max_diff <= thresh_offdiag,
173 "The orthogonalised overlap matrix is not the identity matrix."
174 );
175
176 let (eigvals_t, eigvecs_t) = hmat_t.eigh(UPLO::Lower)?;
177
178 let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
180 &eigvals_t.view(),
181 &eigvecs_t.view(),
182 &eigenvalue_comparison_mode,
183 );
184 let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
185
186 let eigvecs_sorted_normalised =
188 normalise_eigenvectors_real(&eigvecs_sorted.view(), &smat.view(), thresh_offdiag)?;
189
190 let eigvecs_sorted_normalised_regularised =
192 regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
193
194 Ok(GeneralisedEigenvalueResult {
195 eigenvalues: eigvals_t_sorted,
196 eigenvectors: eigvecs_sorted_normalised_regularised,
197 })
198 }
199
200 fn solve_generalised_eigenvalue_problem_with_ggev(
201 &self,
202 _: bool,
203 thresh_offdiag: dtype_,
204 thresh_zeroov: dtype_,
205 eigenvalue_comparison_mode: EigenvalueComparisonMode,
206 ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
207 let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
208
209 ensure!(
211 (hmat.to_owned() - hmat.t()).norm_l2() <= thresh_offdiag,
212 "Hamiltonian matrix is not real-symmetric."
213 );
214 ensure!(
215 (smat.to_owned() - smat.t()).norm_l2() <= thresh_offdiag,
216 "Overlap matrix is not real-symmetric."
217 );
218
219 let (geneigvals, eigvecs) =
220 (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
221
222 for gv in geneigvals.iter() {
223 if let GeneralizedEigenvalue::Finite(v, _) = gv {
224 ensure!(
225 v.im().abs() <= thresh_offdiag,
226 "Unexpected complex eigenvalue {v} for real, symmetric S and H."
227 );
228 }
229 }
230
231 let mut indices = (0..geneigvals.len())
233 .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
234 .collect_vec();
235
236 match eigenvalue_comparison_mode {
237 EigenvalueComparisonMode::Modulus => {
238 indices.sort_by(|i, j| {
239 if let (
240 GeneralizedEigenvalue::Finite(e_i, _),
241 GeneralizedEigenvalue::Finite(e_j, _),
242 ) = (&geneigvals[*i], &geneigvals[*j])
243 {
244 ComplexFloat::abs(*e_i)
245 .partial_cmp(&ComplexFloat::abs(*e_j))
246 .unwrap()
247 } else {
248 panic!("Unable to compare some eigenvalues.")
249 }
250 });
251 }
252 EigenvalueComparisonMode::Real => {
253 indices.sort_by(|i, j| {
254 if let (
255 GeneralizedEigenvalue::Finite(e_i, _),
256 GeneralizedEigenvalue::Finite(e_j, _),
257 ) = (&geneigvals[*i], &geneigvals[*j])
258 {
259 e_i.re().partial_cmp(&e_j.re()).unwrap()
260 } else {
261 panic!("Unable to compare some eigenvalues.")
262 }
263 });
264 }
265 }
266
267 let eigvals_re_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
268 if let GeneralizedEigenvalue::Finite(v, _) = gv {
269 v.re()
270 } else {
271 panic!("Unexpected indeterminate eigenvalue.")
272 }
273 });
274 let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
275 ensure!(
276 eigvecs_sorted.iter().all(|v| v.im().abs() < thresh_offdiag),
277 "Unexpected complex eigenvectors."
278 );
279 let eigvecs_re_sorted = eigvecs_sorted.map(|v| v.re());
280
281 let eigvecs_re_sorted_normalised =
283 normalise_eigenvectors_real(&eigvecs_re_sorted.view(), &smat.view(), thresh_offdiag)?;
284
285 let eigvecs_re_sorted_normalised_regularised =
287 regularise_eigenvectors(&eigvecs_re_sorted_normalised.view(), thresh_offdiag);
288
289 Ok(GeneralisedEigenvalueResult {
290 eigenvalues: eigvals_re_sorted,
291 eigenvectors: eigvecs_re_sorted_normalised_regularised,
292 })
293 }
294}
295
296impl<T> GeneralisedEigenvalueSolvable for (&ArrayView2<'_, Complex<T>>, &ArrayView2<'_, Complex<T>>)
297where
298 T: Float + FloatConst + Scalar<Complex = Complex<T>>,
299 Complex<T>: ComplexFloat<Real = T> + Scalar<Real = T, Complex = Complex<T>> + Lapack,
300 for<'a> ArrayView2<'a, Complex<T>>:
301 CanonicalOrthogonalisable<NumType = Complex<T>, RealType = T>,
302{
303 type NumType = Complex<T>;
304
305 type RealType = T;
306
307 fn solve_generalised_eigenvalue_problem_with_canonical_orthogonalisation(
308 &self,
309 complex_symmetric: bool,
310 thresh_offdiag: T,
311 thresh_zeroov: T,
312 eigenvalue_comparison_mode: EigenvalueComparisonMode,
313 ) -> Result<GeneralisedEigenvalueResult<Complex<T>>, anyhow::Error> {
314 let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
315
316 if complex_symmetric {
317 ensure!(
319 (hmat.to_owned() - hmat.t()).norm_l2() <= thresh_offdiag,
320 "Hamiltonian matrix is not complex-symmetric."
321 );
322 } else {
323 ensure!(
325 (hmat.to_owned() - hmat.map(|v| v.conj()).t()).norm_l2() <= thresh_offdiag,
326 "Hamiltonian matrix is not complex-Hermitian."
327 );
328 }
329
330 let xmat_res = smat.view().calc_canonical_orthogonal_matrix(
333 complex_symmetric,
334 false,
335 thresh_offdiag,
336 thresh_zeroov,
337 )?;
338
339 let xmat = xmat_res.xmat();
340 let xmat_d = xmat_res.xmat_d();
341
342 let hmat_t = xmat_d.dot(&hmat).dot(&xmat);
343 let smat_t = xmat_d.dot(&smat).dot(&xmat);
344 let smat_t_d = smat_t.map(|v| v.conj()).t().to_owned();
345
346 let max_diff = (&smat_t_d.dot(&smat_t) - &Array2::<T>::eye(smat_t.nrows()))
348 .iter()
349 .map(|x| ComplexFloat::abs(*x))
350 .max_by(|x, y| {
351 x.partial_cmp(y)
352 .expect("Unable to compare two `abs` values.")
353 })
354 .ok_or_else(|| {
355 format_err!("Unable to determine the maximum element of the |S - I| matrix.")
356 })?;
357 ensure!(
358 max_diff <= thresh_offdiag,
359 "The orthogonalised overlap matrix is not the identity matrix."
360 );
361 let smat_t_d_hmat_t = smat_t_d.dot(&hmat_t);
362
363 let (eigvals_t, eigvecs_t) = smat_t_d_hmat_t.eig()?;
364
365 let (eigvals_t_sorted, eigvecs_t_sorted) = sort_eigenvalues_eigenvectors(
367 &eigvals_t.view(),
368 &eigvecs_t.view(),
369 &eigenvalue_comparison_mode,
370 );
371 let eigvecs_sorted = xmat.dot(&eigvecs_t_sorted);
372
373 let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
375 &eigvecs_sorted.view(),
376 &smat.view(),
377 complex_symmetric,
378 thresh_offdiag,
379 )?;
380
381 let eigvecs_sorted_normalised_regularised =
383 regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
384
385 Ok(GeneralisedEigenvalueResult {
386 eigenvalues: eigvals_t_sorted,
387 eigenvectors: eigvecs_sorted_normalised_regularised,
388 })
389 }
390
391 fn solve_generalised_eigenvalue_problem_with_ggev(
392 &self,
393 complex_symmetric: bool,
394 thresh_offdiag: T,
395 thresh_zeroov: T,
396 eigenvalue_comparison_mode: EigenvalueComparisonMode,
397 ) -> Result<GeneralisedEigenvalueResult<Self::NumType>, anyhow::Error> {
398 let (hmat, smat) = (self.0.to_owned(), self.1.to_owned());
399
400 if complex_symmetric {
401 ensure!(
403 (hmat.to_owned() - hmat.t()).norm_l2() <= thresh_offdiag,
404 "Hamiltonian matrix is not complex-symmetric."
405 );
406 ensure!(
407 (smat.to_owned() - smat.t()).norm_l2() <= thresh_offdiag,
408 "Overlap matrix is not complex-symmetric."
409 );
410 } else {
411 ensure!(
413 (hmat.to_owned() - hmat.map(|v| v.conj()).t()).norm_l2() <= thresh_offdiag,
414 "Hamiltonian matrix is not complex-Hermitian."
415 );
416 ensure!(
417 (smat.to_owned() - smat.map(|v| v.conj()).t()).norm_l2() <= thresh_offdiag,
418 "Overlap matrix is not complex-Hermitian."
419 );
420 }
421
422 let (geneigvals, eigvecs) =
423 (hmat.clone(), smat.clone()).eig_generalized(Some(thresh_zeroov))?;
424
425 let mut indices = (0..geneigvals.len())
427 .filter(|i| matches!(geneigvals[*i], GeneralizedEigenvalue::Finite(_, _)))
428 .collect_vec();
429
430 match eigenvalue_comparison_mode {
431 EigenvalueComparisonMode::Modulus => {
432 indices.sort_by(|i, j| {
433 if let (
434 GeneralizedEigenvalue::Finite(e_i, _),
435 GeneralizedEigenvalue::Finite(e_j, _),
436 ) = (&geneigvals[*i], &geneigvals[*j])
437 {
438 ComplexFloat::abs(*e_i)
439 .partial_cmp(&ComplexFloat::abs(*e_j))
440 .unwrap()
441 } else {
442 panic!("Unable to compare some eigenvalues.")
443 }
444 });
445 }
446 EigenvalueComparisonMode::Real => {
447 indices.sort_by(|i, j| {
448 if let (
449 GeneralizedEigenvalue::Finite(e_i, _),
450 GeneralizedEigenvalue::Finite(e_j, _),
451 ) = (&geneigvals[*i], &geneigvals[*j])
452 {
453 e_i.re().partial_cmp(&e_j.re()).unwrap()
454 } else {
455 panic!("Unable to compare some eigenvalues.")
456 }
457 });
458 }
459 }
460
461 let eigvals_sorted = geneigvals.select(Axis(0), &indices).map(|gv| {
462 if let GeneralizedEigenvalue::Finite(v, _) = gv {
463 *v
464 } else {
465 panic!("Unexpected indeterminate eigenvalue.")
466 }
467 });
468 let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
469
470 let eigvecs_sorted_normalised = normalise_eigenvectors_complex(
472 &eigvecs_sorted.view(),
473 &smat.view(),
474 complex_symmetric,
475 thresh_offdiag,
476 )?;
477
478 let eigvecs_sorted_normalised_regularised =
480 regularise_eigenvectors(&eigvecs_sorted_normalised.view(), thresh_offdiag);
481
482 Ok(GeneralisedEigenvalueResult {
483 eigenvalues: eigvals_sorted,
484 eigenvectors: eigvecs_sorted_normalised_regularised,
485 })
486 }
487}
488
489fn sort_eigenvalues_eigenvectors<T: ComplexFloat>(
505 eigvals: &ArrayView1<T>,
506 eigvecs: &ArrayView2<T>,
507 eigenvalue_comparison_mode: &EigenvalueComparisonMode,
508) -> (Array1<T>, Array2<T>) {
509 let mut indices = (0..eigvals.len()).collect_vec();
510 match eigenvalue_comparison_mode {
511 EigenvalueComparisonMode::Modulus => {
512 indices.sort_by(|i, j| {
513 ComplexFloat::abs(eigvals[*i])
514 .partial_cmp(&ComplexFloat::abs(eigvals[*j]))
515 .unwrap()
516 });
517 }
518 EigenvalueComparisonMode::Real => {
519 indices.sort_by(|i, j| eigvals[*i].re().partial_cmp(&eigvals[*j].re()).unwrap());
520 }
521 }
522 let eigvals_sorted = eigvals.select(Axis(0), &indices);
523 let eigvecs_sorted = eigvecs.select(Axis(1), &indices);
524 (eigvals_sorted, eigvecs_sorted)
525}
526
527fn regularise_eigenvectors<T>(eigvecs: &ArrayView2<T>, thresh: T::Real) -> Array2<T>
539where
540 T: ComplexFloat + One,
541 T::Real: Float,
542{
543 let eigvecs_sgn = stack!(
544 Axis(0),
545 eigvecs
546 .row(0)
547 .map(|v| {
548 if Float::abs(ComplexFloat::re(*v)) > thresh {
549 T::from(v.re().signum()).expect("Unable to convert a signum to the right type.")
550 } else if Float::abs(ComplexFloat::im(*v)) > thresh {
551 T::from(v.im().signum()).expect("Unable to convert a signum to the right type.")
552 } else {
553 T::one()
554 }
555 })
556 .view()
557 );
558 let eigvecs_regularised = eigvecs * eigvecs_sgn;
559 eigvecs_regularised
560}
561
562fn normalise_eigenvectors_real<T>(
574 eigvecs: &ArrayView2<T>,
575 smat: &ArrayView2<T>,
576 thresh: T,
577) -> Result<Array2<T>, anyhow::Error>
578where
579 T: LinalgScalar + Float,
580{
581 let sq_norm = einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
582 .map_err(|err| format_err!(err))?
583 .into_dimensionality::<Ix2>()
584 .map_err(|err| format_err!(err))?;
585 let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
586 .iter()
587 .map(|x| x.abs())
588 .max_by(|x, y| {
589 x.partial_cmp(y)
590 .expect("Unable to compare two `abs` values.")
591 })
592 .ok_or_else(|| {
593 format_err!(
594 "Unable to determine the maximum off-diagonal element of the C^T.S.C matrix."
595 )
596 })?;
597
598 ensure!(
599 max_diff <= thresh,
600 "The C^T.S.C matrix is not a diagonal matrix."
601 );
602 ensure!(
603 sq_norm.diag().iter().all(|v| *v > T::zero()),
604 "Some eigenvectors have negative squared norms and cannot be normalised over the reals."
605 );
606 let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
607 Ok(eigvecs_normalised)
608}
609
610fn normalise_eigenvectors_complex<T>(
623 eigvecs: &ArrayView2<T>,
624 smat: &ArrayView2<T>,
625 complex_symmetric: bool,
626 thresh: T::Real,
627) -> Result<Array2<T>, anyhow::Error>
628where
629 T: LinalgScalar + ComplexFloat + std::fmt::Display,
630 T::Real: Float,
631{
632 let sq_norm = if complex_symmetric {
633 einsum("ji,jk,kl->il", &[eigvecs, smat, eigvecs])
634 .map_err(|err| format_err!(err))?
635 .into_dimensionality::<Ix2>()
636 .map_err(|err| format_err!(err))?
637 } else {
638 einsum(
639 "ji,jk,kl->il",
640 &[&eigvecs.map(|v| v.conj()).view(), smat, eigvecs],
641 )
642 .map_err(|err| format_err!(err))?
643 .into_dimensionality::<Ix2>()
644 .map_err(|err| format_err!(err))?
645 };
646 let max_diff = (&sq_norm - &Array2::from_diag(&sq_norm.diag()))
647 .iter()
648 .map(|x| ComplexFloat::abs(*x))
649 .max_by(|x, y| {
650 x.partial_cmp(y)
651 .expect("Unable to compare two `abs` values.")
652 })
653 .ok_or_else(|| {
654 if complex_symmetric {
655 format_err!(
656 "Unable to determine the maximum off-diagonal element of the C^†.S.C matrix."
657 )
658 } else {
659 format_err!(
660 "Unable to determine the maximum off-diagonal element of the C^†.S.C matrix."
661 )
662 }
663 })?;
664
665 ensure!(
666 max_diff <= thresh,
667 if complex_symmetric {
668 "The C^T.S.C matrix is not a diagonal matrix."
669 } else {
670 "The C^†.S.C matrix is not a diagonal matrix."
671 }
672 );
673 let eigvecs_normalised = eigvecs / sq_norm.diag().map(|v| v.sqrt());
674 Ok(eigvecs_normalised)
675}