qsym2/bindings/python/
integrals.rs

1//! Python bindings for QSym² atomic-orbital integral evaluations.
2
3use anyhow::{self, bail, ensure, format_err};
4use lazy_static::lazy_static;
5#[cfg(feature = "integrals")]
6use nalgebra::{Point3, Vector3};
7#[cfg(feature = "integrals")]
8use num_complex::Complex;
9#[cfg(feature = "integrals")]
10use numpy::{IntoPyArray, PyArray2, PyArray4};
11use periodic_table;
12#[cfg(feature = "integrals")]
13use pyo3::exceptions::PyValueError;
14use pyo3::prelude::*;
15use pyo3::types::PyType;
16#[cfg(feature = "qchem")]
17use regex::Regex;
18use serde::{Deserialize, Serialize};
19
20use crate::angmom::spinor_rotation_3d::{SpinConstraint, SpinOrbitCoupled};
21use crate::auxiliary::molecule::Molecule;
22use crate::basis::ao::{
23    BasisAngularOrder, BasisAtom, BasisShell, CartOrder, PureOrder, ShellOrder, SpinorOrder,
24};
25#[cfg(feature = "integrals")]
26use crate::basis::ao_integrals::{BasisSet, BasisShellContraction, GaussianContraction};
27#[cfg(feature = "integrals")]
28use crate::integrals::shell_tuple::build_shell_tuple_collection;
29#[cfg(feature = "qchem")]
30use crate::io::format::{log_title, qsym2_output, QSym2Output};
31
32#[cfg(feature = "qchem")]
33lazy_static! {
34    static ref SP_PATH_RE: Regex =
35        Regex::new(r"(.*sp)\\energy_function$").expect("Regex pattern invalid.");
36}
37
38/// Python-exposed enumerated type to handle the union type `(bool, bool) | (list[int], bool)` in
39/// Python for specifying pure-spherical-harmonic order or spinor order.
40#[derive(Clone, FromPyObject)]
41pub enum PyPureSpinorOrder {
42    /// Variant for standard pure or spinor shell order. The first associated boolean indicates if
43    /// the functions are arranged in increasing-$`m`$ order, and the second associated boolean
44    /// indicates if the shell is even with respect to spatial inversion.
45    Standard((bool, bool)),
46
47    /// Variant for custom pure or spinor shell order. The associated vector contains a sequence of
48    /// integers specifying the order of $`m`$ values for pure or $`2m`$ values for spinor in the
49    /// shell, and the associated boolean indicates if the shell is even with respect to spatial
50    /// inversion.
51    Custom((Vec<i32>, bool)),
52}
53
54/// Python-exposed enumerated type to handle the `ShellOrder` union type `bool |
55/// Optional[list[tuple[int, int, int]]]` in Python.
56#[derive(Clone, FromPyObject)]
57pub enum PyShellOrder {
58    /// Variant for pure or spinor shell order. The associated value is either a boolean indicating
59    /// if the functions are arranged in increasing-$`m`$ order, or a sequence of integers specifying
60    /// a custom $`m`$-order for pure or $`2m`$-order for spinor.
61    ///
62    /// Python type: `bool | list[int]`.
63    PureSpinorOrder(PyPureSpinorOrder),
64
65    /// Variant for Cartesian shell order. If the associated `Option` is `None`, the order will be
66    /// taken to be lexicographic. Otherwise, the order will be as specified by the $`(x, y, z)`$
67    /// exponent tuples.
68    ///
69    /// Python type: Optional[list[tuple[int, int, int]]].
70    CartOrder(Option<Vec<(u32, u32, u32)>>),
71}
72
73// /// Enumerated type indicating the type of magnetic symmetry to be used for representation
74// /// analysis.
75#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
76#[pyclass(eq, eq_int)]
77pub enum ShellType {
78    /// Variant indicating that unitary representations should be used for magnetic symmetry
79    /// analysis.
80    Pure,
81
82    /// Variant indicating that magnetic corepresentations should be used for magnetic symmetry
83    /// analysis.
84    Spinor,
85
86    Cartesian,
87}
88
89/// Python-exposed structure to marshal basis angular order information between Python and Rust.
90///
91/// # Constructor arguments
92///
93/// * `basis_atoms` - A vector of tuples, each of which provides information for one basis
94/// atom in the form `(element, basis_shells)`. Here:
95///   * `element` is a string giving the element symbol of the atom, and
96///   * `basis_shells` is a vector of tuples, each of which provides information for one basis
97///   shell on the atom in the form `(angmom, cart, order)`. Here:
98///     * `angmom` is a symbol such as `"S"` or `"P"` for the angular momentum of the shell,
99///     * `cart` is a boolean indicating if the functions in the shell are Cartesian (`true`)
100///     or pure / solid harmonics (`false`), and
101///     * `order` specifies how the functions in the shell are ordered:
102///       * if `cart` is `true`, `order` can be `None` for lexicographic order, or a list of
103///       tuples `(lx, ly, lz)` specifying a custom order for the Cartesian functions where
104///       `lx`, `ly`, and `lz` are the $`x`$-, $`y`$-, and $`z`$-exponents, respectively;
105///       * if `cart` is `false`, `order` can be `true` for increasing-$`m`$ order, `false` for
106///       decreasing-$`m`$ order, or a list of $`m`$ values for custom order.
107///
108///   Python type:
109///   `list[tuple[str, list[tuple[str, bool, Optional[list[tuple[int, int, int]]] | bool | list[int]]]]]`.
110#[pyclass]
111pub struct PyBasisAngularOrder {
112    /// A vector of basis atoms. Each item in the vector is a tuple consisting of an atomic symbol
113    /// and a vector of basis shell quartets whose components give:
114    /// - the angular momentum symbol for the shell,
115    /// - `true` if the shell is Cartesian, `false` if the shell is pure,
116    /// - if the shell is Cartesian, then this has two possibilities:
117    ///   - either `None` if the Cartesian functions are in lexicographic order,
118    ///   - or `Some(vec![[lx, ly, lz], ...])` to specify a custom Cartesian order.
119    /// - if the shell is pure, then this is a boolean `increasingm` to indicate if the pure
120    /// functions in the shell are arranged in increasing-$`m`$ order, or a list of $`m`$ values
121    /// specifying a custom $`m`$ order.
122    ///
123    /// Python type: `list[tuple[str, list[tuple[str, bool, Optional[list[tuple[int, int, int]]] | bool | list[int]]]]]`.
124    basis_atoms: Vec<(String, Vec<(u32, ShellType, PyShellOrder)>)>,
125}
126
127#[pymethods]
128impl PyBasisAngularOrder {
129    /// Constructs a new `PyBasisAngularOrder` structure.
130    ///
131    /// # Arguments
132    ///
133    /// * `basis_atoms` - A vector of tuples, each of which provides information for one basis
134    /// atom in the form `(element, basis_shells)`. Here:
135    ///   * `element` is a string giving the element symbol of the atom, and
136    ///   * `basis_shells` is a vector of tuples, each of which provides information for one basis
137    ///   shell on the atom in the form `(angmom, cart, order)`. Here:
138    ///     * `angmom` is a symbol such as `"S"` or `"P"` for the angular momentum of the shell,
139    ///     * `cart` is a boolean indicating if the functions in the shell are Cartesian (`true`)
140    ///     or pure / solid harmonics (`false`), and
141    ///     * `order` specifies how the functions in the shell are ordered:
142    ///       * if `cart` is `true`, `order` can be `None` for lexicographic order, or a list of
143    ///       tuples `(lx, ly, lz)` specifying a custom order for the Cartesian functions where
144    ///       `lx`, `ly`, and `lz` are the $`x`$-, $`y`$-, and $`z`$-exponents, respectively;
145    ///       * if `cart` is `false`, `order` can be `true` for increasing-$`m`$ order, `false` for
146    ///       decreasing-$`m`$ order, or a list of $`m`$ values for custom order.
147    ///
148    ///   Python type:
149    ///   `list[tuple[str, list[tuple[str, bool, bool | Optional[list[tuple[int, int, int]]]]]]]`.
150    #[new]
151    fn new(basis_atoms: Vec<(String, Vec<(u32, ShellType, PyShellOrder)>)>) -> Self {
152        Self { basis_atoms }
153    }
154
155    /// Extracts basis angular order information from a Q-Chem HDF5 archive file.
156    ///
157    /// # Arguments
158    ///
159    /// * `filename` - A path to a Q-Chem HDF5 archive file. Python type: `str`.
160    ///
161    /// # Returns
162    ///
163    /// A sequence of `PyBasisAngularOrder` objects, one for each Q-Chem calculation found in the
164    /// HDF5 archive file. Python type: `list[PyBasisAngularOrder]`.
165    ///
166    /// A summary showing how the `PyBasisAngularOrder` objects map onto the Q-Chem calculations in
167    /// the HDF5 archive file is also logged at the `INFO` level.
168    #[cfg(feature = "qchem")]
169    #[classmethod]
170    fn from_qchem_archive(_cls: &Bound<'_, PyType>, filename: &str) -> PyResult<Vec<Self>> {
171        use hdf5;
172        use indexmap::IndexMap;
173        use num::ToPrimitive;
174
175        let f = hdf5::File::open(filename).map_err(|err| PyValueError::new_err(err.to_string()))?;
176        let mut sp_paths = f
177            .group(".counters")
178            .map_err(|err| PyValueError::new_err(err.to_string()))?
179            .member_names()
180            .map_err(|err| PyValueError::new_err(err.to_string()))?
181            .iter()
182            .filter_map(|path| {
183                if SP_PATH_RE.is_match(path) {
184                    let path = path.replace("\\", "/");
185                    Some(path.replace("/energy_function", ""))
186                } else {
187                    None
188                }
189            })
190            .collect::<Vec<_>>();
191        sp_paths.sort_by(|path_a, path_b| numeric_sort::cmp(path_a, path_b));
192
193        let elements = periodic_table::periodic_table();
194
195        log_title(&format!(
196            "Basis angular order extraction from Q-Chem HDF5 archive files",
197        ));
198        let pybaos = sp_paths
199            .iter()
200            .map(|sp_path| {
201                let sp_group = f
202                    .group(sp_path)
203                    .map_err(|err| PyValueError::new_err(err.to_string()))?;
204                let shell_types = sp_group
205                    .dataset("aobasis/shell_types")
206                    .map_err(|err| PyValueError::new_err(err.to_string()))?
207                    .read_1d::<i32>()
208                    .map_err(|err| PyValueError::new_err(err.to_string()))?;
209                let shell_to_atom_map = sp_group
210                    .dataset("aobasis/shell_to_atom_map")
211                    .map_err(|err| PyValueError::new_err(err.to_string()))?
212                    .read_1d::<usize>()
213                    .map_err(|err| PyValueError::new_err(err.to_string()))?
214                    .iter()
215                    .zip(shell_types.iter())
216                    .flat_map(|(&idx, shell_type)| {
217                        if *shell_type == -1 {
218                            vec![idx, idx]
219                        } else {
220                            vec![idx]
221                        }
222                    })
223                    .collect::<Vec<_>>();
224                let nuclei = sp_group
225                    .dataset("structure/nuclei")
226                    .map_err(|err| PyValueError::new_err(err.to_string()))?
227                    .read_1d::<usize>()
228                    .map_err(|err| PyValueError::new_err(err.to_string()))?;
229
230                let mut basis_atoms_map: IndexMap<usize, Vec<(u32, ShellType, PyShellOrder)>> =
231                    IndexMap::new();
232                shell_types.iter().zip(shell_to_atom_map.iter()).for_each(
233                    |(shell_type, atom_idx)| {
234                        if *shell_type == 0 {
235                            // S shell
236                            basis_atoms_map.entry(*atom_idx).or_insert(vec![]).push((
237                                0,
238                                ShellType::Cartesian,
239                                PyShellOrder::CartOrder(Some(CartOrder::qchem(0).cart_tuples)),
240                            ));
241                        } else if *shell_type == 1 {
242                            // P shell
243                            basis_atoms_map.entry(*atom_idx).or_insert(vec![]).push((
244                                1,
245                                ShellType::Cartesian,
246                                PyShellOrder::CartOrder(Some(CartOrder::qchem(1).cart_tuples)),
247                            ));
248                        } else if *shell_type == -1 {
249                            // SP shell
250                            basis_atoms_map
251                                .entry(*atom_idx)
252                                .or_insert(vec![])
253                                .extend_from_slice(&[
254                                    (
255                                        0,
256                                        ShellType::Cartesian,
257                                        PyShellOrder::CartOrder(Some(
258                                            CartOrder::qchem(0).cart_tuples,
259                                        )),
260                                    ),
261                                    (
262                                        1,
263                                        ShellType::Cartesian,
264                                        PyShellOrder::CartOrder(Some(
265                                            CartOrder::qchem(1).cart_tuples,
266                                        )),
267                                    ),
268                                ]);
269                        } else if *shell_type < 0 {
270                            // Cartesian D shell or higher
271                            let l = shell_type.unsigned_abs();
272                            // let l_usize = l
273                            //     .to_usize()
274                            //     .unwrap_or_else(|| panic!("Unable to convert the angular momentum value `|{shell_type}|` to `usize`."));
275                            basis_atoms_map.entry(*atom_idx).or_insert(vec![]).push((
276                                l,
277                                ShellType::Cartesian,
278                                PyShellOrder::CartOrder(Some(CartOrder::qchem(l).cart_tuples)),
279                            ));
280                        } else {
281                            // Pure D shell or higher
282                            let l = shell_type.unsigned_abs();
283                            // let l_usize = l
284                            //     .to_usize()
285                            //     .unwrap_or_else(|| panic!("Unable to convert the angular momentum value `|{shell_type}|` to `usize`."));
286                            basis_atoms_map.entry(*atom_idx).or_insert(vec![]).push((
287                                l,
288                                ShellType::Pure,
289                                PyShellOrder::PureSpinorOrder(PyPureSpinorOrder::Standard((
290                                    true,
291                                    l % 2 == 0,
292                                ))),
293                            ));
294                        }
295                    },
296                );
297                let pybao = basis_atoms_map
298                    .into_iter()
299                    .map(|(atom_idx, v)| {
300                        let element = elements
301                            .get(nuclei[atom_idx])
302                            .map(|el| el.symbol.to_string())
303                            .ok_or_else(|| {
304                                PyValueError::new_err(format!(
305                                    "Unable to identify an element for atom index `{atom_idx}`."
306                                ))
307                            })?;
308                        Ok((element, v))
309                    })
310                    .collect::<Result<Vec<_>, _>>()
311                    .map(|basis_atoms| Self::new(basis_atoms));
312                pybao
313            })
314            .collect::<Result<Vec<_>, _>>();
315
316        let idx_width = sp_paths.len().ilog10().to_usize().unwrap_or(4).max(4) + 1;
317        let sp_path_width = sp_paths
318            .iter()
319            .map(|sp_path| sp_path.chars().count())
320            .max()
321            .unwrap_or(10)
322            .max(10);
323        let table_width = idx_width + sp_path_width + 4;
324        qsym2_output!("");
325        "Each single-point calculation has associated with it a `PyBasisAngularOrder` object.\n\
326        The table below shows the `PyBasisAngularOrder` index in the generated list and the\n\
327        corresponding single-point calculation."
328            .log_output_display();
329        qsym2_output!("{}", "┈".repeat(table_width));
330        qsym2_output!(" {:<idx_width$}  {:<}", "Index", "Q-Chem job");
331        qsym2_output!("{}", "┈".repeat(table_width));
332        sp_paths.iter().enumerate().for_each(|(i, sp_path)| {
333            qsym2_output!(" {:<idx_width$}  {:<}", i, sp_path);
334        });
335        qsym2_output!("{}", "┈".repeat(table_width));
336        qsym2_output!("");
337
338        pybaos
339    }
340}
341
342impl PyBasisAngularOrder {
343    /// Extracts the information in the [`PyBasisAngularOrder`] structure into `QSym2`'s native
344    /// [`BasisAngularOrder`] structure.
345    ///
346    /// # Arguments
347    ///
348    /// * `mol` - The molecule with which the basis set information is associated.
349    ///
350    /// # Returns
351    ///
352    /// The [`BasisAngularOrder`] structure with the same information.
353    ///
354    /// # Errors
355    ///
356    /// Errors if the number of atoms or the atom elements in `mol` do not match the number of
357    /// atoms and atom elements in `self`, or if incorrect shell order types are specified.
358    pub fn to_qsym2<'b, 'a: 'b>(
359        &'b self,
360        mol: &'a Molecule,
361    ) -> Result<BasisAngularOrder<'b>, anyhow::Error> {
362        ensure!(
363            self.basis_atoms.len() == mol.atoms.len(),
364            "The number of basis atoms does not match the number of ordinary atoms."
365        );
366        let basis_atoms = self
367            .basis_atoms
368            .iter()
369            .zip(mol.atoms.iter())
370            .flat_map(|((element, basis_shells), atom)| {
371                ensure!(
372                    *element == atom.atomic_symbol,
373                    "Expected element `{element}`, but found atom `{}`.",
374                    atom.atomic_symbol
375                );
376                let bss = basis_shells
377                    .iter()
378                    .flat_map(|(angmom, cart, shell_order)| {
379                        create_basis_shell(*angmom, cart, shell_order)
380                    })
381                    .collect::<Vec<_>>();
382                Ok(BasisAtom::new(atom, &bss))
383            })
384            .collect::<Vec<_>>();
385        Ok(BasisAngularOrder::new(&basis_atoms))
386    }
387}
388
389/// Python-exposed enumerated type to marshall basis spin constraint information between Rust and
390/// Python.
391#[pyclass(eq, eq_int)]
392#[derive(Clone, PartialEq, Eq, Hash)]
393pub enum PySpinConstraint {
394    /// Variant for restricted spin constraint. Only two spin spaces are exposed.
395    Restricted,
396
397    /// Variant for unrestricted spin constraint. Only two spin spaces arranged in decreasing-$`m`$
398    /// order (*i.e.* $`(\alpha, \beta)`$) are exposed.
399    Unrestricted,
400
401    /// Variant for generalised spin constraint. Only two spin spaces arranged in decreasing-$`m`$
402    /// order (*i.e.* $`(\alpha, \beta)`$) are exposed.
403    Generalised,
404}
405
406impl From<PySpinConstraint> for SpinConstraint {
407    fn from(pysc: PySpinConstraint) -> Self {
408        match pysc {
409            PySpinConstraint::Restricted => SpinConstraint::Restricted(2),
410            PySpinConstraint::Unrestricted => SpinConstraint::Unrestricted(2, false),
411            PySpinConstraint::Generalised => SpinConstraint::Generalised(2, false),
412        }
413    }
414}
415
416impl TryFrom<SpinConstraint> for PySpinConstraint {
417    type Error = anyhow::Error;
418
419    fn try_from(sc: SpinConstraint) -> Result<Self, Self::Error> {
420        match sc {
421            SpinConstraint::Restricted(2) => Ok(PySpinConstraint::Restricted),
422            SpinConstraint::Unrestricted(2, false) => Ok(PySpinConstraint::Unrestricted),
423            SpinConstraint::Generalised(2, false) => Ok(PySpinConstraint::Generalised),
424            _ => Err(format_err!(
425                "`PySpinConstraint` can only support two spin spaces."
426            )),
427        }
428    }
429}
430
431/// Python-exposed enumerated type to marshall basis spin--orbit-coupled layout in the coupled
432/// treatment of spin and spatial degrees of freedom between Rust and Python.
433#[pyclass(eq, eq_int)]
434#[derive(Clone, PartialEq, Eq, Hash)]
435pub enum PySpinOrbitCoupled {
436    /// Variant for $`j`$-adapted basis functions. Only two relativistic components are exposed.
437    JAdapted,
438}
439
440impl From<PySpinOrbitCoupled> for SpinOrbitCoupled {
441    fn from(pysoc: PySpinOrbitCoupled) -> Self {
442        match pysoc {
443            PySpinOrbitCoupled::JAdapted => SpinOrbitCoupled::JAdapted(2),
444        }
445    }
446}
447
448impl TryFrom<SpinOrbitCoupled> for PySpinOrbitCoupled {
449    type Error = anyhow::Error;
450
451    fn try_from(soc: SpinOrbitCoupled) -> Result<Self, Self::Error> {
452        match soc {
453            SpinOrbitCoupled::JAdapted(2) => Ok(PySpinOrbitCoupled::JAdapted),
454            _ => Err(format_err!(
455                "`PySpinOrbitCoupled` can only support two relativistic components."
456            )),
457        }
458    }
459}
460
461/// Python-exposed enumerated type to handle the union type `PySpinConstraint | PySpinOrbitCoupled`
462/// in Python.
463#[derive(FromPyObject, Clone, PartialEq, Eq, Hash)]
464pub enum PyStructureConstraint {
465    /// Variant for Python-exposed spin constraint layout.
466    SpinConstraint(PySpinConstraint),
467
468    /// Variant for Python-exposed spin--orbit-coupled layout.
469    SpinOrbitCoupled(PySpinOrbitCoupled),
470}
471
472impl TryFrom<SpinConstraint> for PyStructureConstraint {
473    type Error = anyhow::Error;
474
475    fn try_from(sc: SpinConstraint) -> Result<Self, Self::Error> {
476        match sc {
477            SpinConstraint::Restricted(2) => Ok(PyStructureConstraint::SpinConstraint(
478                PySpinConstraint::Restricted,
479            )),
480            SpinConstraint::Unrestricted(2, false) => Ok(PyStructureConstraint::SpinConstraint(
481                PySpinConstraint::Unrestricted,
482            )),
483            SpinConstraint::Generalised(2, false) => Ok(PyStructureConstraint::SpinConstraint(
484                PySpinConstraint::Generalised,
485            )),
486            _ => Err(format_err!(
487                "`PySpinConstraint` can only support two spin spaces."
488            )),
489        }
490    }
491}
492
493impl TryFrom<PyStructureConstraint> for SpinConstraint {
494    type Error = anyhow::Error;
495
496    fn try_from(py_sc: PyStructureConstraint) -> Result<Self, Self::Error> {
497        match py_sc {
498            PyStructureConstraint::SpinConstraint(py_sc) => Ok(py_sc.into()),
499            PyStructureConstraint::SpinOrbitCoupled(_) => Err(format_err!(
500                "`SpinConstraint` cannot be created from `PySpinOrbitCoupled`."
501            )),
502        }
503    }
504}
505
506impl TryFrom<SpinOrbitCoupled> for PyStructureConstraint {
507    type Error = anyhow::Error;
508
509    fn try_from(soc: SpinOrbitCoupled) -> Result<Self, Self::Error> {
510        match soc {
511            SpinOrbitCoupled::JAdapted(2) => Ok(PyStructureConstraint::SpinOrbitCoupled(
512                PySpinOrbitCoupled::JAdapted,
513            )),
514            _ => Err(format_err!(
515                "`PySpinOrbitCoupled` can only support two relativistic components."
516            )),
517        }
518    }
519}
520
521impl TryFrom<PyStructureConstraint> for SpinOrbitCoupled {
522    type Error = anyhow::Error;
523
524    fn try_from(py_sc: PyStructureConstraint) -> Result<Self, Self::Error> {
525        match py_sc {
526            PyStructureConstraint::SpinOrbitCoupled(py_soc) => Ok(py_soc.into()),
527            PyStructureConstraint::SpinConstraint(_) => Err(format_err!(
528                "`SpinOrbitCoupled` cannot be created from `PySpinConstraint`."
529            )),
530        }
531    }
532}
533
534#[cfg(feature = "integrals")]
535#[pyclass]
536#[derive(Clone)]
537/// Python-exposed structure to marshall basis shell contraction information between Rust and
538/// Python.
539///
540/// # Constructor arguments
541///
542/// * `basis_shell` - A triplet of the form `(angmom, cart, order)` where:
543///     * `angmom` is a symbol such as `"S"` or `"P"` for the angular momentum of the shell,
544///     * `cart` is a boolean indicating if the functions in the shell are Cartesian (`true`)
545///     or pure / solid harmonics (`false`), and
546///     * `order` specifies how the functions in the shell are ordered:
547///       * if `cart` is `true`, `order` can be `None` for lexicographic order, or a list of
548///       tuples `(lx, ly, lz)` specifying a custom order for the Cartesian functions where
549///       `lx`, `ly`, and `lz` are the $`x`$-, $`y`$-, and $`z`$-exponents;
550///       * if `cart` is `false`, `order` can be `true` for increasing-$`m`$ order, `false` for
551///       decreasing-$`m`$ order, or a list of $`m`$ values for custom order.
552///
553///     Python type: `tuple[str, bool, bool | Optional[list[tuple[int, int, int]]]]`.
554/// * `primitives` - A list of tuples, each of which contains the exponent and the contraction
555/// coefficient of a Gaussian primitive in this shell. Python type: `list[tuple[float, float]]`.
556/// * `cart_origin` - A fixed-size list of length 3 containing the Cartesian coordinates of the
557/// origin $`\mathbf{R}`$ of this shell in Bohr radii. Python type: `list[float]`.
558/// * `k` - An optional fixed-size list of length 3 containing the Cartesian components of the
559/// $`\mathbf{k}`$ vector of this shell that appears in the complex phase factor
560/// $`\exp[i\mathbf{k} \cdot (\mathbf{r} - \mathbf{R})]`$. Python type: `Optional[list[float]]`.
561pub struct PyBasisShellContraction {
562    /// A triplet of the form `(angmom, cart, order)` where:
563    ///     * `angmom` is a symbol such as `"S"` or `"P"` for the angular momentum of the shell,
564    ///     * `cart` is a boolean indicating if the functions in the shell are Cartesian (`true`)
565    ///     or pure / solid harmonics (`false`), and
566    ///     * `order` specifies how the functions in the shell are ordered:
567    ///       * if `cart` is `true`, `order` can be `None` for lexicographic order, or a list of
568    ///       tuples `(lx, ly, lz)` specifying a custom order for the Cartesian functions where
569    ///       `lx`, `ly`, and `lz` are the $`x`$-, $`y`$-, and $`z`$-exponents;
570    ///       * if `cart` is `false`, `order` can be `true` for increasing-$`m`$ order, `false` for
571    ///       decreasing-$`m`$ order, or a list of $`m`$ values for custom order.
572    ///
573    /// Python type: `tuple[str, bool, bool | Optional[list[tuple[int, int, int]]]]`.
574    pub basis_shell: (u32, ShellType, PyShellOrder),
575
576    /// A list of tuples, each of which contains the exponent and the contraction coefficient of a
577    /// Gaussian primitive in this shell.
578    ///
579    /// Python type: `list[tuple[float, float]]`.
580    pub primitives: Vec<(f64, f64)>,
581
582    /// A fixed-size list of length 3 containing the Cartesian coordinates of the origin
583    /// $`\mathbf{R}`$ of this shell in Bohr radii.
584    ///
585    /// Python type: `list[float]`.
586    pub cart_origin: [f64; 3],
587
588    /// An optional fixed-size list of length 3 containing the Cartesian components of the
589    /// $`\mathbf{k}`$ vector of this shell that appears in the complex phase factor
590    /// $`\exp[i\mathbf{k} \cdot (\mathbf{r} - \mathbf{R})]`$.
591    ///
592    /// Python type: `Optional[list[float]]`.
593    pub k: Option<[f64; 3]>,
594}
595
596#[cfg(feature = "integrals")]
597#[pymethods]
598impl PyBasisShellContraction {
599    /// Creates a new `PyBasisShellContraction` structure.
600    ///
601    /// # Arguments
602    ///
603    /// * `basis_shell` - A triplet of the form `(angmom, cart, order)` where:
604    ///     * `angmom` is a symbol such as `"S"` or `"P"` for the angular momentum of the shell,
605    ///     * `cart` is a boolean indicating if the functions in the shell are Cartesian (`true`)
606    ///     or pure / solid harmonics (`false`), and
607    ///     * `order` specifies how the functions in the shell are ordered:
608    ///       * if `cart` is `true`, `order` can be `None` for lexicographic order, or a list of
609    ///       tuples `(lx, ly, lz)` specifying a custom order for the Cartesian functions where
610    ///       `lx`, `ly`, and `lz` are the $`x`$-, $`y`$-, and $`z`$-exponents;
611    ///       * if `cart` is `false`, `order` can be `true` for increasing-$`m`$ or `false` for
612    ///       decreasing-$`m`$ order.
613    ///
614    ///     Python type: `tuple[str, bool, bool | Optional[list[tuple[int, int, int]]]]`.
615    /// * `primitives` - A list of tuples, each of which contains the exponent and the contraction
616    /// coefficient of a Gaussian primitive in this shell. Python type: `list[tuple[float, float]]`.
617    /// * `cart_origin` - A fixed-size list of length 3 containing the Cartesian coordinates of the
618    /// origin of this shell. Python type: `list[float]`.
619    /// * `k` - An optional fixed-size list of length 3 containing the Cartesian components of the
620    /// $`\mathbf{k}`$ vector of this shell. Python type: `Optional[list[float]]`.
621    #[new]
622    #[pyo3(signature = (basis_shell, primitives, cart_origin, k=None))]
623    pub fn new(
624        basis_shell: (u32, ShellType, PyShellOrder),
625        primitives: Vec<(f64, f64)>,
626        cart_origin: [f64; 3],
627        k: Option<[f64; 3]>,
628    ) -> Self {
629        Self {
630            basis_shell,
631            primitives,
632            cart_origin,
633            k,
634        }
635    }
636}
637
638#[cfg(feature = "integrals")]
639impl TryFrom<PyBasisShellContraction> for BasisShellContraction<f64, f64> {
640    type Error = anyhow::Error;
641
642    fn try_from(pybsc: PyBasisShellContraction) -> Result<Self, Self::Error> {
643        let (order, cart, shell_order) = pybsc.basis_shell;
644        let basis_shell = create_basis_shell(order, &cart, &shell_order)?;
645        let contraction = GaussianContraction::<f64, f64> {
646            primitives: pybsc.primitives,
647        };
648        let cart_origin = Point3::from_slice(&pybsc.cart_origin);
649        let k = pybsc.k.map(|k| Vector3::from_row_slice(&k));
650        Ok(Self {
651            basis_shell,
652            contraction,
653            cart_origin,
654            k,
655        })
656    }
657}
658
659// ================
660// Helper functions
661// ================
662
663/// Creates a [`BasisShell`] structure from the `(angmom, cart, shell_order)` triplet.
664///
665/// # Arguments
666/// * `order` is an integer indicating the order of the shell,
667/// * `cart` is a boolean indicating if the functions in the shell are Cartesian (`true`)
668/// or pure / solid harmonics (`false`), and
669/// * `shell_order` specifies how the functions in the shell are ordered:
670///   * if `cart` is `true`, `order` can be `None` for lexicographic order, or a list of
671///   tuples `(lx, ly, lz)` specifying a custom order for the Cartesian functions where
672///   `lx`, `ly`, and `lz` are the $`x`$-, $`y`$-, and $`z`$-exponents;
673///   * if `cart` is `false`, `order` can be `true` for increasing-$`m`$ or `false` for
674///   decreasing-$`m`$ order.
675///
676/// # Returns
677///
678/// A [`BasisShell`] structure.
679///
680/// # Errors
681///
682/// Errors if `angmom` is not a valid angular momentum, or if there is a mismatch between `cart`
683/// and `shell_order`.
684fn create_basis_shell(
685    order: u32,
686    shell_type: &ShellType,
687    shell_order: &PyShellOrder,
688) -> Result<BasisShell, anyhow::Error> {
689    let shl_ord = match shell_type {
690        ShellType::Cartesian => {
691            let cart_order = match shell_order {
692                PyShellOrder::CartOrder(cart_tuples_opt) => {
693                    if let Some(cart_tuples) = cart_tuples_opt {
694                        CartOrder::new(cart_tuples)?
695                    } else {
696                        CartOrder::lex(order)
697                    }
698                }
699                PyShellOrder::PureSpinorOrder(_) => {
700                    log::error!(
701                        "Cartesian shell order expected, but specification for pure/spinor shell order found."
702                    );
703                    bail!(
704                        "Cartesian shell order expected, but specification for pure/spinor shell order found."
705                    )
706                }
707            };
708            ShellOrder::Cart(cart_order)
709        }
710        ShellType::Pure => match shell_order {
711            PyShellOrder::PureSpinorOrder(pypureorder) => match pypureorder {
712                PyPureSpinorOrder::Standard((increasingm, _even)) => {
713                    if *increasingm {
714                        ShellOrder::Pure(PureOrder::increasingm(order))
715                    } else {
716                        ShellOrder::Pure(PureOrder::decreasingm(order))
717                    }
718                }
719                PyPureSpinorOrder::Custom((mls, _even)) => ShellOrder::Pure(PureOrder::new(mls)?),
720            },
721            PyShellOrder::CartOrder(_) => {
722                log::error!(
723                    "Pure shell order expected, but specification for Cartesian shell order found."
724                );
725                bail!(
726                    "Pure shell order expected, but specification for Cartesian shell order found."
727                )
728            }
729        },
730        ShellType::Spinor => match shell_order {
731            PyShellOrder::PureSpinorOrder(pyspinororder) => match pyspinororder {
732                PyPureSpinorOrder::Standard((increasingm, even)) => {
733                    if *increasingm {
734                        ShellOrder::Spinor(SpinorOrder::increasingm(order, *even))
735                    } else {
736                        ShellOrder::Spinor(SpinorOrder::decreasingm(order, *even))
737                    }
738                }
739                PyPureSpinorOrder::Custom((two_mjs, even)) => {
740                    ShellOrder::Spinor(SpinorOrder::new(two_mjs, *even)?)
741                }
742            },
743            PyShellOrder::CartOrder(_) => {
744                log::error!(
745                    "Spinor shell order expected, but specification for Cartesian shell order found."
746                );
747                bail!(
748                    "Spinor shell order expected, but specification for Cartesian shell order found."
749                )
750            }
751        },
752    };
753    Ok::<_, anyhow::Error>(BasisShell::new(order, shl_ord))
754}
755
756// =================
757// Exposed functions
758// =================
759
760#[cfg(feature = "integrals")]
761#[pyfunction]
762/// Calculates the real-valued two-centre overlap matrix for a basis set.
763///
764/// # Arguments
765///
766/// * `basis_set` - A list of lists of [`PyBasisShellContraction`]. Each inner list contains shells
767/// on one atom. Python type: `list[list[PyBasisShellContraction]]`.
768///
769/// # Returns
770///
771/// A two-dimensional array containing the real two-centre overlap values.
772///
773/// # Panics
774///
775/// Panics if any shell contains a finite $`\mathbf{k}`$ vector.
776pub fn calc_overlap_2c_real<'py>(
777    py: Python<'py>,
778    basis_set: Vec<Vec<PyBasisShellContraction>>,
779) -> PyResult<Bound<'py, PyArray2<f64>>> {
780    let bscs = BasisSet::new(
781        basis_set
782            .into_iter()
783            .map(|basis_atom| {
784                basis_atom
785                    .into_iter()
786                    .map(|pybsc| BasisShellContraction::<f64, f64>::try_from(pybsc))
787                    .collect::<Result<Vec<_>, _>>()
788            })
789            .collect::<Result<Vec<_>, _>>()
790            .map_err(|err| PyValueError::new_err(err.to_string()))?,
791    );
792    let sao_2c = py.allow_threads(|| {
793        let stc = build_shell_tuple_collection![
794            <s1, s2>;
795            false, false;
796            &bscs, &bscs;
797            f64
798        ];
799        stc.overlap([0, 0])
800            .pop()
801            .expect("Unable to retrieve the two-centre overlap matrix.")
802    });
803    let pysao_2c = sao_2c.into_pyarray(py);
804    Ok(pysao_2c)
805}
806
807#[cfg(feature = "integrals")]
808#[pyfunction]
809/// Calculates the complex-valued two-centre overlap matrix for a basis set.
810///
811/// # Arguments
812///
813/// * `basis_set` - A list of lists of [`PyBasisShellContraction`]. Each inner list contains shells
814/// on one atom. Python type: `list[list[PyBasisShellContraction]]`.
815/// * `complex_symmetric` - A boolean indicating if the complex-symmetric overlap is to be
816/// calculated.
817///
818/// # Returns
819///
820/// A two-dimensional array containing the complex two-centre overlap values.
821pub fn calc_overlap_2c_complex<'py>(
822    py: Python<'py>,
823    basis_set: Vec<Vec<PyBasisShellContraction>>,
824    complex_symmetric: bool,
825) -> PyResult<Bound<'py, PyArray2<Complex<f64>>>> {
826    let bscs = BasisSet::new(
827        basis_set
828            .into_iter()
829            .map(|basis_atom| {
830                basis_atom
831                    .into_iter()
832                    .map(|pybsc| BasisShellContraction::<f64, f64>::try_from(pybsc))
833                    .collect::<Result<Vec<_>, _>>()
834            })
835            .collect::<Result<Vec<_>, _>>()
836            .map_err(|err| PyValueError::new_err(err.to_string()))?,
837    );
838    let sao_2c = py.allow_threads(|| {
839        let stc = build_shell_tuple_collection![
840            <s1, s2>;
841            !complex_symmetric, false;
842            &bscs, &bscs;
843            Complex<f64>
844        ];
845        stc.overlap([0, 0])
846            .pop()
847            .expect("Unable to retrieve the two-centre overlap matrix.")
848    });
849    let pysao_2c = sao_2c.into_pyarray(py);
850    Ok(pysao_2c)
851}
852
853#[cfg(feature = "integrals")]
854#[pyfunction]
855/// Calculates the real-valued four-centre overlap tensor for a basis set.
856///
857/// # Arguments
858///
859/// * `basis_set` - A list of lists of [`PyBasisShellContraction`]. Each inner list contains shells
860/// on one atom. Python type: `list[list[PyBasisShellContraction]]`.
861///
862/// # Returns
863///
864/// A four-dimensional array containing the real four-centre overlap values.
865///
866/// # Panics
867///
868/// Panics if any shell contains a finite $`\mathbf{k}`$ vector.
869pub fn calc_overlap_4c_real<'py>(
870    py: Python<'py>,
871    basis_set: Vec<Vec<PyBasisShellContraction>>,
872) -> PyResult<Bound<'py, PyArray4<f64>>> {
873    let bscs = BasisSet::new(
874        basis_set
875            .into_iter()
876            .map(|basis_atom| {
877                basis_atom
878                    .into_iter()
879                    .map(|pybsc| BasisShellContraction::<f64, f64>::try_from(pybsc))
880                    .collect::<Result<Vec<_>, _>>()
881            })
882            .collect::<Result<Vec<_>, _>>()
883            .map_err(|err| PyValueError::new_err(err.to_string()))?,
884    );
885    let sao_4c = py.allow_threads(|| {
886        let stc = build_shell_tuple_collection![
887            <s1, s2, s3, s4>;
888            false, false, false, false;
889            &bscs, &bscs, &bscs, &bscs;
890            f64
891        ];
892        stc.overlap([0, 0, 0, 0])
893            .pop()
894            .expect("Unable to retrieve the four-centre overlap tensor.")
895    });
896    let pysao_4c = sao_4c.into_pyarray(py);
897    Ok(pysao_4c)
898}
899
900#[cfg(feature = "integrals")]
901#[pyfunction]
902/// Calculates the complex-valued four-centre overlap tensor for a basis set.
903///
904/// # Arguments
905///
906/// * `basis_set` - A list of lists of [`PyBasisShellContraction`]. Each inner list contains shells
907/// on one atom. Python type: `list[list[PyBasisShellContraction]]`.
908/// * `complex_symmetric` - A boolean indicating if the complex-symmetric overlap tensor is to be
909/// calculated.
910///
911/// # Returns
912///
913/// A four-dimensional array containing the complex four-centre overlap values.
914pub fn calc_overlap_4c_complex<'py>(
915    py: Python<'py>,
916    basis_set: Vec<Vec<PyBasisShellContraction>>,
917    complex_symmetric: bool,
918) -> PyResult<Bound<'py, PyArray4<Complex<f64>>>> {
919    let bscs = BasisSet::new(
920        basis_set
921            .into_iter()
922            .map(|basis_atom| {
923                basis_atom
924                    .into_iter()
925                    .map(|pybsc| BasisShellContraction::<f64, f64>::try_from(pybsc))
926                    .collect::<Result<Vec<_>, _>>()
927            })
928            .collect::<Result<Vec<_>, _>>()
929            .map_err(|err| PyValueError::new_err(err.to_string()))?,
930    );
931    let sao_4c = py.allow_threads(|| {
932        let stc = build_shell_tuple_collection![
933            <s1, s2, s3, s4>;
934            !complex_symmetric, !complex_symmetric, false, false;
935            &bscs, &bscs, &bscs, &bscs;
936            Complex<f64>
937        ];
938        stc.overlap([0, 0, 0, 0])
939            .pop()
940            .expect("Unable to retrieve the four-centre overlap tensor.")
941    });
942    let pysao_4c = sao_4c.into_pyarray(py);
943    Ok(pysao_4c)
944}