1use std::fmt::{Display, LowerExp};
2use std::marker::PhantomData;
3
4use anyhow::{self, ensure, format_err};
5use derive_builder::Builder;
6use itertools::Itertools;
7use ndarray::{Array2, ArrayView2, ArrayView4, Axis, ScalarOperand};
8use ndarray_linalg::types::Lapack;
9use num::FromPrimitive;
10use num_complex::ComplexFloat;
11
12use crate::angmom::spinor_rotation_3d::StructureConstraint;
13use crate::symmetry::symmetry_transformation::SymmetryTransformable;
14use crate::target::determinant::SlaterDeterminant;
15use crate::target::noci::backend::nonortho::{
16 calc_lowdin_pairing, calc_o0_matrix_element, calc_o1_matrix_element, calc_o2_matrix_element,
17};
18
19use super::OrbitMatrix;
20
21#[cfg(test)]
22#[path = "hamiltonian_tests.rs"]
23mod hamiltonian_tests;
24
25#[derive(Builder)]
27pub struct HamiltonianAO<'a, T, SC, F>
28where
29 T: ComplexFloat + Lapack,
30 SC: StructureConstraint + Clone,
31 F: Fn(&Array2<T>) -> Result<(Array2<T>, Array2<T>), anyhow::Error> + Clone,
32{
33 enuc: T,
35
36 onee: ArrayView2<'a, T>,
39
40 #[builder(default = "None")]
43 twoe: Option<ArrayView4<'a, T>>,
44
45 #[builder(default = "None")]
46 get_jk: Option<F>,
47
48 #[builder(setter(skip), default = "PhantomData")]
50 structure_constraint: PhantomData<SC>,
51}
52
53impl<'a, T, SC, F> HamiltonianAO<'a, T, SC, F>
54where
55 T: ComplexFloat + Lapack,
56 SC: StructureConstraint + Clone,
57 F: Fn(&Array2<T>) -> Result<(Array2<T>, Array2<T>), anyhow::Error> + Clone,
58{
59 pub fn builder() -> HamiltonianAOBuilder<'a, T, SC, F> {
61 HamiltonianAOBuilder::<'a, T, SC, F>::default()
62 }
63
64 pub fn enuc(&self) -> T {
66 self.enuc
67 }
68
69 pub fn onee(&'a self) -> &'a ArrayView2<'a, T> {
72 &self.onee
73 }
74
75 pub fn twoe(&self) -> Option<&ArrayView4<'a, T>> {
78 self.twoe.as_ref()
79 }
80
81 pub fn get_jk(&self) -> Option<&F> {
82 self.get_jk.as_ref()
83 }
84}
85
86impl<'a, T, SC, F> HamiltonianAO<'a, T, SC, F>
87where
88 T: ComplexFloat + Lapack + ScalarOperand + FromPrimitive,
89 <T as ComplexFloat>::Real: LowerExp,
90 SC: StructureConstraint + Display + PartialEq + Clone,
91 F: Fn(&Array2<T>) -> Result<(Array2<T>, Array2<T>), anyhow::Error> + Clone,
92{
93 pub fn calc_hamiltonian_matrix_element_contributions(
118 &self,
119 det_w: &SlaterDeterminant<T, SC>,
120 det_x: &SlaterDeterminant<T, SC>,
121 sao: &ArrayView2<T>,
122 thresh_offdiag: <T as ComplexFloat>::Real,
123 thresh_zeroov: <T as ComplexFloat>::Real,
124 ) -> Result<(T, T, T), anyhow::Error> {
125 ensure!(
126 det_w.structure_constraint() == det_x.structure_constraint(),
127 "Inconsistent spin constraints: {} != {}.",
128 det_w.structure_constraint(),
129 det_x.structure_constraint(),
130 );
131 let sc = det_w.structure_constraint();
132
133 if det_w.complex_symmetric() != det_x.complex_symmetric() {
134 return Err(format_err!(
135 "The `complex_symmetric` booleans of the specified determinants do not match: `det_w` (`{}`) != `det_x` (`{}`).",
136 det_w.complex_symmetric(),
137 det_x.complex_symmetric(),
138 ));
139 }
140 let complex_symmetric = det_w.complex_symmetric();
141
142 let lowdin_paired_coefficientss = det_w
143 .coefficients()
144 .iter()
145 .zip(det_w.occupations().iter())
146 .zip(det_x.coefficients().iter().zip(det_x.occupations().iter()))
147 .map(|((cw, occw), (cx, occx))| {
148 let occw_indices = occw
149 .iter()
150 .enumerate()
151 .filter_map(|(i, occ_i)| {
152 if occ_i.abs() >= det_w.threshold() {
153 Some(i)
154 } else {
155 None
156 }
157 })
158 .collect::<Vec<_>>();
159 let cw_occ = cw.select(Axis(1), &occw_indices);
160 let occx_indices = occx
161 .iter()
162 .enumerate()
163 .filter_map(|(i, occ_i)| {
164 if occ_i.abs() >= det_x.threshold() {
165 Some(i)
166 } else {
167 None
168 }
169 })
170 .collect::<Vec<_>>();
171 let cx_occ = cx.select(Axis(1), &occx_indices);
172 calc_lowdin_pairing(
173 &cw_occ.view(),
174 &cx_occ.view(),
175 sao,
176 complex_symmetric,
177 thresh_offdiag,
178 thresh_zeroov,
179 )
180 })
181 .collect::<Result<Vec<_>, _>>()?;
182
183 let zeroe_h_wx = calc_o0_matrix_element(&lowdin_paired_coefficientss, self.enuc, sc)?;
184 let onee_h_wx = calc_o1_matrix_element(&lowdin_paired_coefficientss, &self.onee, sc)?;
185 let twoe_h_wx =
186 calc_o2_matrix_element(&lowdin_paired_coefficientss, self.twoe(), self.get_jk(), sc)?;
187 Ok((zeroe_h_wx, onee_h_wx, twoe_h_wx))
188 }
189
190 pub fn calc_hamiltonian_matrix(
214 &self,
215 dets: &[&SlaterDeterminant<T, SC>],
216 sao: &ArrayView2<T>,
217 thresh_offdiag: <T as ComplexFloat>::Real,
218 thresh_zeroov: <T as ComplexFloat>::Real,
219 ) -> Result<Array2<T>, anyhow::Error> {
220 let dim = dets.len();
221 let mut hmat = Array2::<T>::zeros((dim, dim));
222 for pair in dets.iter().enumerate().combinations_with_replacement(2) {
223 let (w, det_w) = &pair[0];
224 let (x, det_x) = &pair[1];
225 let (zeroe_wx, onee_wx, twoe_wx) = self.calc_hamiltonian_matrix_element_contributions(
226 det_w,
227 det_x,
228 sao,
229 thresh_offdiag,
230 thresh_zeroov,
231 )?;
232 hmat[(*w, *x)] = zeroe_wx + onee_wx + twoe_wx;
233 if *w != *x {
234 let (zeroe_xw, onee_xw, twoe_xw) = self
235 .calc_hamiltonian_matrix_element_contributions(
236 det_x,
237 det_w,
238 sao,
239 thresh_offdiag,
240 thresh_zeroov,
241 )?;
242 hmat[(*x, *w)] = zeroe_xw + onee_xw + twoe_xw;
243 }
244 }
245 Ok(hmat)
246 }
247}
248
249impl<'a, T, SC, F> OrbitMatrix<'a, T, SC> for &HamiltonianAO<'a, T, SC, F>
250where
251 T: ComplexFloat + Lapack + ScalarOperand + FromPrimitive,
252 <T as ComplexFloat>::Real: LowerExp,
253 SC: StructureConstraint + Clone + Display + PartialEq,
254 SlaterDeterminant<'a, T, SC>: SymmetryTransformable,
255 F: Fn(&Array2<T>) -> Result<(Array2<T>, Array2<T>), anyhow::Error> + Clone,
256{
257 type MatrixElement = T;
258
259 fn calc_matrix_element(
260 &self,
261 det_w: &SlaterDeterminant<T, SC>,
262 det_x: &SlaterDeterminant<T, SC>,
263 sao: &ArrayView2<T>,
264 thresh_offdiag: <T as ComplexFloat>::Real,
265 thresh_zeroov: <T as ComplexFloat>::Real,
266 ) -> Result<T, anyhow::Error> {
267 let (zeroe, onee, twoe) = self.calc_hamiltonian_matrix_element_contributions(
268 det_w,
269 det_x,
270 sao,
271 thresh_offdiag,
272 thresh_zeroov,
273 )?;
274 Ok(zeroe + onee + twoe)
275 }
276
277 fn t(x: &T) -> T {
278 *x
279 }
280
281 fn conj(x: &T) -> T {
282 <T as ComplexFloat>::conj(*x)
283 }
284
285 fn zero(&self) -> T {
286 T::zero()
287 }
288}
289
290impl<'a, T, SC, F> OrbitMatrix<'a, T, SC> for HamiltonianAO<'a, T, SC, F>
291where
292 T: ComplexFloat + Lapack + ScalarOperand + FromPrimitive,
293 <T as ComplexFloat>::Real: LowerExp,
294 SC: StructureConstraint + Clone + Display + PartialEq,
295 SlaterDeterminant<'a, T, SC>: SymmetryTransformable,
296 F: Fn(&Array2<T>) -> Result<(Array2<T>, Array2<T>), anyhow::Error> + Clone,
297{
298 type MatrixElement = T;
299
300 fn calc_matrix_element(
301 &self,
302 det_w: &SlaterDeterminant<T, SC>,
303 det_x: &SlaterDeterminant<T, SC>,
304 sao: &ArrayView2<T>,
305 thresh_offdiag: <T as ComplexFloat>::Real,
306 thresh_zeroov: <T as ComplexFloat>::Real,
307 ) -> Result<T, anyhow::Error> {
308 (&self).calc_matrix_element(det_w, det_x, sao, thresh_offdiag, thresh_zeroov)
309 }
310
311 fn t(x: &T) -> T {
312 *x
313 }
314
315 fn conj(x: &T) -> T {
316 <T as ComplexFloat>::conj(*x)
317 }
318
319 fn zero(&self) -> T {
320 T::zero()
321 }
322}