Skip to content

Restore default factory on Head._task_weights#819

Open
shaun0927 wants to merge 1 commit into
NVIDIA-Merlin:mainfrom
shaun0927:fix/813-task-weights-default-factory
Open

Restore default factory on Head._task_weights#819
shaun0927 wants to merge 1 commit into
NVIDIA-Merlin:mainfrom
shaun0927:fix/813-task-weights-default-factory

Conversation

@shaun0927
Copy link
Copy Markdown

Goals ⚽

Restore the documented "unset task weight defaults to 1.0" behavior that regressed in PR #802.

Fixes #813.

Implementation Details 🚧

PR #802 (ab7207cf) changed self._task_weights = defaultdict(lambda: 1.0) to self._task_weights = defaultdict() in transformers4rec/torch/model/base.py:272. A defaultdict constructed with no factory has the same runtime behavior as a plain dict — missing keys raise KeyError. The previous factory returned 1.0 for unset task weights, which matches the docstring for Head(task_weights=...):

task_weights: optional per-task loss weight. Missing entries default to 1.0.

One usage inside Head.forward was patched to .get(name, 1.0) as a workaround, but any other indexing site — internal future code or external subclasses that read head._task_weights[task_name] — now raises KeyError where it previously returned 1.0.

Minimal fix — restore the factory:

-self._task_weights = defaultdict()
+self._task_weights = defaultdict(lambda: 1.0)

This is the smallest change that re-establishes the defaultdict contract and avoids the need for scattered .get(name, 1.0) work-arounds.

Testing Details 🔍

Pre-fix regression:

>>> from collections import defaultdict
>>> d = defaultdict()
>>> d["anything"]
KeyError: 'anything'

Post-fix:

>>> d = defaultdict(lambda: 1.0)
>>> d["anything"]
1.0

Existing tests that touch Head(...) construction still pass (the factory change is purely additive on unset-key reads). No new test is added because the repro is a straightforward standard-library semantic; happy to add an explicit regression test if reviewers would prefer.

PR NVIDIA-Merlin#802 replaced defaultdict(lambda: 1.0) with defaultdict(), which
has the same runtime semantics as a plain dict - missing keys raise
KeyError. The documented behavior ('1.0 when unset') was preserved at
only one call site via .get(name, 1.0); any other direct indexing
regresses to a crash.

Restoring the lambda factory is a one-line change that preserves the
original API contract and keeps downstream code that reads
head._task_weights[task_name] working.

Fixes NVIDIA-Merlin#813
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@shaun0927
Copy link
Copy Markdown
Author

FYI — I have read CLA.md and agree to its terms for this submission. The changes in this PR are entirely my original work, made on my own behalf (not in the course of employment by any other party), and are offered under the Apache 2.0 license of the project. Happy to re-state this in any additional form the maintainers prefer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Head._task_weights = defaultdict() silently behaves like a plain dict

1 participant