Summary
When using make_ppo_networks() with the default distribution_type='tanh_normal', the init_noise_std parameter is accepted without any warning but has no effect.
Root Cause
In make_policy_network() (networks.py:393-399), the tanh_normal branch creates a plain MLP:
if distribution_type == 'tanh_normal':
policy_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [param_size],
activation=activation,
kernel_init=kernel_init,
layer_norm=layer_norm,
)
init_noise_std, noise_std_type, and state_dependent_std are all accepted by the function signature but never passed to this branch. The std is entirely determined by the network's output (second half of 2*action_size, passed through softplus in NormalTanhDistribution.create_dist()), with the initial value depending on random weight initialization.
Only the normal branch creates PolicyModuleWithStd, where init_noise_std actually initializes the learnable LogParam/Param.
Impact
- Users who set
init_noise_std while using the default tanh_normal get zero feedback that the parameter is being ignored
- Any hyperparameter sweep over
init_noise_std under tanh_normal produces identical results — wasted compute
learner.py default flags combine tanh_normal with init_noise_std=1.0 (dead code)
train_test.py tests tanh_normal + init_noise_std=0.8 without verifying it has any effect
We discovered this while training dexterous manipulation policies with MJX. We had been tuning init_noise_std under the default tanh_normal for some time before realizing it had zero effect.
Suggested Fix (any of)
- Raise a warning when
init_noise_std is explicitly set with tanh_normal
- Document that
init_noise_std only applies to distribution_type='normal'
- Consider changing the PPO default to
distribution_type='normal' — most PPO implementations in the community (IsaacLab, rl_games, rsl_rl, CleanRL, Stable-Baselines3) use state-independent std without tanh squashing, consistent with Schulman's original PPO
Summary
When using
make_ppo_networks()with the defaultdistribution_type='tanh_normal', theinit_noise_stdparameter is accepted without any warning but has no effect.Root Cause
In
make_policy_network()(networks.py:393-399), thetanh_normalbranch creates a plainMLP:init_noise_std,noise_std_type, andstate_dependent_stdare all accepted by the function signature but never passed to this branch. The std is entirely determined by the network's output (second half of2*action_size, passed throughsoftplusinNormalTanhDistribution.create_dist()), with the initial value depending on random weight initialization.Only the
normalbranch createsPolicyModuleWithStd, whereinit_noise_stdactually initializes the learnableLogParam/Param.Impact
init_noise_stdwhile using the defaulttanh_normalget zero feedback that the parameter is being ignoredinit_noise_stdundertanh_normalproduces identical results — wasted computelearner.pydefault flags combinetanh_normalwithinit_noise_std=1.0(dead code)train_test.pyteststanh_normal+init_noise_std=0.8without verifying it has any effectWe discovered this while training dexterous manipulation policies with MJX. We had been tuning
init_noise_stdunder the defaulttanh_normalfor some time before realizing it had zero effect.Suggested Fix (any of)
init_noise_stdis explicitly set withtanh_normalinit_noise_stdonly applies todistribution_type='normal'distribution_type='normal'— most PPO implementations in the community (IsaacLab, rl_games, rsl_rl, CleanRL, Stable-Baselines3) use state-independent std without tanh squashing, consistent with Schulman's original PPO