qsym2/chartab/
modular_linalg.rs

1//! Modular linear algebra.
2
3use std::collections::{HashMap, HashSet};
4use std::error::Error;
5use std::fmt::{self, Debug, Display};
6use std::hash::Hash;
7use std::ops::Div;
8use std::panic;
9
10use itertools::Itertools;
11use log;
12use ndarray::{s, Array1, Array2, Axis, LinalgScalar, ShapeBuilder, Zip};
13use num_modular::ModularInteger;
14use num_traits::{Inv, Pow, ToPrimitive, Zero};
15use rayon::prelude::*;
16
17use crate::auxiliary::misc::GramSchmidtError;
18
19#[cfg(test)]
20#[path = "modular_linalg_tests.rs"]
21mod modular_linalg_tests;
22
23/// Calculates the determinant of a square matrix over a finite integer field
24/// using the Bereiss algorithm.
25///
26/// For more information, see
27/// <https://stackoverflow.com/questions/66192894/precise-determinant-of-integer-nxn-matrix>.
28///
29/// # Arguments
30///
31/// * `mat` - A square matrix.
32///
33/// # Returns
34///
35/// The determinant of `mat` in the same field.
36///
37/// # Panics
38///
39/// Panics if `mat` is not a square matrix.
40fn modular_determinant<T>(mat: &Array2<T>) -> T
41where
42    T: Clone + LinalgScalar + ModularInteger<Base = u32> + Div<Output = T>,
43{
44    assert_eq!(mat.ncols(), mat.nrows(), "A square matrix is expected.");
45    let mut mat = mat.clone();
46    let rep = mat
47        .first()
48        .expect("Unable to obtain the first element of `mat`.");
49    let dim = mat.ncols();
50    let mut sign = rep.convert(1u32);
51    let mut prev = rep.convert(1u32);
52    let zero = rep.convert(0u32);
53
54    for i in 0..(dim - 1) {
55        if mat[(i, i)] == zero {
56            // Swap with another row having non-zero i-th element.
57            let rel_swapto = mat.slice(s![(i + 1).., i]).iter().position(|x| *x != zero);
58            if let Some(rel_index) = rel_swapto {
59                let (mut mat_above, mut mat_below) = mat.view_mut().split_at(Axis(0), i + 1);
60                let row_from = mat_above.slice_mut(s![i, ..]);
61                let row_to = mat_below.slice_mut(s![rel_index, ..]);
62                Zip::from(row_from).and(row_to).for_each(std::mem::swap);
63                sign = -sign;
64            } else {
65                // All mat[.., i] are zero => zero determinant.
66                return zero;
67            }
68        }
69        for (j, k) in ((i + 1)..dim).cartesian_product((i + 1)..dim) {
70            let numerator = mat[(j, k)] * mat[(i, i)] - mat[(j, i)] * mat[(i, k)];
71            mat[(j, k)] = numerator / prev;
72        }
73        prev = mat[(i, i)];
74    }
75    sign * *mat
76        .last()
77        .expect("Unable to obtain the last element of `mat`.")
78}
79
80/// Converts an array into its unique reduced row echelon form using Gaussian
81/// elimination over a finite integer field.
82///
83/// # Arguments
84///
85/// * `mat` - A rectangular matrix.
86///
87/// # Returns
88///
89/// * The reduced row echelon form of `mat`.
90/// * The nullity of `mat`.
91///
92/// # Panics
93///
94/// Panics when the pivoting values are not unity.
95fn modular_rref<T>(mat: &Array2<T>) -> (Array2<T>, usize)
96where
97    T: Clone + Copy + Debug + ModularInteger<Base = u32> + Div<Output = T>,
98{
99    let mut mat = mat.clone();
100    let nrows = mat.nrows();
101    let ncols = mat.ncols();
102    let rep = mat
103        .first()
104        .expect("Unable to obtain the first element in `mat`.");
105    let zero = rep.convert(0);
106    let one = rep.convert(1);
107    let mut rank = 0usize;
108
109    let mut pivot_row = 0usize;
110    let mut pivot_col = 0usize;
111
112    while pivot_row < nrows && pivot_col < ncols {
113        // Find the pivot in column pivot_col
114        let rel_i_nonzero_option = mat
115            .slice(s![pivot_row.., pivot_col])
116            .iter()
117            .position(|x| *x != zero);
118        if let Some(rel_i_nonzero) = rel_i_nonzero_option {
119            if rel_i_nonzero > 0 {
120                // Possible pivot in this column at row (pivot_row + rel_i_nonzero)
121                // Swap row pivot_row with row (pivot_row + rel_i_nonzero)
122                let (mut mat_above, mut mat_below) =
123                    mat.view_mut().split_at(Axis(0), pivot_row + 1);
124                let row_from = mat_above.slice_mut(s![pivot_row, ..]);
125                let row_to = mat_below.slice_mut(s![rel_i_nonzero - 1, ..]);
126                Zip::from(row_from).and(row_to).for_each(std::mem::swap);
127            }
128
129            // Scale all elements in pivot row to make the pivot element equal to one
130            let pivot_val = mat[(pivot_row, pivot_col)];
131            for j in (pivot_col)..ncols {
132                mat[(pivot_row, j)] = mat[(pivot_row, j)] / pivot_val;
133            }
134
135            // Eliminate below the pivot
136            for i in (pivot_row + 1)..nrows {
137                assert_eq!(mat[(pivot_row, pivot_col)], one);
138                let f = mat[(i, pivot_col)];
139                // row_i -= f * pivot_row
140                // Fill with zeros the lower part of pivot column
141                // This is essentially a subtraction but has been optimised away.
142                mat[(i, pivot_col)] = zero;
143                // Subtract all remaining elements in current row
144                for j in (pivot_col + 1)..ncols {
145                    let a = mat[(pivot_row, j)];
146                    mat[(i, j)] = mat[(i, j)] - a * f;
147                }
148            }
149
150            // Eliminate above the pivot
151            for i in (0..pivot_row).rev() {
152                assert_eq!(mat[(pivot_row, pivot_col)], one);
153                let f = mat[(i, pivot_col)];
154                // row_i -= f * pivot_row
155                // Fill with zeros the upper part of pivot column
156                mat[(i, pivot_col)] = zero;
157                // Subtract all remaining elements in current row
158                for j in (pivot_col + 1)..ncols {
159                    let a = mat[(pivot_row, j)];
160                    mat[(i, j)] = mat[(i, j)] - a * f;
161                }
162            }
163
164            // Increase pivot row and column for the next while iteration
165            pivot_row += 1;
166            pivot_col += 1;
167
168            // Pivot column increases rank.
169            rank += 1;
170        } else {
171            // No pivot in this column; pass to next column.
172            pivot_col += 1;
173        }
174    }
175    (mat, ncols - rank)
176}
177
178/// Determines a set of basis vectors for the kernel of a matrix via Gaussian
179/// elimination over a finite integer field.
180///
181/// The kernel of an $`m \times n`$ matrix $`\mathbf{M}`$ is the space of
182/// the solutions to the equation
183///
184/// ```math
185///     \mathbf{M} \mathbf{x} = \mathbf{0},
186/// ```
187///
188/// where $`\mathbf{x}`$ is an $`n \times 1`$ column vector.
189///
190/// # Arguments
191///
192/// * `mat` - A rectangular matrix.
193///
194/// # Returns
195///
196/// A vector of basis vectors for the kernel of `mat`.
197fn modular_kernel<T>(mat: &Array2<T>) -> Vec<Array1<T>>
198where
199    T: Clone + Copy + Debug + ModularInteger<Base = u32> + Div<Output = T>,
200{
201    let (mat_rref, nullity) = modular_rref(mat);
202    let ncols = mat.ncols();
203    let rep = mat
204        .first()
205        .expect("Unable to obtain the first element in `mat`.");
206    let zero = rep.convert(0);
207    let one = rep.convert(1);
208    let pivot_cols: Vec<usize> = mat_rref
209        .axis_iter(Axis(0))
210        .filter_map(|row| row.iter().position(|&x| x != zero))
211        .collect();
212    let rank = ncols - nullity;
213    assert_eq!(rank, pivot_cols.len());
214    log::debug!("Rank: {}", rank);
215    log::debug!("Kernel dim: {}", nullity);
216
217    let pivot_cols_set: HashSet<usize> = pivot_cols.iter().copied().collect::<HashSet<_>>();
218    let non_pivot_cols = (0..ncols).collect::<HashSet<_>>();
219    let non_pivot_cols = non_pivot_cols.difference(&pivot_cols_set);
220    non_pivot_cols
221        .map(|&non_pivot_col| {
222            let mut kernel_basis_vec = Array1::from_elem((ncols,), zero);
223            kernel_basis_vec[non_pivot_col] = one;
224
225            for (i, &pivot_col) in pivot_cols.iter().enumerate() {
226                kernel_basis_vec[pivot_col] = -mat_rref[(i, non_pivot_col)];
227            }
228            let first_nonzero_pos = kernel_basis_vec
229                .iter()
230                .position(|&x| x != zero)
231                .expect("Kernel basis vector cannot be zero.");
232            let first_nonzero = kernel_basis_vec[first_nonzero_pos];
233            kernel_basis_vec
234                .iter_mut()
235                .for_each(|x| *x = *x / first_nonzero);
236            kernel_basis_vec
237        })
238        .collect()
239}
240
241#[derive(Debug, Clone)]
242pub(crate) struct ModularEigError<'a, T> {
243    mat: &'a Array2<T>,
244}
245
246impl<'a, T: Display + Debug> fmt::Display for ModularEigError<'a, T> {
247    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248        write!(f, "Unable to diagonalise {}.", self.mat)
249    }
250}
251
252impl<'a, T: Display + Debug> Error for ModularEigError<'a, T> {}
253
254/// Determines the eigenvalues and eigenvector of a square matrix over a finite
255/// integer field.
256///
257/// # Arguments
258///
259/// * `mat` - A square matrix of modular integers.
260///
261/// # Returns
262///
263/// A hashmap containing the eigenvalues and the associated eigenvectors.
264/// One eigenvalue can be associated with multiple eigenvectors in cases of
265/// degeneracy.
266///
267/// # Panics
268///
269/// Panics when inconsistent ring moduli between matrix elements are encountered.
270pub(crate) fn modular_eig<T>(
271    mat: &Array2<T>,
272) -> Result<HashMap<T, Vec<Array1<T>>>, ModularEigError<'_, T>>
273where
274    T: Clone
275        + LinalgScalar
276        + Display
277        + Debug
278        + ModularInteger<Base = u32>
279        + Eq
280        + Hash
281        + panic::UnwindSafe
282        + panic::RefUnwindSafe
283        + Sync
284        + Send,
285{
286    assert!(mat.is_square(), "Only square matrices are supported.");
287    let dim = mat.nrows();
288    let modulus_set: HashSet<u32> = mat
289        .iter()
290        .filter_map(|x| panic::catch_unwind(|| x.modulus()).ok())
291        .collect();
292    assert_eq!(
293        modulus_set.len(),
294        1,
295        "Inconsistent ring moduli between matrix elements."
296    );
297    let modulus = *modulus_set
298        .iter()
299        .next()
300        .expect("Unexpected empty `modulus_set`.");
301    let rep = mat
302        .iter()
303        .find(|x| panic::catch_unwind(|| x.modulus()).is_ok())
304        .expect("At least one modular integer with a known modulus should have been found.");
305    let zero = T::zero();
306    log::debug!("Diagonalising in GF({})...", modulus);
307
308    let results: HashMap<T, Vec<Array1<T>>> = (0..modulus)
309        .par_bridge()
310        .filter_map(|lam| {
311            let lamb = rep.convert(lam);
312            let char_mat = mat - Array2::from_diag_elem(dim, lamb);
313            let det = modular_determinant(&char_mat);
314            if det == zero {
315                let vecs = modular_kernel(&char_mat);
316                log::debug!(
317                    "{} is an eigenvalue with multiplicity {}.",
318                    lamb,
319                    vecs.len()
320                );
321                Some((lamb, vecs))
322            } else {
323                None
324            }
325        })
326        .collect();
327    let eigen_dim = results.values().fold(0usize, |acc, vecs| acc + vecs.len());
328    if eigen_dim != dim {
329        log::warn!(
330            "Found {} / {} eigenvector{}. The matrix is not diagonalisable in GF({}).",
331            eigen_dim,
332            dim,
333            if dim > 1 { "s" } else { "" },
334            modulus
335        );
336        Err(ModularEigError { mat })
337    } else {
338        log::debug!(
339            "Found {} / {} eigenvector{}. Eigensolver done in GF({}).",
340            eigen_dim,
341            dim,
342            if dim > 1 { "s" } else { "" },
343            modulus
344        );
345
346        Ok(results)
347    }
348}
349
350/// Calculates the weighted Hermitian inner product between two vectors defined
351/// as:
352///
353/// ```math
354/// \langle \mathbf{u}, \mathbf{w} \rangle
355/// = \lvert G \rvert^{-1} \sum_i
356///     \frac{u_i \bar{w}_i}{\lvert K_i \rvert},
357/// ```
358///
359/// where $`K_i`$ is the i-th conjugacy class of the group, and
360/// $`\bar{w_i}`$ the character in $`\mathbf{w}`$ corresponding to the
361/// inverse conjugacy class of $`K_i`$.
362///
363/// Note that, in $`\mathbb{C}`$, $`\bar{w}_i = w_i^*`$, but this is not true
364/// in $`\mathrm{GF}(p)`$.
365///
366/// # Arguments
367///
368/// * `vec_pair` - A pair of vectors for which the Hermitian inner product is to be
369/// calculated.
370/// * `class_sizes` - The sizes of the conjugacy classes.
371/// * `perm_for_conj` - The permutation indices to take a vector into its conjugate.
372///
373/// # Returns
374/// The weighted Hermitian inner product.
375///
376/// # Panics
377///
378/// Panics when inconsistent ring moduli between vector elements are encountered.
379#[must_use]
380pub(crate) fn weighted_hermitian_inprod<T>(
381    vec_pair: (&Array1<T>, &Array1<T>),
382    class_sizes: &[usize],
383    perm_for_conj: Option<&Vec<usize>>,
384) -> T
385where
386    T: Display
387        + Debug
388        + LinalgScalar
389        + ModularInteger<Base = u32>
390        + panic::UnwindSafe
391        + panic::RefUnwindSafe,
392{
393    let (vec_u, vec_w) = vec_pair;
394    assert_eq!(vec_u.len(), vec_w.len());
395    assert_eq!(vec_u.len(), class_sizes.len());
396
397    let modulus_set: HashSet<u32> = vec_u
398        .iter()
399        .chain(vec_w.iter())
400        .filter_map(|x| panic::catch_unwind(|| x.modulus()).ok())
401        .collect();
402    assert_eq!(
403        modulus_set.len(),
404        1,
405        "Inconsistent ring moduli between vector elements."
406    );
407
408    let rep = vec_u
409        .iter()
410        .chain(vec_w.iter())
411        .find(|x| panic::catch_unwind(|| x.modulus()).is_ok())
412        .expect("No known modulus found.");
413
414    let vec_w_conj = if let Some(indices) = perm_for_conj {
415        vec_w.select(Axis(0), indices)
416    } else {
417        vec_w.clone()
418    };
419
420    Zip::from(vec_u)
421        .and(&vec_w_conj)
422        .and(class_sizes)
423        .fold(T::zero(), |acc, &u, &w_conj, &k| {
424            acc + (u * w_conj)
425                / rep.convert(
426                    u32::try_from(k)
427                        .unwrap_or_else(|_| panic!("Unable to convert `{k}` to `u32`.")),
428                )
429        })
430        / rep.convert(
431            u32::try_from(class_sizes.iter().sum::<usize>())
432                .expect("Unable to convert the group order to `u32`."),
433        )
434}
435
436/// Performs Gram--Schmidt orthogonalisation (but not normalisation) on a set of vectors with
437/// respect to the inner product defined in [`self::weighted_hermitian_inprod`].
438///
439/// # Arguments
440///
441/// * `vs` - Vectors forming a basis for a subspace.
442/// * `class_sizes` - Sizes for the conjugacy classes. This is required to compute the inner
443/// product.
444/// * `perm_for_conj` - The permutation indices to take a vector into its conjugate. This is
445/// required to compute the inner product.
446///
447/// # Returns
448///
449/// The orthogonal vectors forming a basis for the same subspace.
450///
451/// # Errors
452///
453/// Errors when the orthogonalisation procedure fails, which occurs when there is linear dependency
454/// between the basis vectors.
455fn modular_gram_schmidt<'a, T>(
456    vs: &'a [Array1<T>],
457    class_sizes: &[usize],
458    perm_for_conj: Option<&Vec<usize>>,
459) -> Result<Vec<Array1<T>>, GramSchmidtError<'a, T>>
460where
461    T: Display
462        + Debug
463        + LinalgScalar
464        + ModularInteger<Base = u32>
465        + panic::UnwindSafe
466        + panic::RefUnwindSafe,
467{
468    let mut us: Vec<Array1<T>> = Vec::with_capacity(vs.len());
469    let mut us_sq_norm: Vec<T> = Vec::with_capacity(vs.len());
470    for (i, vi) in vs.iter().enumerate() {
471        // u[i] now initialised with v[i]
472        us.push(vi.clone());
473
474        // Project vi onto all uj (0 <= j < i)
475        for j in 0..i {
476            if Zero::is_zero(&us_sq_norm[j]) {
477                log::error!("A zero-norm vector found: {}", us[j]);
478                return Err(GramSchmidtError {
479                    vecs: Some(vs),
480                    mat: None,
481                });
482            }
483            let p_uj_vi =
484                weighted_hermitian_inprod((vi, &us[j]), class_sizes, perm_for_conj) / us_sq_norm[j];
485            us[i] = &us[i] - us[j].map(|&x| x * p_uj_vi);
486        }
487
488        // Evaluate the squared norm of ui which will no longer be changed after this iteration.
489        // us_sq_norm[i] now available.
490        us_sq_norm.push(weighted_hermitian_inprod(
491            (&us[i], &us[i]),
492            class_sizes,
493            perm_for_conj,
494        ));
495    }
496
497    debug_assert!({
498        us.iter().enumerate().all(|(i, ui)| {
499            us.iter().enumerate().all(|(j, uj)| {
500                let ov_ij = weighted_hermitian_inprod((ui, uj), class_sizes, perm_for_conj);
501                i == j || Zero::is_zero(&ov_ij)
502            })
503        })
504    });
505
506    Ok(us)
507}
508
509#[derive(Debug, Clone)]
510pub(crate) struct SplitSpaceError<'a, T> {
511    mat: &'a Array2<T>,
512    vecs: &'a [Array1<T>],
513}
514
515impl<'a, T: Display + Debug> fmt::Display for SplitSpaceError<'a, T> {
516    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
517        write!(
518            f,
519            "Unable to split the degenerate subspace spanned by {:#?} with {}.",
520            self.vecs, self.mat
521        )
522    }
523}
524
525impl<'a, T: Display + Debug> Error for SplitSpaceError<'a, T> {}
526
527/// Splits a space into smaller subspaces under the action of a matrix.
528///
529/// # Arguments
530///
531/// * `mat` - A matrix to act on the specified space.
532/// * `vecs` - The basis vectors specifying the space.
533/// * `class_sizes` - Sizes for the conjugacy classes. This is required to compute the inner
534/// product defined in [`self::weighted_hermitian_inprod`].
535/// * `perm_for_conj` - The permutation indices to take a vector into its conjugate. This is
536/// required to compute the inner product defined in [`self::weighted_hermitian_inprod`].
537///
538/// # Returns
539///
540/// A vector of vectors of vectors, where each inner vector contains the basis
541/// vectors for an $`n`$-dimensional subspace, $`n \ge 1`$.
542///
543/// # Panics
544///
545/// Panics when inconsistent ring moduli between vector and matrix elements are found.
546///
547/// # Errors
548///
549/// Errors when the degeneracy subspace cannot be split, which occurs when any of the
550/// orthogonalised vectors spanning the subspace is a null vector.
551#[allow(clippy::too_many_lines)]
552pub(crate) fn split_space<'a, T>(
553    mat: &'a Array2<T>,
554    vecs: &'a [Array1<T>],
555    class_sizes: &[usize],
556    perm_for_conj: Option<&Vec<usize>>,
557) -> Result<Vec<Vec<Array1<T>>>, SplitSpaceError<'a, T>>
558where
559    T: Display
560        + LinalgScalar
561        + Debug
562        + ModularInteger<Base = u32>
563        + Eq
564        + Hash
565        + Zero
566        + Inv
567        + panic::UnwindSafe
568        + panic::RefUnwindSafe
569        + Sync
570        + Send,
571{
572    let modulus_set: HashSet<u32> = vecs
573        .iter()
574        .flatten()
575        .chain(mat.iter())
576        .filter_map(|x| panic::catch_unwind(|| x.modulus()).ok())
577        .collect();
578    assert_eq!(
579        modulus_set.len(),
580        1,
581        "Inconsistent ring moduli between vector and matrix elements."
582    );
583
584    let dim = vecs.len();
585    log::debug!("Dimensionality of space to be split: {}", dim);
586    let split_subspaces = if dim <= 1 {
587        log::debug!("Nothing to do.");
588        vec![Vec::from(vecs)]
589    } else {
590        // Orthogonalise the subspace basis
591        let ortho_vecs = modular_gram_schmidt(vecs, class_sizes, perm_for_conj).map_err(|err| {
592            log::warn!("{err}");
593            SplitSpaceError { mat, vecs }
594        })?;
595        let ortho_vecs_mat = Array2::from_shape_vec(
596            (class_sizes.len(), dim).f(),
597            ortho_vecs.iter().flatten().copied().collect::<Vec<_>>(),
598        )
599        .expect("Unable to construct a two-dimensional matrix of the orthogonal vectors.");
600
601        // Find the representation matrix of the action of `mat` on the basis vectors
602        let ortho_vecs_mag = Array2::from_shape_vec(
603            (dim, 1),
604            ortho_vecs
605                .iter()
606                .map(|col_i| weighted_hermitian_inprod((col_i, col_i), class_sizes, perm_for_conj))
607                .collect(),
608        )
609        .expect(
610            "Unable to construct a column vector of the magnitudes of the orthogonalised vectors.",
611        );
612        if ortho_vecs_mag.iter().any(|x| Zero::is_zero(x)) {
613            return Err(SplitSpaceError { mat, vecs });
614        }
615
616        // The division below is correct: `ortho_vecs_mag` (dim × 1) is broadcast to (dim × dim),
617        // hence every row of the dividend is divided by the corresponding element of
618        // `ortho_vecs_mag`.
619        let nv_mat = mat.dot(&ortho_vecs_mat);
620        let mut rep_mat_unnorm: Array2<T> = Array2::zeros((dim, dim));
621        for (i, v) in ortho_vecs.iter().enumerate() {
622            for (j, nv) in nv_mat.columns().into_iter().enumerate() {
623                rep_mat_unnorm[[i, j]] =
624                    weighted_hermitian_inprod((v, &nv.to_owned()), class_sizes, perm_for_conj);
625            }
626        }
627        let rep_mat = rep_mat_unnorm / ortho_vecs_mag;
628
629        // Diagonalise the representation matrix
630        // Then use the eigenvectors to form linear combinations of the original
631        // basis vectors and split the subspace
632        let eigs = modular_eig(&rep_mat).map_err(|_| SplitSpaceError { mat, vecs })?;
633        let n_subspaces = eigs.len();
634        if n_subspaces == dim {
635            log::debug!("{dim}-D space is completely split into {n_subspaces} 1-D subspaces.",);
636        } else {
637            log::debug!(
638                "{dim}-D space is incompletely split into {n_subspaces} subspace{}.",
639                if n_subspaces == 1 { "" } else { "s" }
640            );
641        }
642
643        // Each eigenvalue of the representation matrix corresponds to one sub-subspace.
644        eigs.iter().fold(vec![], |mut acc, (eigval, eigvecs)| {
645            log::debug!(
646                "Handling eigenvalue {} of the representation matrix...",
647                eigval
648            );
649            acc.push(
650                eigvecs
651                    .iter()
652                    .map(|vec| {
653                        // Form linear combinations of the original basis vectors
654                        let transformed_vec = ortho_vecs_mat.dot(vec);
655
656                        // Normalise so that the first non-zero element is one
657                        let first_non_zero = transformed_vec
658                            .iter()
659                            .find(|&x| !Zero::is_zero(x))
660                            .expect("Unexpected zero eigenvector.");
661                        Array1::from_vec(
662                            transformed_vec
663                                .iter()
664                                .map(|x| *x / *first_non_zero)
665                                .collect(),
666                        )
667                    })
668                    .collect::<Vec<_>>(),
669            );
670            acc
671        })
672    };
673    Ok(split_subspaces)
674}
675
676#[derive(Debug, Clone)]
677pub(crate) struct Split2dSpaceError<'a, T> {
678    vecs: &'a [Array1<T>],
679}
680
681impl<'a, T: Display + Debug> fmt::Display for Split2dSpaceError<'a, T> {
682    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
683        writeln!(
684            f,
685            "Unable to greedily split the two-dimensional degenerate subspace spanned by",
686        )?;
687        for vec in self.vecs {
688            writeln!(f, "  {vec}")?;
689        }
690        fmt::Result::Ok(())
691    }
692}
693
694impl<'a, T: Display + Debug> Error for Split2dSpaceError<'a, T> {}
695
696/// Splits a two-dimensional space using the trial-and-error approach suggested by
697/// Schneider, G. J. A. Dixon's character table algorithm revisited.
698/// *Journal of Symbolic Computation* **9**, 601–606 (1990),
699/// [DOI](https://doi.org/10.1016/S0747-7171(08)80077-6).
700///
701/// In cases of ambiguity, the Frobenius--Schur indicators of the prospective irreps are computed
702/// to help rule out invalid cases.
703///
704/// # Arguments
705///
706/// * `vecs` - The two basis vectors specifying the space.
707/// * `class_sizes` - Sizes for the conjugacy classes. This is required to compute the inner
708/// product defined in [`self::weighted_hermitian_inprod`].
709/// * `perm_for_conj` - The permutation indices to take a vector into its conjugate. This is
710/// required to compute the inner product defined in [`self::weighted_hermitian_inprod`].
711///
712/// # Returns
713///
714/// A vector of two vectors of vectors, where each inner vector contains the basis
715/// vectors for a one-dimensional subspace.
716///
717/// # Panics
718///
719/// Panics when inconsistent ring moduli between vector and matrix elements are found.
720///
721/// # Errors
722///
723/// Errors when the two-dimensional subspace cannot be split using this approach. This occurs when
724/// the Frobenius--Schur indicators fail to rule out all prospective cases.
725pub(crate) fn split_2d_space<'a, T>(
726    vecs: &'a [Array1<T>],
727    class_sizes: &[usize],
728    sq_indices: &[usize],
729    perm_for_conj: Option<&Vec<usize>>,
730) -> Result<Vec<Vec<Array1<T>>>, Split2dSpaceError<'a, T>>
731where
732    T: Display
733        + LinalgScalar
734        + Debug
735        + ModularInteger<Base = u32>
736        + Eq
737        + Hash
738        + Zero
739        + Inv
740        + Pow<u32, Output = T>
741        + panic::UnwindSafe
742        + panic::RefUnwindSafe,
743{
744    assert_eq!(vecs.len(), 2, "Only two-dimensional spaces are allowed.");
745    let rep = vecs
746        .iter()
747        .flatten()
748        .find(|x| panic::catch_unwind(|| x.modulus()).is_ok())
749        .expect("No known modulus found.");
750
751    // Echelonise the basis so that v0 has first entry of 1, while v1 has first entry of 0.
752    // This then ensures that v0 + ai*v1 (i = 0, 1) always has first entry of 1, thus satisfying
753    // the requirement for 'normalised' eigenvectors of class matrices.
754    let v_flat: Vec<T> = vecs.iter().flatten().cloned().collect();
755    let shape = (vecs.len(), vecs[0].dim());
756    let (v_mat, _) = modular_rref(&Array2::from_shape_vec(shape, v_flat).unwrap());
757    let vs = v_mat
758        .rows()
759        .into_iter()
760        .map(|v| v.to_owned())
761        .collect::<Vec<_>>();
762    let v0 = vs[0].clone();
763    let v1 = vs[1].clone();
764
765    let v00 = weighted_hermitian_inprod((&v0, &v0), class_sizes, perm_for_conj);
766    let v11 = weighted_hermitian_inprod((&v1, &v1), class_sizes, perm_for_conj);
767    let v01 = weighted_hermitian_inprod((&v0, &v1), class_sizes, perm_for_conj);
768    let v10 = weighted_hermitian_inprod((&v1, &v0), class_sizes, perm_for_conj);
769    let group_order = class_sizes.iter().sum::<usize>();
770    let group_order_u32 = u32::try_from(group_order)
771        .unwrap_or_else(|_| panic!("Unable to convert the group order {group_order} to `u32`."));
772    let sqrt_group_order = group_order
773        .to_f64()
774        .expect("Unable to convert the group order to `f64`.")
775        .sqrt()
776        .floor()
777        .to_u32()
778        .expect("Unable to convert the square root of the group order to `u32`.");
779    let one = rep.convert(1);
780    let p = rep.modulus();
781    let results = (1..=sqrt_group_order)
782        .filter_map(|d0_u32| {
783            if group_order.rem_euclid(usize::try_from(d0_u32).unwrap_or_else(|_| {
784                panic!("Unable to convert the trial dimension {d0_u32} to `usize`.")
785            })) != 0 {
786                None
787            } else {
788                let res_a0 = (0..p).filter_map(|a0_u32| {
789                    let a0 = rep.convert(a0_u32);
790                    if Zero::is_zero(
791                        &(a0 * (a0 * v11 + v01 + v10) + v00
792                            - one / rep.convert(d0_u32).square()),
793                    ) {
794                        let denom = a0 * v11 + v10;
795                        if Zero::is_zero(&denom) {
796                            None
797                        } else {
798                            let a1 = -(v00 + a0 * v01) / denom;
799                            let d1p2 = one / (a1 * (a1 * v11 + v01 + v10) + v00);
800                            let res_d1 = (1..=sqrt_group_order).filter_map(|d1_u32| {
801                                if group_order.rem_euclid(usize::try_from(d1_u32).unwrap_or_else(|_| {
802                                    panic!("Unable to convert the trial dimension {d1_u32} to `usize`.")
803                                })) == 0 && rep.convert(d1_u32).square() == d1p2 {
804
805                                    let v0_split = Array1::from_vec(
806                                        v0.iter()
807                                            .zip(v1.iter())
808                                            .map(|(&v0_x, &v1_x)| v0_x + a0 * v1_x)
809                                            .collect_vec(),
810                                    );
811                                    let v1_split = Array1::from_vec(
812                                        v0.iter()
813                                            .zip(v1.iter())
814                                            .map(|(&v0_x, &v1_x)| v0_x + a1 * v1_x)
815                                            .collect_vec(),
816                                    );
817
818                                    let d0 = rep.convert(d0_u32);
819                                    let d1 = rep.convert(d1_u32);
820                                    let char0 = v0_split
821                                        .iter()
822                                        .zip(class_sizes.iter())
823                                        .map(|(&x, &k)|
824                                            d0 * x / rep.convert(k as u32)
825                                        ).collect::<Vec<_>>();
826                                    let char1 = v1_split
827                                        .iter()
828                                        .zip(class_sizes.iter())
829                                        .map(|(&x, &k)|
830                                            d1 * x / rep.convert(k as u32)
831                                        ).collect::<Vec<_>>();
832
833                                    // Frobenius--Schur indicator calculation in GF(p)
834                                    let fs0 = sq_indices
835                                        .iter()
836                                        .zip(class_sizes.iter())
837                                        .fold(T::zero(), |acc, (&sq_idx, &k)| {
838                                            let k_u32 = u32::try_from(k).unwrap_or_else(|_| {
839                                                panic!("Unable to convert the class size {k} to `u32`.");
840                                            });
841                                            acc + rep.convert(k_u32) * char0[sq_idx]
842                                        }) / rep.convert(group_order_u32);
843                                    let fs0_good = fs0.is_one() || Zero::is_zero(&fs0) || fs0 == rep.convert(p - 1);
844                                    let fs1 = sq_indices
845                                        .iter()
846                                        .zip(class_sizes.iter())
847                                        .fold(T::zero(), |acc, (&sq_idx, &k)| {
848                                            let k_u32 = u32::try_from(k).unwrap_or_else(|_| {
849                                                panic!("Unable to convert the class size {k} to `u32`.");
850                                            });
851                                            acc + rep.convert(k_u32) * char1[sq_idx]
852                                        }) / rep.convert(group_order_u32);
853                                    let fs1_good = fs1.is_one() || Zero::is_zero(&fs1) || fs1 == rep.convert(p - 1);
854
855                                    if fs0_good && fs1_good && d0_u32 <= d1_u32 {
856                                        Some((
857                                            (v0_split, v1_split),
858                                            (d0_u32, d1_u32),
859                                        ))
860                                    } else {
861                                        None
862                                    }
863                                } else {
864                                    None
865                                }
866                            }).collect_vec();
867                            Some(res_d1)
868                        }
869                    } else {
870                        None
871                    }
872                })
873                .flatten()
874                .collect_vec();
875                Some(res_a0)
876            }
877        })
878        .flatten()
879        .collect_vec();
880
881    if results.len() == 1 {
882        // Unique solution found.
883        log::debug!(
884            "Greedy Schneider splitting algorithm for 2-D subspace found a unique solution."
885        );
886        let (v0_split, v1_split) = results[0].0.clone();
887        Ok(vec![vec![v0_split], vec![v1_split]])
888    } else {
889        // Multiple solutions found. The algorithm has failed.
890        log::debug!(
891            "Greedy Schneider splitting algorithm for 2-D subspace found {} solutions.",
892            results.len()
893        );
894        for (i, (_, (d0_u32, d1_u32))) in results.iter().enumerate() {
895            log::debug!("Irrep dimensionalities of solution {i}: ({d0_u32}, {d1_u32})");
896        }
897        Err(Split2dSpaceError { vecs })
898    }
899}