Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 55 additions & 7 deletions pyrit/scenario/core/dataset_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class DatasetConfiguration:
scenario_strategies (Optional[Sequence[ScenarioStrategy]]): The scenario
strategies being executed. Subclasses can use this to filter or customize
which seed groups are loaded based on the selected strategies.

Subclassing notes:
Memoization lives in ``get_seed_groups()`` and ``get_all_seeds()`` —
the two methods that call ``random.sample``. Overrides of those, or
new resolution methods that introduce their own randomness, must
memoize explicitly to preserve lifetime-stable sampling.
"""

def __init__(
Expand Down Expand Up @@ -75,14 +81,37 @@ def __init__(
"or 'dataset_names' to load from memory."
)

if max_dataset_size is not None and max_dataset_size < 1:
raise ValueError("'max_dataset_size' must be a positive integer (>= 1).")

# Store private attributes
# Caches must exist before the max_dataset_size setter runs.
self._seed_groups = list(seed_groups) if seed_groups is not None else None
self.max_dataset_size = max_dataset_size
self._dataset_names = list(dataset_names) if dataset_names is not None else None
self._scenario_strategies = scenario_strategies
self._resolved_groups_cache: Optional[dict[str, list[SeedGroup]]] = None
self._resolved_seeds_cache: Optional[list[Seed]] = None
self._max_dataset_size: Optional[int] = None
Copy link
Copy Markdown
Contributor

@rlundeen2 rlundeen2 May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we simplify this?

Instead of a cache, what if we added a baseline scenario technique that is just PromptSending. We get rid of this in initialize

        if self._include_baseline:
            baseline_attack = self._get_baseline()
            self._atomic_attacks.insert(0, baseline_attack)

and

    def _get_baseline(self) -> AtomicAttack:

And instead add a tag in _get_attack_technique_factories that adds a PromptSending technique as baseline?

_build_display_group would also likely need to be updated to support baseline?

There might be some hiccups, but it feels like a more natural place to include it as an additional technique vs trying to cache the datasets

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this design change, and I think it is the right direction. My only concern is on doing this instead of the caching / memoization. Many of our scenarios never call _get_attack_technique_factories which means migrating those to the factory pattern. I can certainly add those changes here but going forward EncodingDatasetConfiguration.get_all_seed_attack_groups() still gets to its own call of random.sample which would bypass the factory loop and reintroduce the issue. I think making both changes here makes sense, I just don't want to increase scope and leave the underlying cause of the bug latent.

I could certainly be misreading the underlying architecture so feel free to push back on my framing of the issue if the baseline change alone would be sufficient to resolve this bug.

self.max_dataset_size = max_dataset_size # validates via setter

@property
def max_dataset_size(self) -> Optional[int]:
"""
Maximum number of SeedGroups to sample per dataset.

When set, the configuration samples a stable random subset on first
resolution and reuses that subset for the lifetime of the
configuration object (or until this attribute is reassigned).
Reassigning invalidates the cached sample so the next resolution
produces a fresh subset.
"""
return self._max_dataset_size

@max_dataset_size.setter
def max_dataset_size(self, value: Optional[int]) -> None:
if value is not None and value < 1:
raise ValueError("'max_dataset_size' must be a positive integer (>= 1).")
self._max_dataset_size = value
# Invalidate any previously resolved sample so the next call
# re-samples against the new cap.
self._resolved_groups_cache = None
self._resolved_seeds_cache = None

def get_seed_groups(self) -> dict[str, list[SeedGroup]]:
"""
Expand All @@ -94,6 +123,11 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]:

In all cases, max_dataset_size is applied **per dataset** if set.

The resolved sample is cached for the lifetime of the configuration
(until ``max_dataset_size`` is reassigned). A defensive container
copy is returned on each call so the cache survives caller-side
mutation of the dict or per-dataset lists.

Subclasses can override this to filter or customize which seed groups
are loaded based on the stored scenario_composites.

Expand All @@ -106,6 +140,9 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]:
Raises:
ValueError: If no seed groups could be resolved from the configuration.
"""
if self._resolved_groups_cache is not None:
return {name: list(groups) for name, groups in self._resolved_groups_cache.items()}

result: dict[str, list[SeedGroup]] = {}

if self._seed_groups is not None:
Expand All @@ -129,7 +166,9 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]:
if not result:
raise ValueError("DatasetConfiguration has no seed_groups. Set seed_groups or dataset_names.")

return result
self._resolved_groups_cache = result
# Defensive copy: caller must not be able to mutate the cache.
return {name: list(groups) for name, groups in result.items()}

def _load_seed_groups_for_dataset(self, *, dataset_name: str) -> list[SeedGroup]:
"""
Expand Down Expand Up @@ -256,6 +295,11 @@ def get_all_seeds(self) -> list[Seed]:
from memory for all configured datasets. If max_dataset_size is set, randomly
samples up to that many prompts per dataset (without replacement).

The resolved sample is cached for the lifetime of the configuration
(until ``max_dataset_size`` is reassigned). A defensive list copy
is returned on each call so the cache survives caller-side
mutation.

Returns:
List[SeedPrompt]: List of SeedPrompt objects from all configured datasets.
Returns an empty list if no prompts are found.
Expand All @@ -266,6 +310,9 @@ def get_all_seeds(self) -> list[Seed]:
if self._dataset_names is None:
raise ValueError("No dataset names configured. Set dataset_names to use get_all_seed_prompts.")

if self._resolved_seeds_cache is not None:
return list(self._resolved_seeds_cache)

memory = CentralMemory.get_memory_instance()
all_seeds: list[Seed] = []

Expand All @@ -277,4 +324,5 @@ def get_all_seeds(self) -> list[Seed]:
seeds = random.sample(seeds, self.max_dataset_size)
all_seeds.extend(seeds)

return all_seeds
self._resolved_seeds_cache = all_seeds
return list(all_seeds)
194 changes: 194 additions & 0 deletions tests/unit/scenario/test_dataset_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,197 @@ def test_get_all_seeds_returns_empty_list_when_no_seeds_in_memory(self) -> None:
result = config.get_all_seeds()

assert result == []


@pytest.mark.usefixtures("patch_central_database")
class TestDatasetConfigurationMemoization:
"""Tests for memoization of resolved seed groups and seeds.

Pins the contract that the random subset selected when ``max_dataset_size``
is set is stable for the lifetime of the configuration object. ADO 9012
regression tests live here; flakiness is avoided by patching
``random.sample`` rather than relying on RNG seeds.
"""

def _make_seed_groups(self, count: int) -> list[SeedGroup]:
return [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(count)]

def test_get_seed_groups_is_stable_across_calls_with_max_dataset_size(self) -> None:
seed_groups = self._make_seed_groups(10)
config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3)

first_sample = seed_groups[:3]
second_sample = seed_groups[3:6]
with patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
side_effect=[first_sample, second_sample],
) as mock_sample:
first = config.get_seed_groups()
second = config.get_seed_groups()

assert first[EXPLICIT_SEED_GROUPS_KEY] == first_sample
assert second[EXPLICIT_SEED_GROUPS_KEY] == first_sample
assert mock_sample.call_count == 1

def test_get_seed_groups_is_stable_across_multi_dataset(self) -> None:
ds1 = self._make_seed_groups(10)
ds2 = self._make_seed_groups(10)

def mock_load(*, dataset_name: str) -> list[SeedGroup]:
return ds1 if dataset_name == "ds1" else ds2

config = DatasetConfiguration(dataset_names=["ds1", "ds2"], max_dataset_size=3)

ds1_sample = ds1[:3]
ds2_sample = ds2[:3]
with (
patch.object(config, "_load_seed_groups_for_dataset", side_effect=mock_load),
patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
side_effect=[ds1_sample, ds2_sample, ds1[3:6], ds2[3:6]],
) as mock_sample,
):
first = config.get_seed_groups()
second = config.get_seed_groups()

assert first["ds1"] == ds1_sample
assert first["ds2"] == ds2_sample
assert second["ds1"] == ds1_sample
assert second["ds2"] == ds2_sample
assert mock_sample.call_count == 2 # one per dataset, on the first call only

def test_get_all_seed_attack_groups_is_stable_across_calls(self) -> None:
seed_groups = self._make_seed_groups(10)
config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3)

with patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
side_effect=[seed_groups[:3], seed_groups[3:6]],
):
first = config.get_all_seed_attack_groups()
second = config.get_all_seed_attack_groups()

first_objectives = [g.objective.value for g in first]
second_objectives = [g.objective.value for g in second]
assert first_objectives == second_objectives

def test_get_all_seeds_is_stable_across_calls(self) -> None:
seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)]

with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class:
mock_memory = MagicMock()
mock_memory.get_seeds.return_value = seeds
mock_memory_class.get_memory_instance.return_value = mock_memory

config = DatasetConfiguration(dataset_names=["d1"], max_dataset_size=3)

with patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
side_effect=[seeds[:3], seeds[3:6]],
) as mock_sample:
first = config.get_all_seeds()
second = config.get_all_seeds()

assert first == seeds[:3]
assert second == seeds[:3]
assert mock_sample.call_count == 1

def test_returned_dict_can_be_mutated_without_poisoning_cache(self) -> None:
seed_groups = self._make_seed_groups(10)
config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3)

with patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
return_value=seed_groups[:3],
):
first = config.get_seed_groups()
first[EXPLICIT_SEED_GROUPS_KEY].clear()
first.pop(EXPLICIT_SEED_GROUPS_KEY, None)
second = config.get_seed_groups()

assert second[EXPLICIT_SEED_GROUPS_KEY] == seed_groups[:3]

def test_returned_seeds_list_can_be_mutated_without_poisoning_cache(self) -> None:
seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)]

with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class:
mock_memory = MagicMock()
mock_memory.get_seeds.return_value = seeds
mock_memory_class.get_memory_instance.return_value = mock_memory

config = DatasetConfiguration(dataset_names=["d1"], max_dataset_size=3)

with patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
return_value=seeds[:3],
):
first = config.get_all_seeds()
first.clear()
second = config.get_all_seeds()

assert second == seeds[:3]


@pytest.mark.usefixtures("patch_central_database")
class TestDatasetConfigurationMaxDatasetSizeSetter:
"""Tests for the ``max_dataset_size`` property setter."""

def test_setter_invalidates_groups_cache(self) -> None:
seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)]
config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3)

first_sample = seed_groups[:3]
second_sample = seed_groups[5:8]
with patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
side_effect=[first_sample, second_sample],
):
first = config.get_seed_groups()
config.max_dataset_size = 3 # reassign (same value triggers invalidation)
second = config.get_seed_groups()

assert first[EXPLICIT_SEED_GROUPS_KEY] == first_sample
assert second[EXPLICIT_SEED_GROUPS_KEY] == second_sample

def test_setter_invalidates_seeds_cache(self) -> None:
seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)]

with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class:
mock_memory = MagicMock()
mock_memory.get_seeds.return_value = seeds
mock_memory_class.get_memory_instance.return_value = mock_memory

config = DatasetConfiguration(dataset_names=["d1"], max_dataset_size=3)

with patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
side_effect=[seeds[:3], seeds[5:8]],
):
first = config.get_all_seeds()
config.max_dataset_size = 3
second = config.get_all_seeds()

assert first == seeds[:3]
assert second == seeds[5:8]

def test_setter_rejects_zero(self) -> None:
config = DatasetConfiguration(seed_groups=[SeedGroup(seeds=[SeedObjective(value="obj")])])

with pytest.raises(ValueError, match="must be a positive integer"):
config.max_dataset_size = 0

def test_setter_rejects_negative(self) -> None:
config = DatasetConfiguration(seed_groups=[SeedGroup(seeds=[SeedObjective(value="obj")])])

with pytest.raises(ValueError, match="must be a positive integer"):
config.max_dataset_size = -1

def test_setter_accepts_none(self) -> None:
config = DatasetConfiguration(
seed_groups=[SeedGroup(seeds=[SeedObjective(value="obj")])],
max_dataset_size=5,
)

config.max_dataset_size = None

assert config.max_dataset_size is None
31 changes: 31 additions & 0 deletions tests/unit/scenario/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,34 @@ def test_encoding_dataset_config_can_be_initialized_with_dataset_names(self):

assert config._dataset_names == ["garak_slur_terms_en", "garak_web_html_js"]
assert config.max_dataset_size == 5

def test_get_all_seed_attack_groups_is_stable_across_calls_with_max_dataset_size(self):
"""Regression test for ADO 9012 (Path 2).

EncodingDatasetConfiguration.get_all_seed_attack_groups overrides the
base method and routes through get_all_seeds, which has its own
random.sample. Memoizing only get_seed_groups would not catch this
path; this test pins that the override is stable across calls.
"""
from unittest.mock import patch

seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)]

with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class:
mock_memory = MagicMock()
mock_memory.get_seeds.return_value = seeds
mock_memory_class.get_memory_instance.return_value = mock_memory

config = EncodingDatasetConfiguration(dataset_names=["d1"], max_dataset_size=3)

with patch(
"pyrit.scenario.core.dataset_configuration.random.sample",
side_effect=[seeds[:3], seeds[3:6]],
) as mock_sample:
first = config.get_all_seed_attack_groups()
second = config.get_all_seed_attack_groups()

first_objectives = [g.objective.value for g in first]
second_objectives = [g.objective.value for g in second]
assert first_objectives == second_objectives
assert mock_sample.call_count == 1
Loading
Loading