diff --git a/howso/engine/tests/test_engine.py b/howso/engine/tests/test_engine.py index d5060628..28ef18ef 100644 --- a/howso/engine/tests/test_engine.py +++ b/howso/engine/tests/test_engine.py @@ -163,7 +163,7 @@ def test_save_load_good(self, trainee, file_path_type): # Load if file_path_type == 'directory_only': file_path = file_path + 'save_load_trainee.caml' - load_example_trainee = load_trainee(file_path=file_path) + load_example_trainee = load_trainee(path_or_bytes=file_path) load_training_cases = load_example_trainee.get_num_training_cases() @@ -200,7 +200,7 @@ def test_save_load_warning(self, trainee): # Set to correct path file_path = f"{cwd}/save_load_trainee.caml" - load_example_trainee = load_trainee(file_path=file_path) + load_example_trainee = load_trainee(path_or_bytes=file_path) load_training_cases = load_example_trainee.get_num_training_cases() # Delete @@ -235,7 +235,7 @@ def test_save_load_bad_load(self): HowsoError, match='A `.caml` file must be provided.' ): - load_trainee(file_path=file_path) + load_trainee(path_or_bytes=file_path) @pytest.mark.parametrize("status_msg, expected_msg", [ ("This is a test", "This is a test"), @@ -255,7 +255,7 @@ def test_load_status_message(self, mocker, monkeypatch, status_msg, expected_msg HowsoError, match=f'Failed to load Trainee file "{file_path.as_posix()}": {expected_msg}' ): - load_trainee(file_path=file_path) + load_trainee(path_or_bytes=file_path) def test_always_persist_load(self, tmp_path: Path, data, features): """Test that an auto-persist trainee can be reloaded.""" @@ -268,7 +268,7 @@ def test_always_persist_load(self, tmp_path: Path, data, features): finally: trainee.delete() - load_example_trainee = load_trainee(file_path=save_path, persistence="always") + load_example_trainee = load_trainee(path_or_bytes=save_path, persistence="always") try: assert load_example_trainee.get_num_training_cases() == 150 finally: @@ -415,4 +415,10 @@ def test_react_aggregate(self, data: DataFrame, trainee: Trainee): assert isinstance(value, DataFrame) for feature in data.columns: - assert feature in total_df.columns \ No newline at end of file + assert feature in total_df.columns + + def test_to_memory(self, trainee): + """ + Test the passthrough to `to_memory` in the Trainee class. + """ + assert isinstance(trainee.to_memory(), bytes) \ No newline at end of file diff --git a/howso/engine/trainee.py b/howso/engine/trainee.py index 3d9ed479..71f9cb71 100644 --- a/howso/engine/trainee.py +++ b/howso/engine/trainee.py @@ -1,17 +1,18 @@ from __future__ import annotations +import typing as t +import uuid +import warnings from collections.abc import ( Callable, Collection, Generator, + Iterable, Mapping, MutableMapping, ) from copy import deepcopy from pathlib import Path -import typing as t -import uuid -import warnings from pandas import ( DataFrame, @@ -28,15 +29,16 @@ LocalSaveableProtocol, ProjectClient, ) -from howso.client.schemas import AggregateReaction, GroupReaction -from howso.client.schemas import Project as BaseProject -from howso.client.schemas import Reaction -from howso.client.schemas import Session as BaseSession -from howso.client.schemas import Trainee as BaseTrainee from howso.client.schemas import ( + AggregateReaction, + GroupReaction, + Reaction, TraineeRuntime, TraineeRuntimeOptions, ) +from howso.client.schemas import Project as BaseProject +from howso.client.schemas import Session as BaseSession +from howso.client.schemas import Trainee as BaseTrainee from howso.client.typing import ( AblationThresholdMap, CaseIndices, @@ -57,6 +59,7 @@ TargetedModel, ValueMasses, ) +from howso.direct.client import HowsoDirectClient from howso.engine.client import get_client from howso.engine.project import Project from howso.engine.session import Session @@ -4469,6 +4472,34 @@ def from_schema( return schema return cls.from_dict(dict(schema.to_dict(), client=client)) + def to_memory( + self, + *, + file_type: t.Literal["amlg", "caml"] = "amlg", + trainee_path: Iterable[str] | None = None, + ) -> bytes | None: + """ + Get the Trainee file data as bytes. + + Parameters + ---------- + file_type : {"amlg", "caml"}, default "amlg" + The type of byte data to return. + trainee_path : Iterable of str, optional + The hierarchy path to a sub-Trainee from the root Trainee specified by `trainee_id`. + + Returns + ------- + bytes or None + The Trainee file data as bytes. Or None if the `trainee_id` and/or `trainee_path` does not refer to + a valid Trainee. + """ + if not isinstance(self.client, LocalSaveableProtocol): + raise HowsoError("The current client does not support loading a Trainee from file or memory.") + + return self.client.trainee_to_memory(self.id, file_type=file_type, trainee_path=trainee_path) + + @classmethod def from_dict(cls, schema: Mapping) -> Trainee: """ @@ -4573,32 +4604,45 @@ def delete_trainee( else: client.delete_trainee(trainee_id=str(name_or_id)) +@t.overload +def load_trainee( + path_or_bytes: PathLike, + client: AbstractHowsoClient | None = ..., + *, + persistence: Persistence = ..., +) -> Trainee: ... + +@t.overload +def load_trainee( + path_or_bytes: bytes, + client: AbstractHowsoClient | None = ..., + *, + persistence: Persistence = ..., +) -> Trainee: ... + def load_trainee( - file_path: PathLike, + path_or_bytes: PathLike | bytes, client: t.Optional[AbstractHowsoClient] = None, *, - persistence: Persistence = 'allow', + persistence: Persistence = "allow", ) -> Trainee: """ Load an existing trainee from disk. Parameters ---------- - file_path : str or bytes or os.PathLike - The path of the file to load the Trainee from. This path can contain - an absolute path, a relative path or simply a file name. A ``.caml`` file name - must be always be provided if file paths are provided. - - If ``file_path`` is a relative path the absolute path will be computed - appending the ``file_path`` to the CWD. - - If ``file_path`` is an absolute path, this is the absolute path that - will be used. - - If ``file_path`` is just a filename, then the absolute path will be computed - appending the filename to the CWD. - + path_or_bytes : str | bytes | os.PathLike + The path or binary data to load the Trainee from. + + If a path, it can be an absolute path, a relative path, or simply a file name. + The file must be a ``.caml`` file. If the path is relative the absolute path will + be computed appending the ``file_path`` to the CWD. If the path is an absolute path, + this is the absolute path that will be used. If the path is just a filename, then + the absolute path will be computed appending the filename to the CWD. + + If binary data, that data will be loaded as a :class:`Trainee` directly using + :meth:`Howso.direct.HowsoDirectClient.create_trainee_from_memory`. client : AbstractHowsoClient, optional The Howso client instance to use. Must have local disk access. persistence : {"allow", "always", "never"}, default "allow" @@ -4609,25 +4653,30 @@ def load_trainee( Returns ------- Trainee - The trainee instance. + The :class:`Trainee` instance. """ client = client or get_client() if not isinstance(client, LocalSaveableProtocol): - raise HowsoError("The current client does not support loading a Trainee from file.") - - if not isinstance(file_path, Path): - file_path = Path(file_path) + raise HowsoError("The current client does not support loading a Trainee from file or memory.") trainee_id = str(uuid.uuid4()) - file_path = file_path.expanduser().resolve() + path = Path(path_or_bytes).expanduser().resolve() + if isinstance(path_or_bytes, bytes) and not path.exists(): + # The bytes are not a string representing a path, try to load + # the trainee. + base_trainee = create_trainee_from_memory(trainee_id, path_or_bytes) + client.amlg.set_entity_permissions(trainee_id, json_permissions='{"load":true,"store":true}') + client.trainee_cache.set(base_trainee) + trainee = Trainee.from_schema(base_trainee, client=client) + return trainee # It is decided that if the file contains a suffix then it contains a # file name. - if file_path.suffix: + if path.suffix: # Check to make sure sure `.caml` file is provided - if file_path.suffix.lower() != '.caml': + if path.suffix.lower() != '.caml': raise HowsoError( 'Filepath with a non `.caml` extension was provided.' ) @@ -4636,18 +4685,18 @@ def load_trainee( raise HowsoError('A `.caml` file must be provided.') # If path is not absolute, append it to the default directory. - if not file_path.is_absolute(): - file_path = client.default_persist_path.joinpath(file_path) + if not path.is_absolute(): + path_or_bytes = client.default_persist_path.joinpath(path) # Ensure the path exists - if not file_path.exists(): + if not path.exists(): raise HowsoError( - f'The specified Trainee file "{file_path.as_posix()}" does not exist.') + f'The specified Trainee file "{path.as_posix()}" does not exist.') if persistence == 'always': status = client.amlg.load_entity( handle=trainee_id, - file_path=str(file_path), + file_path=str(path_or_bytes), persist=True, json_file_params=('{"transactional":true,"flatten":true,"execute_on_load":true,' '"require_version_compatibility":true}') @@ -4655,18 +4704,18 @@ def load_trainee( else: status = client.amlg.load_entity( handle=trainee_id, - file_path=str(file_path) + file_path=str(path_or_bytes) ) if not status.loaded: status_msg = status.message or "An unknown error occurred" - raise HowsoError(f'Failed to load Trainee file "{file_path.as_posix()}": {status_msg}') + raise HowsoError(f'Failed to load Trainee file "{path.as_posix()}": {status_msg}') client.amlg.set_entity_permissions(trainee_id, json_permissions='{"load":true,"store":true}') base_trainee = client._get_trainee_from_engine(trainee_id) # type: ignore reportPrivateUsage client.trainee_cache.set(base_trainee) trainee = Trainee.from_schema(base_trainee, client=client) - setattr(trainee, '_custom_save_path', file_path) + setattr(trainee, '_custom_save_path', path_or_bytes) return trainee