Skip to content

Commit 577616c

Browse files
authored
Use schedule class instead of lambdas (#493)
* Use schedule class instead of lambdas * Add test for linear schedule
1 parent 9a5b7ae commit 577616c

File tree

7 files changed

+64
-23
lines changed

7 files changed

+64
-23
lines changed

CHANGELOG.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
## Release 2.6.1 (WIP)
2+
3+
### Breaking Changes
4+
- Upgraded to SB3 >= 2.6.1
5+
- `linear_schedule` now returns a `SimpleLinearSchedule` object for better portability
6+
7+
### New Features
8+
9+
### Bug fixes
10+
- Docker GPU images are now working again
11+
- Use `ConstantSchedule`, and `SimpleLinearSchedule` instead of `constant_fn` and `linear_schedule`
12+
13+
### Documentation
14+
15+
### Other
16+
117
## Release 2.6.0 (2025-03-24)
218

319
### Breaking Changes

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
gym==0.26.2
2-
stable-baselines3[extra,tests,docs]>=2.6.0,<3.0
2+
stable-baselines3[extra,tests,docs]>=2.6.1a1,<3.0
33
box2d-py==2.3.8
44
pybullet_envs_gymnasium>=0.6.0
55
# minigrid

rl_zoo3/exp_manager.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
3333
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
3434
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike # noqa: F401
35-
from stable_baselines3.common.utils import constant_fn
35+
from stable_baselines3.common.utils import ConstantSchedule
3636
from stable_baselines3.common.vec_env import (
3737
DummyVecEnv,
3838
SubprocVecEnv,
@@ -50,7 +50,14 @@
5050
import rl_zoo3.import_envs # noqa: F401
5151
from rl_zoo3.callbacks import SaveVecNormalizeCallback, TrialEvalCallback
5252
from rl_zoo3.hyperparams_opt import HYPERPARAMS_CONVERTER, HYPERPARAMS_SAMPLER
53-
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule
53+
from rl_zoo3.utils import (
54+
ALGOS,
55+
SimpleLinearSchedule,
56+
get_callback_list,
57+
get_class_by_name,
58+
get_latest_run_id,
59+
get_wrapper_class,
60+
)
5461

5562

5663
class ExperimentManager:
@@ -381,12 +388,12 @@ def _preprocess_schedules(hyperparams: dict[str, Any]) -> dict[str, Any]:
381388
if isinstance(hyperparams[key], str):
382389
schedule, initial_value = hyperparams[key].split("_")
383390
initial_value = float(initial_value)
384-
hyperparams[key] = linear_schedule(initial_value)
391+
hyperparams[key] = SimpleLinearSchedule(initial_value)
385392
elif isinstance(hyperparams[key], (float, int)):
386393
# Negative value: ignore (ex: for clipping)
387394
if hyperparams[key] < 0:
388395
continue
389-
hyperparams[key] = constant_fn(float(hyperparams[key]))
396+
hyperparams[key] = ConstantSchedule(float(hyperparams[key]))
390397
else:
391398
raise ValueError(f"Invalid value for {key}: {hyperparams[key]}")
392399
return hyperparams

rl_zoo3/utils.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -292,25 +292,33 @@ def make_env(**kwargs) -> gym.Env:
292292
return env
293293

294294

295-
def linear_schedule(initial_value: Union[float, str]) -> Callable[[float], float]:
295+
class SimpleLinearSchedule:
296+
"""
297+
Linear learning rate schedule (from initial value to zero),
298+
simpler than sb3 LinearSchedule.
299+
300+
:param initial_value: (float or str) The initial value for the schedule
301+
"""
302+
303+
def __init__(self, initial_value: Union[float, str]) -> None:
304+
# Force conversion to float
305+
self.initial_value = float(initial_value)
306+
307+
def __call__(self, progress_remaining: float) -> float:
308+
return progress_remaining * self.initial_value
309+
310+
def __repr__(self) -> str:
311+
return f"SimpleLinearSchedule(initial_value={self.initial_value})"
312+
313+
314+
def linear_schedule(initial_value: Union[float, str]) -> SimpleLinearSchedule:
296315
"""
297316
Linear learning rate schedule.
298317
299318
:param initial_value: (float or str)
300-
:return: (function)
319+
:return: A `SimpleLinearSchedule` object
301320
"""
302-
# Force conversion to float
303-
initial_value_ = float(initial_value)
304-
305-
def func(progress_remaining: float) -> float:
306-
"""
307-
Progress will decrease from 1 (beginning) to 0
308-
:param progress_remaining: (float)
309-
:return: (float)
310-
"""
311-
return progress_remaining * initial_value_
312-
313-
return func
321+
return SimpleLinearSchedule(initial_value)
314322

315323

316324
def get_trained_models(log_folder: str) -> dict[str, tuple[str, str]]:

rl_zoo3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.6.1a0
1+
2.6.1a1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
See https://github.com/DLR-RM/rl-baselines3-zoo
1616
"""
1717
install_requires = [
18-
"sb3_contrib>=2.6.0,<3.0",
18+
"sb3_contrib>=2.6.1a1,<3.0",
1919
"gymnasium>=0.29.1,<1.2.0",
2020
"huggingface_sb3>=3.0,<4.0",
2121
"tqdm",

tests/test_wrappers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import gymnasium as gym
2+
import numpy as np
23
import pytest
34
import stable_baselines3 as sb3
5+
from sb3_contrib.common.wrappers import TimeFeatureWrapper
46
from stable_baselines3 import A2C
57
from stable_baselines3.common.env_checker import check_env
68
from stable_baselines3.common.env_util import DummyVecEnv
79

810
import rl_zoo3.import_envs
911
import rl_zoo3.wrappers
10-
from rl_zoo3.utils import get_wrapper_class
11-
from rl_zoo3.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper, TimeFeatureWrapper
12+
from rl_zoo3.utils import SimpleLinearSchedule, get_wrapper_class, linear_schedule
13+
from rl_zoo3.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper
1214

1315

1416
def test_wrappers():
@@ -55,3 +57,11 @@ def test_get_vec_env_wrapper(vec_env_wrapper):
5557
if wrapper_class is not None:
5658
env = wrapper_class(env)
5759
A2C("MlpPolicy", env).learn(16)
60+
61+
62+
def test_linear_schedule():
63+
schedule = linear_schedule(100)
64+
assert isinstance(schedule, SimpleLinearSchedule)
65+
assert np.allclose(schedule(1.0), 100.0)
66+
assert np.allclose(schedule(0.5), 50.0)
67+
assert np.allclose(schedule(0.0), 0.0)

0 commit comments

Comments
 (0)