Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions howso/engine/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_save_load_good(self, trainee, file_path_type):
# Load
if file_path_type == 'directory_only':
file_path = file_path + 'save_load_trainee.caml'
load_example_trainee = load_trainee(file_path=file_path)
load_example_trainee = load_trainee(path_or_bytes=file_path)

load_training_cases = load_example_trainee.get_num_training_cases()

Expand Down Expand Up @@ -200,7 +200,7 @@ def test_save_load_warning(self, trainee):
# Set to correct path
file_path = f"{cwd}/save_load_trainee.caml"

load_example_trainee = load_trainee(file_path=file_path)
load_example_trainee = load_trainee(path_or_bytes=file_path)
load_training_cases = load_example_trainee.get_num_training_cases()

# Delete
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_save_load_bad_load(self):
HowsoError,
match='A `.caml` file must be provided.'
):
load_trainee(file_path=file_path)
load_trainee(path_or_bytes=file_path)

@pytest.mark.parametrize("status_msg, expected_msg", [
("This is a test", "This is a test"),
Expand All @@ -255,7 +255,7 @@ def test_load_status_message(self, mocker, monkeypatch, status_msg, expected_msg
HowsoError,
match=f'Failed to load Trainee file "{file_path.as_posix()}": {expected_msg}'
):
load_trainee(file_path=file_path)
load_trainee(path_or_bytes=file_path)

def test_always_persist_load(self, tmp_path: Path, data, features):
"""Test that an auto-persist trainee can be reloaded."""
Expand All @@ -268,7 +268,7 @@ def test_always_persist_load(self, tmp_path: Path, data, features):
finally:
trainee.delete()

load_example_trainee = load_trainee(file_path=save_path, persistence="always")
load_example_trainee = load_trainee(path_or_bytes=save_path, persistence="always")
try:
assert load_example_trainee.get_num_training_cases() == 150
finally:
Expand Down Expand Up @@ -415,4 +415,10 @@ def test_react_aggregate(self, data: DataFrame, trainee: Trainee):
assert isinstance(value, DataFrame)

for feature in data.columns:
assert feature in total_df.columns
assert feature in total_df.columns

def test_to_memory(self, trainee):
"""
Test the passthrough to `to_memory` in the Trainee class.
"""
assert isinstance(trainee.to_memory(), bytes)
129 changes: 89 additions & 40 deletions howso/engine/trainee.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

import typing as t
import uuid
import warnings
from collections.abc import (
Callable,
Collection,
Generator,
Iterable,
Mapping,
MutableMapping,
)
from copy import deepcopy
from pathlib import Path
import typing as t
import uuid
import warnings

from pandas import (
DataFrame,
Expand All @@ -28,15 +29,16 @@
LocalSaveableProtocol,
ProjectClient,
)
from howso.client.schemas import AggregateReaction, GroupReaction
from howso.client.schemas import Project as BaseProject
from howso.client.schemas import Reaction
from howso.client.schemas import Session as BaseSession
from howso.client.schemas import Trainee as BaseTrainee
from howso.client.schemas import (
AggregateReaction,
GroupReaction,
Reaction,
TraineeRuntime,
TraineeRuntimeOptions,
)
from howso.client.schemas import Project as BaseProject
from howso.client.schemas import Session as BaseSession
from howso.client.schemas import Trainee as BaseTrainee
from howso.client.typing import (
AblationThresholdMap,
CaseIndices,
Expand All @@ -57,6 +59,7 @@
TargetedModel,
ValueMasses,
)
from howso.direct.client import HowsoDirectClient
from howso.engine.client import get_client
from howso.engine.project import Project
from howso.engine.session import Session
Expand Down Expand Up @@ -4469,6 +4472,34 @@ def from_schema(
return schema
return cls.from_dict(dict(schema.to_dict(), client=client))

def to_memory(
self,
*,
file_type: t.Literal["amlg", "caml"] = "amlg",
trainee_path: Iterable[str] | None = None,
) -> bytes | None:
"""
Get the Trainee file data as bytes.

Parameters
----------
file_type : {"amlg", "caml"}, default "amlg"
The type of byte data to return.
trainee_path : Iterable of str, optional
The hierarchy path to a sub-Trainee from the root Trainee specified by `trainee_id`.

Returns
-------
bytes or None
The Trainee file data as bytes. Or None if the `trainee_id` and/or `trainee_path` does not refer to
a valid Trainee.
"""
if not isinstance(self.client, LocalSaveableProtocol):
raise HowsoError("The current client does not support loading a Trainee from file or memory.")

return self.client.trainee_to_memory(self.id, file_type=file_type, trainee_path=trainee_path)


@classmethod
def from_dict(cls, schema: Mapping) -> Trainee:
"""
Expand Down Expand Up @@ -4573,32 +4604,45 @@ def delete_trainee(
else:
client.delete_trainee(trainee_id=str(name_or_id))

@t.overload
def load_trainee(
path_or_bytes: PathLike,
client: AbstractHowsoClient | None = ...,
*,
persistence: Persistence = ...,
) -> Trainee: ...

@t.overload
def load_trainee(
path_or_bytes: bytes,
client: AbstractHowsoClient | None = ...,
*,
persistence: Persistence = ...,
) -> Trainee: ...


def load_trainee(
file_path: PathLike,
path_or_bytes: PathLike | bytes,
client: t.Optional[AbstractHowsoClient] = None,
*,
persistence: Persistence = 'allow',
persistence: Persistence = "allow",
) -> Trainee:
"""
Load an existing trainee from disk.

Parameters
----------
file_path : str or bytes or os.PathLike
The path of the file to load the Trainee from. This path can contain
an absolute path, a relative path or simply a file name. A ``.caml`` file name
must be always be provided if file paths are provided.

If ``file_path`` is a relative path the absolute path will be computed
appending the ``file_path`` to the CWD.

If ``file_path`` is an absolute path, this is the absolute path that
will be used.

If ``file_path`` is just a filename, then the absolute path will be computed
appending the filename to the CWD.

path_or_bytes : str | bytes | os.PathLike
The path or binary data to load the Trainee from.

If a path, it can be an absolute path, a relative path, or simply a file name.
The file must be a ``.caml`` file. If the path is relative the absolute path will
be computed appending the ``file_path`` to the CWD. If the path is an absolute path,
this is the absolute path that will be used. If the path is just a filename, then
the absolute path will be computed appending the filename to the CWD.

If binary data, that data will be loaded as a :class:`Trainee` directly using
:meth:`Howso.direct.HowsoDirectClient.create_trainee_from_memory`.
client : AbstractHowsoClient, optional
The Howso client instance to use. Must have local disk access.
persistence : {"allow", "always", "never"}, default "allow"
Expand All @@ -4609,25 +4653,30 @@ def load_trainee(
Returns
-------
Trainee
The trainee instance.
The :class:`Trainee` instance.
"""
client = client or get_client()

if not isinstance(client, LocalSaveableProtocol):
raise HowsoError("The current client does not support loading a Trainee from file.")

if not isinstance(file_path, Path):
file_path = Path(file_path)
raise HowsoError("The current client does not support loading a Trainee from file or memory.")

trainee_id = str(uuid.uuid4())

file_path = file_path.expanduser().resolve()
path = Path(path_or_bytes).expanduser().resolve()
if isinstance(path_or_bytes, bytes) and not path.exists():
# The bytes are not a string representing a path, try to load
# the trainee.
base_trainee = create_trainee_from_memory(trainee_id, path_or_bytes)
client.amlg.set_entity_permissions(trainee_id, json_permissions='{"load":true,"store":true}')
client.trainee_cache.set(base_trainee)
trainee = Trainee.from_schema(base_trainee, client=client)
return trainee

# It is decided that if the file contains a suffix then it contains a
# file name.
if file_path.suffix:
if path.suffix:
# Check to make sure sure `.caml` file is provided
if file_path.suffix.lower() != '.caml':
if path.suffix.lower() != '.caml':
raise HowsoError(
'Filepath with a non `.caml` extension was provided.'
)
Expand All @@ -4636,37 +4685,37 @@ def load_trainee(
raise HowsoError('A `.caml` file must be provided.')

# If path is not absolute, append it to the default directory.
if not file_path.is_absolute():
file_path = client.default_persist_path.joinpath(file_path)
if not path.is_absolute():
path_or_bytes = client.default_persist_path.joinpath(path)

# Ensure the path exists
if not file_path.exists():
if not path.exists():
raise HowsoError(
f'The specified Trainee file "{file_path.as_posix()}" does not exist.')
f'The specified Trainee file "{path.as_posix()}" does not exist.')

if persistence == 'always':
status = client.amlg.load_entity(
handle=trainee_id,
file_path=str(file_path),
file_path=str(path_or_bytes),
persist=True,
json_file_params=('{"transactional":true,"flatten":true,"execute_on_load":true,'
'"require_version_compatibility":true}')
)
else:
status = client.amlg.load_entity(
handle=trainee_id,
file_path=str(file_path)
file_path=str(path_or_bytes)
)
if not status.loaded:
status_msg = status.message or "An unknown error occurred"
raise HowsoError(f'Failed to load Trainee file "{file_path.as_posix()}": {status_msg}')
raise HowsoError(f'Failed to load Trainee file "{path.as_posix()}": {status_msg}')

client.amlg.set_entity_permissions(trainee_id, json_permissions='{"load":true,"store":true}')

base_trainee = client._get_trainee_from_engine(trainee_id) # type: ignore reportPrivateUsage
client.trainee_cache.set(base_trainee)
trainee = Trainee.from_schema(base_trainee, client=client)
setattr(trainee, '_custom_save_path', file_path)
setattr(trainee, '_custom_save_path', path_or_bytes)

return trainee

Expand Down
Loading