diff --git a/.ci_support/environment.yml b/.ci_support/environment.yml index aec78ee4d..ff5d390ad 100644 --- a/.ci_support/environment.yml +++ b/.ci_support/environment.yml @@ -17,7 +17,7 @@ dependencies: - scikit-learn =1.8.0 - scipy =1.17.1 - spglib =2.7.0 -- sqsgenerator =0.5.4 +- sqsgenerator =0.5.5 - hatchling =1.29.0 - hatch-vcs =0.5.0 - mp-api =0.37.2 diff --git a/src/structuretoolkit/build/sqs.py b/src/structuretoolkit/build/sqs.py deleted file mode 100644 index 255ef3d64..000000000 --- a/src/structuretoolkit/build/sqs.py +++ /dev/null @@ -1,236 +0,0 @@ -import itertools -import random -import warnings -from collections.abc import Iterable -from multiprocessing import cpu_count - -import numpy as np -from ase.atoms import Atoms -from ase.data import atomic_numbers - - -def chemical_formula(atoms: Atoms) -> str: - """ - Generate the chemical formula of an Atoms object. - - Args: - atoms (Atoms): The Atoms object representing the structure. - - Returns: - str: The chemical formula of the structure. - - """ - - def group_symbols(): - for species, same in itertools.groupby(atoms.get_chemical_symbols()): - num_same = len(list(same)) - yield species if num_same == 1 else f"{species}{num_same}" - - return "".join(group_symbols()) - - -def map_dict(f, d: dict) -> dict: - """ - Apply a function to each value in a dictionary. - - Args: - f: The function to apply. - d (Dict): The dictionary to apply the function to. - - Returns: - Dict: The dictionary with the function applied to each value. - - """ - return {k: f(v) for k, v in d.items()} - - -def mole_fractions_to_composition( - mole_fractions: dict[str, float], num_atoms: int -) -> dict[str, int]: - """ - Convert mole fractions to composition. - - Args: - mole_fractions (Dict[str, float]): The mole fractions of each species. - num_atoms (int): The total number of atoms. - - Returns: - Dict[str, int]: The composition of each species. - - Raises: - ValueError: If the sum of mole fractions is not within the range (1 - 1/num_atoms, 1 + 1/num_atoms). - - """ - if not (1.0 - 1 / num_atoms) < sum(mole_fractions.values()) < (1.0 + 1 / num_atoms): - raise ValueError( - f"mole-fractions must sum up to one: {sum(mole_fractions.values())}" - ) - - composition = map_dict(lambda x: x * num_atoms, mole_fractions) - # check to avoid partial occupation -> x_i * num_atoms is not an integer number - if any( - not float.is_integer(round(occupation, 1)) - for occupation in composition.values() - ): - # at least one of the specified species exhibits fractional occupation, we try to fix it by rounding - composition_ = map_dict(lambda occ: int(round(occ)), composition) - warnings.warn( - f"The current mole-fraction specification cannot be applied to {num_atoms} atoms, " - "as it would lead to fractional occupation. Hence, I have changed it from " - f'"{mole_fractions}" -> "{map_dict(lambda n: n / num_atoms, composition_)}"', - stacklevel=2, - ) - composition = composition_ - - # due to rounding errors there might be a difference of one atom - actual_atoms = sum(composition.values()) - diff = actual_atoms - num_atoms - if abs(diff) == 1: - # it is not possible to distribute atoms equally e.g x_a = x_b = x_c = 1/3 on 32 atoms - # we remove one randomly bet we inform the user - removed_species = random.choice(tuple(composition)) - composition[removed_species] -= 1 - warnings.warn( - f'It is not possible to distribute the species properly. Therefore one "{removed_species}" atom was removed. ' - "This changes the input mole-fraction specification. " - f'"{mole_fractions}" -> "{map_dict(lambda n: n / num_atoms, composition)}"', - stacklevel=2, - ) - elif abs(diff) > 1: - # something else is wrong with the mole-fractions input - raise ValueError(f"Cannot interpret mole-fraction dict {mole_fractions}") - - return composition - - -def remap_sro(species: Iterable[str], array: np.ndarray) -> dict[str, list]: - """ - Remap computed short-range order parameters to the style of sqsgenerator=v0.0.5. - - Args: - species (Iterable[str]): The species in the structure. - array (np.ndarray): The computed short-range order parameters. - - Returns: - Dict[str, list]: The remapped short-range order parameters. - - """ - species = tuple(sorted(species, key=lambda abbr: atomic_numbers[abbr])) - return { - f"{si}-{sj}": array[:, i, j].tolist() - for (i, si), (j, sj) in itertools.product( - enumerate(species), enumerate(species) - ) - if j >= i - } - - -def remap_sqs_results( - result: dict[str, Atoms | np.ndarray], -) -> tuple[Atoms, dict[str, list]]: - """ - Remap the results of SQS optimization. - - Args: - result (Dict[str, Union[Atoms, np.ndarray]]): The result of SQS optimization. - - Returns: - Tuple[Atoms, Dict[str, list]]: The remapped structure and short-range order parameters. - - """ - return result["structure"], remap_sro( - set(result["structure"].get_chemical_symbols()), result["parameters"] - ) - - -def transpose(it: Iterable[Iterable]) -> Iterable[tuple]: - """ - Transpose an iterable of iterables. - - Args: - it (Iterable[Iterable]): The iterable to transpose. - - Returns: - Iterable[tuple]: The transposed iterable. - - """ - return zip(*it, strict=True) - - -def sqs_structures( - structure: Atoms, - mole_fractions: dict[str, float | int], - weights: dict[int, float] | None = None, - objective: float | np.ndarray = 0.0, - iterations: float | int = 1e6, - output_structures: int = 10, - mode: str = "random", - num_threads: int | None = None, - rtol: float | None = None, - atol: float | None = None, - return_statistics: bool | None = False, -) -> Atoms | tuple[Atoms, dict[str, list], int, float]: - """ - Generate SQS structures. - - Args: - structure (Atoms): The initial structure. - mole_fractions (Dict[str, Union[float, int]]): The mole fractions of each species. - weights (Optional[Dict[int, float]]): The weights for each shell. - objective (Union[float, np.ndarray]): The target objective value. - iterations (Union[float, int]): The number of iterations. - output_structures (int): The number of output structures. - mode (str): The mode for selecting configurations. - num_threads (Optional[int]): The number of threads to use. - rtol (Optional[float]): The relative tolerance. - atol (Optional[float]): The absolute tolerance. - return_statistics (Optional[bool]): Whether to return additional statistics. - - Returns: - Union[Atoms, Tuple[Atoms, Dict[str, list], int, float]]: The generated structures or a tuple containing the structures, short-range order parameters breakdown, number of iterations, and average cycle time. - - """ - from sqsgenerator import optimize, to_ase - - composition = mole_fractions_to_composition(mole_fractions, len(structure)) - - settings = { - "atol": atol, - "rtol": rtol, - "iteration_mode": mode, - "structure": { - "lattice": structure.cell.array.tolist(), - "coords": structure.get_scaled_positions().tolist(), - "species": structure.get_chemical_symbols(), - }, - "iterations": int(iterations), - "composition": {k: int(v) for k, v in composition.items()}, - "target_objective": objective, - "thread_config": num_threads or cpu_count(), - "max_results_per_objective": output_structures, - } - if weights is not None: - settings["shell_weights"] = weights - - # not specifying a parameter in settings causes sqsgenerator to choose a "sensible" default, - # hence we remove all entries with a None value - result = optimize({k: v for k, v in settings.items() if v is not None}) - - structures = [] - sro_breakdown: list = [] - finished = False - for r_lst in result: - if not finished: - for s in r_lst[1]: - structures.append(to_ase(s.structure())) - sro_breakdown.append(s.sro()) - if len(structures) == output_structures: - finished = True - break - cycle_time = list(result.statistics.timings.values())[0] - num_iterations = result.config.iterations - - if not return_statistics: - return structures - else: - return structures, sro_breakdown, num_iterations, cycle_time diff --git a/src/structuretoolkit/build/sqs/__init__.py b/src/structuretoolkit/build/sqs/__init__.py new file mode 100644 index 000000000..c40646142 --- /dev/null +++ b/src/structuretoolkit/build/sqs/__init__.py @@ -0,0 +1,33 @@ +from ._types import ( + SqsResultSplit, + SqsResult, + SqsResultInteract, + SublatticeMode, + IterationMode, + Site, + Element, + ShellRadii, + ShellWeights, + Shell, + Composition, + Prec, + SroParameter, +) +from ._interface import sqs_structures + +__all__ = [ + "Composition", + "Element", + "IterationMode", + "Prec", + "Shell", + "ShellRadii", + "ShellWeights", + "Site", + "SroParameter", + "SqsResult", + "SqsResultInteract", + "SqsResultSplit", + "SublatticeMode", + "sqs_structures", +] diff --git a/src/structuretoolkit/build/sqs/_interface.py b/src/structuretoolkit/build/sqs/_interface.py new file mode 100644 index 000000000..f5e9c0020 --- /dev/null +++ b/src/structuretoolkit/build/sqs/_interface.py @@ -0,0 +1,310 @@ +from __future__ import annotations +from ase.atoms import Atoms +from threading import Event, Thread +from typing import overload, Literal, TypeVar, Generic, Any, Iterator + +from ._types import ( + Composition, + ShellWeights, + ShellRadii, + SublatticeMode, + IterationMode, + Prec, + SqsResultInteract, + SqsResultSplit, + LogLevel, +) + +R = TypeVar("R", SqsResultInteract, SqsResultSplit) +T = TypeVar("T") + + +class _SqsResultProxy(Generic[R]): + def __init__(self, result: R): + self._result = result + + def atoms(self) -> Atoms: + from sqsgenerator import to_ase + + return to_ase(self._result.structure()) + + def __getattr__(self, item: str) -> Any: + return getattr(self._result, item) + + def sublattices(self) -> list[_SqsResultProxy[SqsResultInteract]]: + from sqsgenerator.core import SqsResultSplitDouble, SqsResultSplitFloat + + if isinstance(self._result, (SqsResultSplitDouble, SqsResultSplitFloat)): + return [ + _SqsResultProxy[SqsResultInteract](r) + for r in self._result.sublattices() + ] + else: + raise AttributeError( + f"{type(self._result).__name__} has no attribute 'sublattices'" + ) + + +class SqsResultPack(Generic[R]): + def __init__(self, pack): + self._pack = pack + + def __len__(self) -> int: + return len(self._pack) + + def best(self) -> _SqsResultProxy[R]: + return _SqsResultProxy(self._pack.best()) + + def num_objectives(self) -> int: + return self._pack.num_objectives() + + def num_results(self) -> int: + return self._pack.num_results() + + def __iter__(self) -> Iterator[_SqsResultProxy[R]]: + for _, results in self._pack: + for result in results: + yield _SqsResultProxy(result) + + +def _ensure_list(v: T | list[T] | None) -> list[T] | None: + if v is not None: + return v if isinstance(v, list) else [v] + else: + return None + + +@overload +def sqs_structures( + structure: Atoms, + composition: Composition, + supercell: tuple[int, int, int] | None = None, + shell_weights: ShellWeights | None = None, + shell_radii: ShellRadii | None = None, + objective: float | None = None, + iterations: int = 1_000_000, + atol: float | None = None, + rtol: float | None = None, + sublattice_mode: Literal["interact"] = "interact", + iteration_mode: IterationMode = "random", + num_threads: int | None = None, + precision: Prec = "single", + max_results_per_objective: int = 10, + log_level: LogLevel = "warn", + **kwargs: Any, +) -> SqsResultPack[SqsResultInteract]: ... + + +@overload +def sqs_structures( + structure: Atoms, + composition: list[Composition], + supercell: tuple[int, int, int] | None = None, + shell_weights: list[ShellWeights] | None = None, + shell_radii: list[ShellRadii] | None = None, + objective: list[float] | None = None, + iterations: int = 1_000_000, + atol: float | None = None, + rtol: float | None = None, + sublattice_mode: Literal["split"] = "split", + iteration_mode: IterationMode = "random", + num_threads: int | None = None, + precision: Prec = "single", + max_results_per_objective: int = 10, + log_level: LogLevel = "warn", + **kwargs: Any, +) -> SqsResultPack[SqsResultSplit]: ... + + +def sqs_structures( + structure: Atoms, + composition: list[Composition] | Composition, + supercell: tuple[int, int, int] | None = None, + shell_weights: list[ShellWeights] | ShellWeights | None = None, + shell_radii: list[ShellRadii] | ShellRadii | None = None, + objective: list[float] | float | None = None, + iterations: int = 1_000_000, + atol: float | None = None, + rtol: float | None = None, + sublattice_mode: SublatticeMode = "interact", + iteration_mode: IterationMode = "random", + num_threads: int | None = None, + precision: Prec = "single", + max_results_per_objective: int = 10, + log_level: LogLevel = "warn", + **kwargs: Any, +) -> SqsResultPack[SqsResultSplit] | SqsResultPack[SqsResultInteract]: + """ + Generate special quasirandom structures (SQS) using the sqsgenerator package. + The function can handle both single and multiple sublattices, + + Args: + structure: (ase.atoms.Atoms) The initial structure to optimize. + composition (dict[str, int] | list[dict[str, int]]): The target composition(s) for the optimization. + Each dictionary should map element symbols to their desired counts. If a list is provided, each dictionary + corresponds to a different sublattice. A list is expected if sublattice_mode is set to "split". + Use "sites" key to specify the sites that belong to a sublattice, + e.g. {"Cu": 8, "Au": 8, "sites": [0, 1, ..., 15]}. In case you use "sites" key with atomic species, + e.g. {"Cu": 8, "Au": 8, "sites": "Al"}, the "Al" sites refer to the atoms of the {structure} argument. + supercell (tuple[int, int, int] | None): The supercell size to use for the optimization. + If None, the original cell is used. + shell_weights (dict[int, float] | list[dict[int, float]] | None): The weights for each shell in the objective + function. The keys should be the shell numbers (starting from 1) and the values should be the corresponding + weights. If a list is provided, each dictionary corresponds to a different sublattice ("split" mode). + shell_radii (list[float] | list[list[float]] | None): The radii for each shell. Use to manually define + the coordination shell radii. The list should contain the radii for each shell, starting from the first + shell. If a list of lists is provided, each inner list corresponds to a different sublattice ("split" mode). + If set to None 0.0 (=random) will be used in interact and [0.0, 0.0, ...] in split mode. + objective: (float | list[float]) The target objective value(s) for the optimization. If a list is provided, + each value corresponds to a different sublattice ("split" mode). In split mode diverging objectives are + supported, e.g. [0, 1], to enable clustering, ordering, partial ordering or randomization for each sublattice. + iterations: (int) The maximum number of iterations to perform during the optimization. In case iteration_mode + is set to "systematic", this parameter is ignored. + atol (float | None): The absolute tolerance for shell radii detection. If None, no absolute tolerance is used. + rtol (float | None): The relative tolerance for shell radii detection. If None, no relative tolerance is used. + sublattice_mode (str): The mode to use for handling sublattices. Can be either "interact" or "split". + In "interact" mode, the whole cell is treated as a whole. In "split" mode, the cell is split into + sublattices according to the "sites" key in the composition dictionaries, and the optimization is performed + separately for each sublattice. "split" mode does not support iteration_mode "systematic". + iteration_mode (str): The mode to use for iterating through the configuration space. + Can be either "random" or "systematic". + num_threads (int | None): The number of threads to use for the optimization. + If None, the optimization will use the number of hardware threads it detects. + precision (str): The precision to use for the optimization. Can be either "single" or "double". + max_results_per_objective (int): The maximum number of results to return for each objective value. + If the optimization finds more results with the same objective value, only at most + {max_results_per_objective} results will be kept. + log_level (str): The log level to use for the optimization. Can be either "trace", "debug", "info", "warn", + or "error". The log level controls the verbosity of the output during optimization. + **kwargs (Any): Additional keyword arguments to pass to the sqsgenerator optimization function. + + Returns: + SqsResultPack: A pack of optimization results. The type of the results (SqsResultInteract or SqsResultSplit) + depends on the sublattice_mode used for the optimization. + + """ + + from sqsgenerator import parse_config + from sqsgenerator.core import ( + ParseError, + LogLevel as SqsLogLevel, + SqsCallbackContext, + optimize as sqs_optimize, + ) + + config = dict( + prec=precision, + iteration_mode=iteration_mode, + sublattice_mode=sublattice_mode, + structure=dict( + lattice=structure.cell.array.tolist(), + coords=structure.get_scaled_positions().tolist(), + species=structure.get_atomic_numbers().tolist(), + ), + iterations=iterations, + max_results_per_objective=max_results_per_objective, + ) + if atol is not None: + config["atol"] = atol + if rtol is not None: + config["rtol"] = rtol + if supercell is not None: + if all(n > 0 for n in supercell): + config["structure"]["supercell"] = supercell + else: + raise ValueError( + f"Invalid supercell: {supercell}. All dimensions must be positive integers." + ) + + def _preprocess_for_mode(v: T | list[T] | None) -> list[T] | None: + match sublattice_mode: + case "interact": + return v + case "split": + return _ensure_list(v) + case _: + raise ValueError( + f"Invalid sublattice mode: {sublattice_mode}. Use 'interact' or 'split'." + ) + + if (composition := _preprocess_for_mode(composition)) is not None: + config["composition"] = composition + if (shell_weights := _preprocess_for_mode(shell_weights)) is not None: + config["shell_weights"] = shell_weights + if (shell_radii := _preprocess_for_mode(shell_radii)) is not None: + config["shell_radii"] = shell_radii + if objective is None: + objective = 0.0 if sublattice_mode == "interact" else [0.0] * len(composition) + config["target_objective"] = _preprocess_for_mode(objective) + + if num_threads is not None: + if num_threads > 0: + config["thread_config"] = num_threads + else: + raise ValueError( + f"Invalid num_threads: {num_threads}. Must be a positive integer." + ) + + for kwarg, val in kwargs.items(): + if val is not None: + config[kwarg] = val + + config = parse_config(config) + if isinstance(config, ParseError): + raise ValueError( + f"Failed to parse config: parameter {config.key} - {config.msg}" + ) + + stop_gracefully: bool = False + stop_event = Event() + + def _callback(ctx: SqsCallbackContext) -> None: + nonlocal stop_gracefully + if stop_gracefully: + ctx.stop() + + optimization_result: SqsResultPack | None = None + + match log_level: + case "warn": + level = SqsLogLevel.warn + case "info": + level = SqsLogLevel.info + case "debug": + level = SqsLogLevel.debug + case "error": + level = SqsLogLevel.error + case "trace": + level = SqsLogLevel.trace + case _: + raise ValueError( + f"Invalid log level: {log_level}. Use 'trace', 'debug', 'info', 'warn', or 'error'." + ) + + def _optimize(): + result_local = sqs_optimize(config, log_level=level, callback=_callback) + stop_event.set() + nonlocal optimization_result + optimization_result = result_local + + t = Thread(target=_optimize) + t.start() + try: + while t.is_alive() and not stop_event.is_set(): + stop_event.wait(timeout=1.0) + except (KeyboardInterrupt, EOFError): + stop_gracefully = True + finally: + try: + t.join(timeout=5.0) + except TimeoutError: + raise RuntimeError( + "Optimization thread did not finish within the timeout period after requesting it to stop. " + "The optimization may still be running in the background. Try to decrease chunk_size by passing it as a " + "keyword argument to sqs_structures to make the optimization more responsive to stop requests." + ) + + if optimization_result is None: + raise RuntimeError("Optimization failed to produce a result.") + else: + return SqsResultPack(optimization_result) diff --git a/src/structuretoolkit/build/sqs/_types.py b/src/structuretoolkit/build/sqs/_types.py new file mode 100644 index 000000000..8abd1b473 --- /dev/null +++ b/src/structuretoolkit/build/sqs/_types.py @@ -0,0 +1,195 @@ +import numpy as np +from ase import Atoms +from typing import Literal, TypeAlias, Protocol, overload, Union + +Shell: TypeAlias = int + +SublatticeMode = Literal["split", "interact"] +IterationMode = Literal["random", "systematic"] +Element = Literal[ + "0", + "H", + "He", + "Li", + "Be", + "B", + "C", + "N", + "O", + "F", + "Ne", + "Na", + "Mg", + "Al", + "Si", + "P", + "S", + "Cl", + "Ar", + "K", + "Ca", + "Sc", + "Ti", + "V", + "Cr", + "Mn", + "Fe", + "Co", + "Ni", + "Cu", + "Zn", + "Ga", + "Ge", + "As", + "Se", + "Br", + "Kr", + "Rb", + "Sr", + "Y", + "Zr", + "Nb", + "Mo", + "Tc", + "Ru", + "Rh", + "Pd", + "Ag", + "Cd", + "In", + "Sn", + "Sb", + "Te", + "I", + "Xe", + "Cs", + "Ba", + "La", + "Ce", + "Pr", + "Nd", + "Pm", + "Sm", + "Eu", + "Gd", + "Tb", + "Dy", + "Ho", + "Er", + "Tm", + "Yb", + "Lu", + "Hf", + "Ta", + "W", + "Re", + "Os", + "Ir", + "Pt", + "Au", + "Hg", + "Tl", + "Pb", + "Bi", + "Po", + "At", + "Rn", + "Fr", + "Ra", + "Ac", + "Th", + "Pa", + "U", + "Np", + "Pu", + "Am", + "Cm", + "Bk", + "Cf", + "Es", + "Fm", + "Md", + "No", + "Lr", + "Rf", + "Db", + "Sg", + "Bh", + "Hs", + "Mt", + "Ds", + "Rg", + "Cn", + "Uut", + "Fl", +] + +Site = Union[str, list[int]] + +Prec = Literal["single", "double"] + +Composition = dict[Element | Literal["sites"], Union[int, Site]] + +ShellWeights = dict[Shell, float] + +ShellRadii = list[float] + +LogLevel = Literal["warn", "info", "debug", "error", "trace"] + + +class SroParameter: + @property + def i(self) -> int: ... + + @property + def j(self) -> int: ... + + @property + def shell(self) -> int: ... + + def __float__(self) -> float: ... + + @property + def value(self) -> float: ... + + +class SqsResultInteract(Protocol): + def shell_index(self, shell: int) -> int: ... + + def species_index(self, species: int) -> int: ... + + def rank(self) -> str: ... + + @overload + def sro( + self, + ) -> ( + np.ndarray[tuple[int, int, int], np.dtype[np.float32]] + | np.ndarray[tuple[int, int, int], np.dtype[np.float64]] + ): ... + + @overload + def sro(self, shell: int) -> list[SroParameter]: ... + + @overload + def sro(self, i: int, j: int) -> list[SroParameter]: ... + + @overload + def sro(self, shell: int, i: int, j: int) -> SroParameter: ... + + @property + def objective(self) -> float: ... + + def atoms(self) -> Atoms: ... + + +class SqsResultSplit(Protocol): + @property + def objective(self) -> float: ... + + def atoms(self) -> Atoms: ... + + def sublattices(self) -> list[SqsResultInteract]: ... + + +SqsResult = SqsResultSplit | SqsResultInteract diff --git a/tests/test_sqs.py b/tests/test_sqs.py index 897aa1ed3..6ff7d3f09 100644 --- a/tests/test_sqs.py +++ b/tests/test_sqs.py @@ -17,51 +17,223 @@ "sqsgenerator is not available, so the sqsgenerator related unittests are skipped.", ) class SQSTestCase(unittest.TestCase): - def test_sqs_structures_no_stats(self): - structures_lst = stk.build.sqs_structures( + + def test_errors_simple(self): + config = dict( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=dict(Cu=16, Au=16), + ) + with self.assertRaises(ValueError): + stk.build.sqs_structures(supercell=(-1, 0, 0), **config) + + with self.assertRaises(ValueError): + stk.build.sqs_structures(sublattice_mode="test", **config) + + with self.assertRaises(ValueError): + stk.build.sqs_structures(num_threads=-2, **config) + + with self.assertRaises(AttributeError): + # ensure proxy works as expected + stk.build.sqs_structures( + structure=bulk("Au", cubic=True), + composition=dict(Cu=16, Au=16), + supercell=(2, 2, 2), + ).best().not_defined + + def test_sqs_structures_simple_supercell(self): + result = stk.build.sqs_structures( + structure=bulk("Au", cubic=True), + composition=dict(Cu=16, Au=16), + supercell=(2, 2, 2), + ).best() + symbols = result.atoms().get_chemical_symbols() + + self.assertEqual(len(symbols), 32) + for el in ["Au", "Cu"]: + self.assertAlmostEqual(symbols.count(el) / len(symbols), 0.5) + + self.assertEqual((5, 2, 2), result.sro().shape) + + def test_sqs_structures_simple_shell_weights(self): + NSHELLS = 3 + SPECIES = ["Au", "Cu", "Al", "Mg"] + result = stk.build.sqs_structures( + structure=bulk("Au", cubic=True), + composition={specie: 8 for specie in SPECIES}, + shell_weights={i + 1: float(i + 1) for i in range(NSHELLS)}, + supercell=(2, 2, 2), + ).best() + symbols = result.atoms().get_chemical_symbols() + + self.assertEqual(len(symbols), 32) + for el in SPECIES: + self.assertAlmostEqual(symbols.count(el) / len(symbols), 0.25) + + self.assertEqual((NSHELLS, len(SPECIES), len(SPECIES)), result.sro().shape) + + def test_sqs_structures_simple(self): + result = stk.build.sqs_structures( structure=bulk("Au", cubic=True).repeat([2, 2, 2]), - mole_fractions={"Cu": 0.5, "Au": 0.5}, - weights=None, - objective=0.0, - iterations=1e6, - output_structures=10, - mode="random", - num_threads=None, - rtol=None, - atol=None, - return_statistics=False, + composition=dict(Cu=16, Au=16), + ).best() + symbols = result.atoms().get_chemical_symbols() + + self.assertEqual(len(symbols), 32) + for el in ["Au", "Cu"]: + self.assertAlmostEqual(symbols.count(el) / len(symbols), 0.5) + + self.assertEqual((5, 2, 2), result.sro().shape) + + def test_sqs_structures_multiple_sublattices(self): + result = stk.build.sqs_structures( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=[ + dict(Cu=8, Au=8, sites=list(range(16))), + dict(Al=8, Mg=8, sites=list(range(16, 32))), + ], + objective=[0, 0], + sublattice_mode="split", + ).best() + symbols = result.atoms().get_chemical_symbols() + + self.assertEqual(len(symbols), 32) + for el in ["Au", "Cu", "Al", "Mg"]: + self.assertAlmostEqual(symbols.count(el) / len(symbols), 0.25) + cu_au, al_mg = result.sublattices() + self.assertEqual(len(cu_au.atoms()), len(al_mg.atoms())) + + def test_sqs_structures_simple_many(self): + results = stk.build.sqs_structures( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=dict(Cu=16, Au=16), ) - self.assertEqual(len(structures_lst), 10) - symbols_lst = [s.get_chemical_symbols() for s in structures_lst] - for s in symbols_lst: - self.assertEqual(len(s), 32) + + last_objective: float | None = None + # sqsgenerator yields structures in order of increasing objective, so the objective should be non-decreasing. + objectives: set[float] = set() + num_solutions = 0 + for result in results: + symbols = result.atoms().get_chemical_symbols() + self.assertEqual(len(symbols), 32) for el in ["Au", "Cu"]: - self.assertAlmostEqual(s.count(el) / len(s), 0.5) + self.assertAlmostEqual(symbols.count(el) / len(symbols), 0.5) + self.assertEqual((5, 2, 2), result.sro().shape) + if last_objective is not None: + self.assertGreaterEqual(result.objective, last_objective) + last_objective = result.objective + objectives.add(result.objective) + num_solutions += 1 + + self.assertEqual(num_solutions, results.num_results()) + self.assertEqual(len(objectives), results.num_objectives()) + self.assertEqual(len(results), results.num_objectives()) + + def test_sqs_structures_tolerances_and_radii(self): + # test atol and rtol + stk.build.sqs_structures( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=dict(Cu=16, Au=16), + atol=1e-5, + rtol=1e-5, + iterations=10, + ) - def test_sqs_structures_with_stats(self): - structures_lst, sro_breakdown, num_iterations, cycle_time = ( + # test shell_radii in interact mode + stk.build.sqs_structures( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=dict(Cu=16, Au=16), + shell_radii=[2.5, 4.0], + iterations=10, + ) + + # test shell_radii in split mode + stk.build.sqs_structures( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=[ + dict(Cu=8, Au=8, sites=list(range(16))), + dict(Al=8, Mg=8, sites=list(range(16, 32))), + ], + shell_radii=[[2.5, 4.0], [2.5, 4.0]], + sublattice_mode="split", + iterations=10, + ) + + def test_sqs_structures_log_levels_and_kwargs(self): + config = dict( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=dict(Cu=16, Au=16), + iterations=10, + ) + for level in ["info", "debug", "error", "trace"]: + stk.build.sqs_structures(log_level=level, **config) + + # test invalid log level + with self.assertRaises(ValueError): + stk.build.sqs_structures(log_level="invalid", **config) + + # test kwargs + stk.build.sqs_structures(chunk_size=1, **config) + + def test_sqs_structures_errors(self): + config = dict( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + iterations=10, + ) + + # test ParseError from sqsgenerator + with self.assertRaises(ValueError): + stk.build.sqs_structures(composition=dict(InvalidElement=16, Au=16), **config) + + def test_sqs_result_proxy_sublattices_error(self): + result = stk.build.sqs_structures( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=dict(Cu=16, Au=16), + iterations=10, + ).best() + + with self.assertRaises(AttributeError): + result.sublattices() + + def test_sqs_keyboard_interrupt(self): + from unittest.mock import patch + + # We mock time.sleep to raise KeyboardInterrupt to simulate it during the wait loop + # However sqs_structures uses stop_event.wait(timeout=1.0) + # Let's mock stop_event.wait instead, but carefully. + + # We need to make sure we only mock the wait call inside sqs_structures loop + # but since we are mocking the class Event in the module, it might be safer to mock the instance + # but we don't have access to the instance easily. + + # Alternatively, we can mock Thread.is_alive to raise it. + # However, stk.build.sqs_structures catches KeyboardInterrupt and sets stop_gracefully = True + # but it DOES NOT re-raise it if a result is already available or if it finishes. + # Actually it should probably re-raise it or return what it has. + # In the current implementation, it catches it and proceeds to join the thread and return results. + with patch("structuretoolkit.build.sqs._interface.Thread.is_alive", side_effect=[True, KeyboardInterrupt, False]): stk.build.sqs_structures( structure=bulk("Au", cubic=True).repeat([2, 2, 2]), - mole_fractions={"Cu": 0.5, "Au": 0.5}, - weights=None, - objective=0.0, - iterations=1e6, - output_structures=10, - mode="random", - num_threads=None, - rtol=None, - atol=None, - return_statistics=True, + composition=dict(Cu=16, Au=16), + iterations=10, ) + + def test_sqs_num_threads(self): + stk.build.sqs_structures( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=dict(Cu=16, Au=16), + iterations=10, + num_threads=2 ) - self.assertEqual(len(structures_lst), 10) - symbols_lst = [s.get_chemical_symbols() for s in structures_lst] - for s in symbols_lst: - self.assertEqual(len(s), 32) - for el in ["Au", "Cu"]: - self.assertAlmostEqual(s.count(el) / len(s), 0.5) - self.assertEqual(len(sro_breakdown), len(structures_lst)) - for sro in sro_breakdown: - self.assertEqual((5, 2, 2), sro.shape) - self.assertEqual(num_iterations, 1000000) - self.assertTrue(cycle_time < 100000000000) + + def test_sqs_optimization_failed(self): + from unittest.mock import patch + # sqs_optimize is imported inside the function, so we need to mock it where it is used. + # But wait, it's imported from sqsgenerator.core. + # So we should patch 'sqsgenerator.core.optimize' + with patch("sqsgenerator.core.optimize", return_value=None): + with self.assertRaises(RuntimeError): + stk.build.sqs_structures( + structure=bulk("Au", cubic=True).repeat([2, 2, 2]), + composition=dict(Cu=16, Au=16), + iterations=10, + )