1use std::fmt;
23use std::path::PathBuf;
24
25use anyhow::{ensure, format_err};
26use derive_builder::Builder;
27use itertools::Itertools;
28use nalgebra::Point3;
29use ndarray::{Array2, Axis};
30use num_traits::ToPrimitive;
31use rayon::prelude::*;
32use serde::{Deserialize, Serialize};
33
34use crate::auxiliary::geometry::Transform;
35use crate::auxiliary::molecule::Molecule;
36use crate::drivers::symmetry_group_detection::{
37    SymmetryGroupDetectionDriver, SymmetryGroupDetectionParams,
38};
39use crate::drivers::QSym2Driver;
40use crate::io::format::{log_subtitle, log_title, nice_bool, qsym2_output, QSym2Output};
41use crate::io::QSym2FileType;
42use crate::permutation::IntoPermutation;
43use crate::symmetry::symmetry_core::{PreSymmetry, Symmetry};
44
45#[cfg(test)]
46#[path = "molecule_symmetrisation_bootstrap_tests.rs"]
47mod molecule_symmetrisation_bootstrap_tests;
48
49fn default_true() -> bool {
58    true
59}
60fn default_max_iterations() -> usize {
61    50
62}
63fn default_consistent_iterations() -> usize {
64    10
65}
66fn default_loose_threshold() -> f64 {
67    1e-2
68}
69fn default_tight_threshold() -> f64 {
70    1e-7
71}
72
73#[derive(Clone, Builder, Debug, Serialize, Deserialize)]
75pub struct MoleculeSymmetrisationBootstrapParams {
76    #[builder(default = "true")]
81    #[serde(default = "default_true")]
82    pub reorientate_molecule: bool,
83
84    #[builder(default = "1e-2")]
87    #[serde(default = "default_loose_threshold")]
88    pub loose_moi_threshold: f64,
89
90    #[builder(default = "1e-2")]
93    #[serde(default = "default_loose_threshold")]
94    pub loose_distance_threshold: f64,
95
96    #[builder(default = "1e-7")]
98    #[serde(default = "default_tight_threshold")]
99    pub target_moi_threshold: f64,
100
101    #[builder(default = "1e-7")]
103    #[serde(default = "default_tight_threshold")]
104    pub target_distance_threshold: f64,
105
106    #[builder(default = "50")]
108    #[serde(default = "default_max_iterations")]
109    pub max_iterations: usize,
110
111    #[builder(default = "10")]
115    #[serde(default = "default_consistent_iterations")]
116    pub consistent_target_symmetry_iterations: usize,
117
118    #[builder(default = "None")]
121    #[serde(default)]
122    pub infinite_order_to_finite: Option<u32>,
123
124    #[builder(default = "false")]
127    #[serde(default)]
128    pub use_magnetic_group: bool,
129
130    #[builder(default = "0")]
132    #[serde(default)]
133    pub verbose: u8,
134
135    #[builder(default = "None")]
138    #[serde(default)]
139    pub symmetrised_result_xyz: Option<PathBuf>,
140
141    #[builder(default = "None")]
145    #[serde(default)]
146    pub symmetrised_result_save_name: Option<PathBuf>,
147}
148
149impl MoleculeSymmetrisationBootstrapParams {
150    pub fn builder() -> MoleculeSymmetrisationBootstrapParamsBuilder {
152        MoleculeSymmetrisationBootstrapParamsBuilder::default()
153    }
154}
155
156impl Default for MoleculeSymmetrisationBootstrapParams {
157    fn default() -> Self {
158        Self::builder()
159            .build()
160            .expect("Unable to construct a default `MoleculeSymmetrisationBootstrapParams`.")
161    }
162}
163
164impl fmt::Display for MoleculeSymmetrisationBootstrapParams {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        writeln!(f, "Loose MoI threshold: {:.3e}", self.loose_moi_threshold)?;
167        writeln!(
168            f,
169            "Loose geo threshold: {:.3e}",
170            self.loose_distance_threshold
171        )?;
172        writeln!(f, "Target MoI threshold: {:.3e}", self.target_moi_threshold)?;
173        writeln!(
174            f,
175            "Target geo threshold: {:.3e}",
176            self.target_distance_threshold
177        )?;
178        writeln!(f)?;
179        writeln!(
180            f,
181            "Group used for symmetrisation: {}",
182            if self.use_magnetic_group {
183                "magnetic group"
184            } else {
185                "unitary group"
186            }
187        )?;
188        if let Some(finite_order) = self.infinite_order_to_finite {
189            writeln!(f, "Infinite order to finite: {finite_order}")?;
190        }
191        writeln!(
192            f,
193            "Maximum symmetrisation iterations: {}",
194            self.max_iterations
195        )?;
196        writeln!(
197            f,
198            "Target symmetry consistent iterations: {}",
199            self.consistent_target_symmetry_iterations
200        )?;
201        writeln!(f, "Output level: {}", self.verbose)?;
202        writeln!(
203            f,
204            "Save symmetrised molecule to XYZ file: {}",
205            if let Some(name) = self.symmetrised_result_xyz.as_ref() {
206                let mut path = name.clone();
207                path.set_extension("xyz");
208                path.display().to_string()
209            } else {
210                nice_bool(false)
211            }
212        )?;
213        writeln!(
214            f,
215            "Save symmetry-group detection results of symmetrised system to file: {}",
216            if let Some(name) = self.symmetrised_result_save_name.as_ref() {
217                let mut path = name.clone();
218                path.set_extension(QSym2FileType::Sym.ext());
219                path.display().to_string()
220            } else {
221                nice_bool(false)
222            }
223        )?;
224        writeln!(f)?;
225
226        Ok(())
227    }
228}
229
230#[derive(Clone, Builder, Debug)]
236pub struct MoleculeSymmetrisationBootstrapResult<'a> {
237    parameters: &'a MoleculeSymmetrisationBootstrapParams,
239
240    pub symmetrised_molecule: Molecule,
242}
243
244impl<'a> MoleculeSymmetrisationBootstrapResult<'a> {
245    fn builder() -> MoleculeSymmetrisationBootstrapResultBuilder<'a> {
246        MoleculeSymmetrisationBootstrapResultBuilder::default()
247    }
248}
249
250#[derive(Clone, Builder)]
256#[builder(build_fn(validate = "Self::validate"))]
257pub struct MoleculeSymmetrisationBootstrapDriver<'a> {
258    parameters: &'a MoleculeSymmetrisationBootstrapParams,
260
261    molecule: &'a Molecule,
263
264    #[builder(setter(skip), default = "None")]
266    result: Option<MoleculeSymmetrisationBootstrapResult<'a>>,
267}
268
269impl<'a> MoleculeSymmetrisationBootstrapDriverBuilder<'a> {
270    fn validate(&self) -> Result<(), String> {
271        let params = self
272            .parameters
273            .ok_or("No molecule symmetrisation parameters found.".to_string())?;
274        if params.consistent_target_symmetry_iterations > params.max_iterations {
275            return Err(format!(
276                "The number of consistent target-symmetry iterations, `{}`, cannot exceed the \
277                    maximum number of iterations, `{}`.",
278                params.consistent_target_symmetry_iterations, params.max_iterations,
279            ));
280        }
281        if params.target_moi_threshold < 0.0
282            || params.loose_moi_threshold < 0.0
283            || params.target_distance_threshold < 0.0
284            || params.loose_distance_threshold < 0.0
285        {
286            return Err("The thresholds cannot be negative.".to_string());
287        }
288        if params.target_moi_threshold > params.loose_moi_threshold {
289            return Err(format!(
290                "The target MoI threshold, `{:.3e}`, cannot be larger than the \
291                    loose MoI threshold, `{:.3e}`.",
292                params.target_moi_threshold, params.loose_moi_threshold
293            ));
294        }
295        if params.target_distance_threshold > params.loose_distance_threshold {
296            return Err(format!(
297                "The target distance threshold, `{:.3e}`, cannot be larger than the \
298                    loose distance threshold, `{:.3e}`.",
299                params.target_distance_threshold, params.loose_distance_threshold
300            ));
301        }
302        Ok(())
303    }
304}
305
306impl<'a> MoleculeSymmetrisationBootstrapDriver<'a> {
307    pub fn builder() -> MoleculeSymmetrisationBootstrapDriverBuilder<'a> {
309        MoleculeSymmetrisationBootstrapDriverBuilder::default()
310    }
311
312    fn symmetrise_molecule(&mut self) -> Result<(), anyhow::Error> {
314        log_title("Molecule Symmetrisation by Bootstrapping");
315        qsym2_output!("");
316        let params = self.parameters;
317        params.log_output_display();
318
319        let mut trial_mol = self.molecule.recentre();
320
321        if params.verbose >= 1 {
322            let orig_mol = self
323                .molecule
324                .adjust_threshold(params.target_distance_threshold);
325            qsym2_output!("Unsymmetrised original molecule:");
326            orig_mol.log_output_display();
327            qsym2_output!("");
328
329            qsym2_output!("Unsymmetrised recentred molecule:");
330            trial_mol.log_output_display();
331            qsym2_output!("");
332        }
333
334        if params.reorientate_molecule {
335            trial_mol.reorientate_mut(params.target_moi_threshold);
341            qsym2_output!("Unsymmetrised recentred and reoriented molecule:");
342            trial_mol.log_output_display();
343            qsym2_output!("");
344        };
345
346        log_subtitle("Iterative molecule symmetry bootstrapping");
347        qsym2_output!("");
348        qsym2_output!("Thresholds:");
349        qsym2_output!(
350            "  Loose : {:.3e} (MoI) - {:.3e} (distance)",
351            params.loose_distance_threshold,
352            params.loose_moi_threshold,
353        );
354        qsym2_output!(
355            "  Target: {:.3e} (MoI) - {:.3e} (distance)",
356            params.target_moi_threshold,
357            params.target_distance_threshold
358        );
359        qsym2_output!("");
360        qsym2_output!("Convergence criteria:");
361        qsym2_output!(
362            "  either: (1) when the loose-threshold symmetry agrees with the target-threshold symmetry,",
363        );
364        qsym2_output!(
365            "  or    : (2) when the target-threshold symmetry contains more elements than the loose-threshold symmetry and has been consistently identified for {} consecutive iteration{}.",
366            params.consistent_target_symmetry_iterations,
367            if params.consistent_target_symmetry_iterations == 1 { "" } else { "s" }
368        );
369        qsym2_output!("");
370
371        let count_length = usize::try_from(params.max_iterations.ilog10() + 1).map_err(|_| {
372            format_err!(
373                "Unable to convert `{}` to `usize`.",
374                params.max_iterations.ilog10() + 1
375            )
376        })?;
377        qsym2_output!("{}", "┈".repeat(count_length + 101));
378        qsym2_output!(
379            " {:>count_length$} {:>22} {:>19}  {:>22} {:>19}  {:>10}",
380            "#",
381            "Rot. sym. (loose)",
382            "Group (loose)",
383            "Rot. sym. (target)",
384            "Group (target)",
385            "Converged?",
386        );
387        qsym2_output!("{}", "┈".repeat(count_length + 101));
388
389        let mut symmetrisation_count = 0;
390        let mut consistent_target_sym_count = 0;
391        let mut loose_ops = vec![];
392        let mut prev_target_sym_group_name: Option<String> = None;
393        let mut converged = false;
394        while symmetrisation_count == 0
395            || (!converged && symmetrisation_count < params.max_iterations)
396        {
397            symmetrisation_count += 1;
398
399            let mut loose_mol =
403                trial_mol.adjust_threshold(self.parameters.loose_distance_threshold);
404            let loose_presym = PreSymmetry::builder()
405                .moi_threshold(self.parameters.loose_moi_threshold)
406                .molecule(&loose_mol)
407                .build()
408                .map_err(|_| {
409                    format_err!("Cannot construct a loose-threshold pre-symmetry structure.")
410                })?;
411
412            let mut loose_sym = Symmetry::new();
413
414            let _loose_res = loose_sym.analyse(&loose_presym, self.parameters.use_magnetic_group);
416
417            loose_ops.extend_from_slice(
422                &loose_sym.generate_all_operations(self.parameters.infinite_order_to_finite),
423            );
424            let n_ops_f64 = loose_ops.len().to_f64().ok_or_else(|| {
425                format_err!("Unable to convert the number of operations to `f64`.")
426            })?;
427
428            let ts = loose_ops
430                .into_par_iter()
431                .flat_map(|op| {
432                    let tmat = op
433                        .get_3d_spatial_matrix()
434                        .select(Axis(0), &[2, 0, 1])
435                        .select(Axis(1), &[2, 0, 1])
436                        .reversed_axes();
437
438                    let ord_perm = op
439                        .act_permute(&loose_mol.molecule_ordinary_atoms())
440                        .ok_or_else(|| {
441                            format_err!(
442                                "Unable to determine the ordinary-atom permutation corresponding to `{op}`."
443                            )
444                        })?;
445                    let mag_perm_opt = loose_mol
446                        .molecule_magnetic_atoms()
447                        .as_ref()
448                        .and_then(|loose_mag_mol| op.act_permute(loose_mag_mol));
449                    let elec_perm_opt = loose_mol
450                        .molecule_electric_atoms()
451                        .as_ref()
452                        .and_then(|loose_elec_mol| op.act_permute(loose_elec_mol));
453                    Ok::<_, anyhow::Error>((tmat, ord_perm, mag_perm_opt, elec_perm_opt))
454                })
455                .collect::<Vec<_>>();
456
457            let loose_ord_coords = Array2::from_shape_vec(
459                (loose_mol.atoms.len(), 3),
460                loose_mol
461                    .atoms
462                    .iter()
463                    .flat_map(|atom| atom.coordinates.coords.iter().cloned())
464                    .collect::<Vec<_>>(),
465            )?;
466            let ave_ord_coords = ts.iter().fold(
467                Array2::<f64>::zeros(loose_ord_coords.raw_dim()),
470                |acc, (tmat, ord_perm, _, _)| {
471                    acc + loose_ord_coords.dot(tmat).select(Axis(0), ord_perm.image())
475                },
476            ) / n_ops_f64;
477            loose_mol
478                .atoms
479                .par_iter_mut()
480                .enumerate()
481                .for_each(|(i, atom)| {
482                    atom.coordinates = Point3::<f64>::from_slice(
483                        ave_ord_coords
484                            .row(i)
485                            .as_slice()
486                            .expect("Unable to convert a row of averaged coordinates to a slice."),
487                    )
488                });
489
490            if let Some(mag_atoms) = loose_mol.magnetic_atoms.as_mut() {
492                let loose_mag_coords = Array2::from_shape_vec(
493                    (mag_atoms.len(), 3),
494                    mag_atoms
495                        .iter()
496                        .flat_map(|atom| atom.coordinates.coords.iter().cloned())
497                        .collect::<Vec<_>>(),
498                )?;
499                let ave_mag_coords = ts.iter().fold(
500                    Ok(Array2::<f64>::zeros(loose_mag_coords.raw_dim())),
501                    |acc: Result<Array2<f64>, anyhow::Error>, (tmat, _, mag_perm_opt, _)| {
502                        Ok(acc?
506                            + loose_mag_coords.dot(tmat).select(
507                                Axis(0),
508                                mag_perm_opt
509                                    .as_ref()
510                                    .ok_or_else(|| {
511                                        format_err!("Expected magnetic atom permutation not found.")
512                                    })?
513                                    .image(),
514                            ))
515                    },
516                )? / n_ops_f64;
517                mag_atoms.iter_mut().enumerate().for_each(|(i, atom)| {
518                    atom.coordinates = Point3::<f64>::from_slice(
519                        ave_mag_coords
520                            .row(i)
521                            .as_slice()
522                            .expect("Unable to convert a row of averaged coordinates to a slice."),
523                    )
524                });
525            }
526
527            if let Some(elec_atoms) = loose_mol.electric_atoms.as_mut() {
529                let loose_elec_coords = Array2::from_shape_vec(
530                    (elec_atoms.len(), 3),
531                    elec_atoms
532                        .iter()
533                        .flat_map(|atom| atom.coordinates.coords.iter().cloned())
534                        .collect::<Vec<_>>(),
535                )?;
536                let ave_elec_coords = ts.iter().fold(
537                    Ok(Array2::<f64>::zeros(loose_elec_coords.raw_dim())),
538                    |acc: Result<Array2<f64>, anyhow::Error>, (tmat, _, _, elec_perm_opt)| {
539                        Ok(acc?
543                            + loose_elec_coords.dot(tmat).select(
544                                Axis(0),
545                                elec_perm_opt
546                                    .as_ref()
547                                    .ok_or_else(|| {
548                                        format_err!("Expected electric atom permutation not found.")
549                                    })?
550                                    .image(),
551                            ))
552                    },
553                )? / n_ops_f64;
554                elec_atoms.iter_mut().enumerate().for_each(|(i, atom)| {
555                    atom.coordinates = Point3::<f64>::from_slice(
556                        ave_elec_coords
557                            .row(i)
558                            .as_slice()
559                            .expect("Unable to convert a row of averaged coordinates to a slice."),
560                    )
561                });
562            }
563
564            trial_mol = loose_mol;
565
566            trial_mol.recentre_mut();
568            if params.reorientate_molecule {
569                trial_mol.reorientate_mut(params.target_moi_threshold);
575            };
576
577            let target_mol = trial_mol.adjust_threshold(self.parameters.target_distance_threshold);
581            let target_presym = PreSymmetry::builder()
582                .moi_threshold(self.parameters.target_moi_threshold)
583                .molecule(&target_mol)
584                .build()
585                .map_err(|_| {
586                    format_err!("Cannot construct a target-threshold pre-symmetry structure.")
587                })?;
588
589            let mut target_sym = Symmetry::new();
590
591            let _ = target_sym.analyse(&target_presym, params.use_magnetic_group);
592
593            let target_loose_consistent = target_sym.n_elements() == loose_sym.n_elements()
594                && target_sym.group_name.is_some()
595                && target_sym.group_name == loose_sym.group_name;
596
597            if target_sym.group_name == prev_target_sym_group_name
598                && target_sym.n_elements() >= loose_sym.n_elements()
599            {
600                consistent_target_sym_count += 1;
601            } else {
602                consistent_target_sym_count = 0;
603            }
604            prev_target_sym_group_name = target_sym.group_name.clone();
605            let target_consistent =
606                consistent_target_sym_count >= params.consistent_target_symmetry_iterations;
607
608            converged = target_loose_consistent || target_consistent;
609            let converged_reason = [target_loose_consistent, target_consistent]
610                .iter()
611                .enumerate()
612                .filter_map(|(i, c)| {
613                    if *c {
614                        Some(format!("({})", i + 1))
615                    } else {
616                        None
617                    }
618                })
619                .join("");
620
621            qsym2_output!(
622                " {:>count_length$} {:>22} {:>19}  {:>22} {:>19}  {:>10}",
623                symmetrisation_count,
624                loose_presym.rotational_symmetry.to_string(),
625                format!(
626                    "{} ({})",
627                    loose_sym.group_name.as_ref().unwrap_or(&"--".to_string()),
628                    loose_sym.n_elements()
629                ),
630                target_presym.rotational_symmetry.to_string(),
631                format!(
632                    "{} ({})",
633                    target_sym.group_name.as_ref().unwrap_or(&"--".to_string()),
634                    target_sym.n_elements()
635                ),
636                if converged {
637                    "yes ".to_string() + &converged_reason
638                } else {
639                    "no".to_string()
640                },
641            );
642
643            loose_ops =
644                target_sym.generate_all_operations(self.parameters.infinite_order_to_finite);
645        }
646        qsym2_output!("{}", "┈".repeat(count_length + 101));
647        qsym2_output!("");
648
649        qsym2_output!("Verifying symmetrisation results...");
653        qsym2_output!("");
654        let verifying_pd_params = SymmetryGroupDetectionParams::builder()
655            .moi_thresholds(&[params.target_moi_threshold])
656            .distance_thresholds(&[params.target_distance_threshold])
657            .time_reversal(params.use_magnetic_group)
658            .write_symmetry_elements(true)
659            .result_save_name(params.symmetrised_result_save_name.clone())
660            .build()?;
661        let mut verifying_pd_driver = SymmetryGroupDetectionDriver::builder()
662            .parameters(&verifying_pd_params)
663            .molecule(Some(&trial_mol))
664            .build()?;
665        verifying_pd_driver.run()?;
666        let verifying_pd_res = verifying_pd_driver.result()?;
667        let verifying_group_name = if params.use_magnetic_group {
668            verifying_pd_res
669                .magnetic_symmetry
670                .as_ref()
671                .and_then(|magsym| magsym.group_name.as_ref())
672        } else {
673            verifying_pd_res.unitary_symmetry.group_name.as_ref()
674        };
675        ensure!(
676            prev_target_sym_group_name.as_ref() == verifying_group_name,
677            "Mismatched symmetry: iterative symmetry bootstrapping found {}, but verification found {}.",
678            prev_target_sym_group_name.as_ref().unwrap_or(&"--".to_string()),
679            verifying_group_name.unwrap_or(&"--".to_string()),
680
681        );
682        qsym2_output!("Verifying symmetrisation results... Done.");
683        qsym2_output!("");
684
685        self.result = Some(
689            MoleculeSymmetrisationBootstrapResult::builder()
690                .parameters(self.parameters)
691                .symmetrised_molecule(trial_mol.clone())
692                .build()?,
693        );
694
695        if let Some(xyz_name) = params.symmetrised_result_xyz.as_ref() {
696            let mut path = xyz_name.clone();
697            path.set_extension("xyz");
698            verifying_pd_res
699                .pre_symmetry
700                .recentred_molecule
701                .to_xyz(&path)?;
702            qsym2_output!("Symmetrised molecule written to: {}", path.display());
703            qsym2_output!("");
704        }
705
706        Ok(())
707    }
708}
709
710impl<'a> QSym2Driver for MoleculeSymmetrisationBootstrapDriver<'a> {
711    type Params = MoleculeSymmetrisationBootstrapParams;
712
713    type Outcome = MoleculeSymmetrisationBootstrapResult<'a>;
714
715    fn result(&self) -> Result<&Self::Outcome, anyhow::Error> {
716        self.result
717            .as_ref()
718            .ok_or_else(|| format_err!("No molecule sprucing results found."))
719    }
720
721    fn run(&mut self) -> Result<(), anyhow::Error> {
722        self.symmetrise_molecule()
723    }
724}