1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
//! Implementation of symmetry transformations for electron densities.

use ndarray::{concatenate, s, Array2, Axis, LinalgScalar, ScalarOperand};
use ndarray_linalg::types::Lapack;
use num_complex::{Complex, ComplexFloat};

use crate::permutation::{IntoPermutation, PermutableCollection, Permutation};
use crate::symmetry::symmetry_element::SymmetryOperation;
use crate::symmetry::symmetry_transformation::{
    assemble_sh_rotation_3d_matrices, permute_array_by_atoms, ComplexConjugationTransformable,
    DefaultTimeReversalTransformable, SpatialUnitaryTransformable, SpinUnitaryTransformable,
    SymmetryTransformable, TimeReversalTransformable, TransformationError,
};
use crate::target::density::Density;

// ---------------------------
// SpatialUnitaryTransformable
// ---------------------------
impl<'a, T> SpatialUnitaryTransformable for Density<'a, T>
where
    T: ComplexFloat + LinalgScalar + ScalarOperand + Copy + Lapack,
    f64: Into<T>,
{
    fn transform_spatial_mut(
        &mut self,
        rmat: &Array2<f64>,
        perm: Option<&Permutation<usize>>,
    ) -> Result<&mut Self, TransformationError> {
        let tmats: Vec<Array2<T>> = assemble_sh_rotation_3d_matrices(self.bao, rmat, perm)
            .map_err(|err| TransformationError(err.to_string()))?
            .iter()
            .map(|tmat| tmat.map(|&x| x.into()))
            .collect();
        let pbao = if let Some(p) = perm {
            self.bao
                .permute(p)
                .map_err(|err| TransformationError(err.to_string()))?
        } else {
            self.bao.clone()
        };
        let old_denmat = &self.density_matrix;
        let p_coeff = if let Some(p) = perm {
            permute_array_by_atoms(old_denmat, p, &[Axis(0), Axis(1)], self.bao)
        } else {
            old_denmat.clone()
        };
        let trow_p_blocks = pbao
            .shell_boundary_indices()
            .into_iter()
            .zip(tmats.iter())
            .map(|((shl_start, shl_end), tmat)| {
                tmat.dot(&p_coeff.slice(s![shl_start..shl_end, ..]))
            })
            .collect::<Vec<_>>();
        let trow_p_coeff = concatenate(
            Axis(0),
            &trow_p_blocks
                .iter()
                .map(|trow_p_block| trow_p_block.view())
                .collect::<Vec<_>>(),
        )
        .expect("Unable to concatenate the transformed rows for the various shells.");

        let tcol_trow_p_blocks = pbao
            .shell_boundary_indices()
            .into_iter()
            .zip(tmats.iter())
            .map(|((shl_start, shl_end), tmat)| {
                // tmat is real-valued, so there is no need for tmat.t().map(|x| x.conj()).
                trow_p_coeff
                    .slice(s![.., shl_start..shl_end])
                    .dot(&tmat.t())
            })
            .collect::<Vec<_>>();
        let new_denmat = concatenate(
            Axis(1),
            &tcol_trow_p_blocks
                .iter()
                .map(|tcol_trow_p_block| tcol_trow_p_block.view())
                .collect::<Vec<_>>(),
        )
        .expect("Unable to concatenate the transformed columns for the various shells.");
        self.density_matrix = new_denmat;
        Ok(self)
    }
}

// ------------------------
// SpinUnitaryTransformable
// ------------------------

impl<'a, T> SpinUnitaryTransformable for Density<'a, T>
where
    T: ComplexFloat + Lapack,
{
    /// Performs a spin transformation in-place.
    ///
    /// Since densities are entirely spatial, spin transformations have no effect on them. This
    /// thus simply returns `self` without modification.
    fn transform_spin_mut(
        &mut self,
        _: &Array2<Complex<f64>>,
    ) -> Result<&mut Self, TransformationError> {
        Ok(self)
    }
}

// -------------------------------
// ComplexConjugationTransformable
// -------------------------------

impl<'a, T> ComplexConjugationTransformable for Density<'a, T>
where
    T: ComplexFloat + Lapack,
{
    fn transform_cc_mut(&mut self) -> Result<&mut Self, TransformationError> {
        self.density_matrix.mapv_inplace(|x| x.conj());
        self.complex_conjugated = !self.complex_conjugated;
        Ok(self)
    }
}

// --------------------------------
// DefaultTimeReversalTransformable
// --------------------------------
impl<'a, T> DefaultTimeReversalTransformable for Density<'a, T> where T: ComplexFloat + Lapack {}

// ---------------------
// SymmetryTransformable
// ---------------------
impl<'a, T> SymmetryTransformable for Density<'a, T>
where
    T: ComplexFloat + Lapack,
    Density<'a, T>: SpatialUnitaryTransformable + TimeReversalTransformable,
{
    fn sym_permute_sites_spatial(
        &self,
        symop: &SymmetryOperation,
    ) -> Result<Permutation<usize>, TransformationError> {
        if (symop.generating_element.threshold().log10() - self.mol.threshold.log10()).abs() >= 3.0
        {
            log::warn!(
                "Symmetry operation threshold ({:.3e}) and molecule threshold ({:.3e}) \
                differ by more than three orders of magnitudes.",
                symop.generating_element.threshold(),
                self.mol.threshold
            )
        }
        symop
            .act_permute(&self.mol.molecule_ordinary_atoms())
            .ok_or(TransformationError(format!(
            "Unable to determine the atom permutation corresponding to the operation `{symop}`.",
        )))
    }
}