Skip to content

Commit 0abdce9

Browse files
authored
add custom env examples
add custom env examples
2 parents 72bc4f7 + ff78415 commit 0abdce9

File tree

14 files changed

+547
-7
lines changed

14 files changed

+547
-7
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ Currently, the features supported by OpenRL include:
6363

6464
- Importing models and datasets from [Hugging Face](https://huggingface.co/)
6565

66+
- [Tutorial](https://openrl-docs.readthedocs.io/en/latest/custom_env/index.html) on how to integrate user-defined environments into OpenRL.
67+
6668
- Support for models such as LSTM, GRU, Transformer etc.
6769

6870
- Multiple training acceleration methods including automatic mixed precision training and data collecting wth half

README_zh.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ OpenRL基于PyTorch进行开发,目标是为强化学习研究社区提供一
5353
- 支持自然语言任务(如对话任务)的强化学习训练
5454
- 支持[竞技场](https://openrl-docs.readthedocs.io/zh/latest/arena/index.html)功能,可以在多智能体对抗性环境中方便地对各种智能体进行评测。
5555
- 支持从[Hugging Face](https://huggingface.co/)上导入模型和数据
56+
- 提供用户自有环境接入OpenRL的[详细教程](https://openrl-docs.readthedocs.io/zh/latest/custom_env/index.html).
5657
- 支持LSTM,GRU,Transformer等模型
5758
- 支持多种训练加速,例如:自动混合精度训练,半精度策略网络收集数据等
5859
- 支持用户自定义训练模型、奖励模型、训练数据以及环境

examples/custom_env/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Integrate user-defined environments into OpenRL
2+
3+
4+
Here, we provide several toy examples to show how to add user-defined environments into OpenRL.
5+
6+
- `gymnasium_env.py`: a simple example to show how to create a Gymnasium environment and integrate it into OpenRL.
7+
- `openai_gym_env.py`: a simple example to show how to create a OpenAI Gym environment and integrate it into OpenRL.
8+
- `pettingzoo_env.py`: a simple example to show how to create a PettingZoo environment and integrate it into OpenRL.

examples/custom_env/gymnasium_env.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
from typing import Any, Dict, Optional
20+
21+
import gymnasium as gym
22+
from gymnasium import spaces
23+
from gymnasium.envs.registration import EnvSpec, register
24+
from gymnasium.utils import seeding
25+
from train_and_test import train_and_test
26+
27+
from openrl.envs.common import make
28+
29+
30+
class IdentityEnv(gym.Env):
31+
spec = EnvSpec("IdentityEnv")
32+
33+
def __init__(self, **kwargs):
34+
self.dim = 2
35+
self.observation_space = spaces.Discrete(1)
36+
self.action_space = spaces.Discrete(self.dim)
37+
self.ep_length = 5
38+
self.current_step = 0
39+
40+
def reset(
41+
self,
42+
*,
43+
seed: Optional[int] = None,
44+
options: Optional[Dict[str, Any]] = None,
45+
):
46+
if seed is not None:
47+
self.seed(seed)
48+
self.current_step = 0
49+
self.generate_state()
50+
return self.state, {}
51+
52+
def step(self, action):
53+
reward = 1
54+
self.generate_state()
55+
self.current_step += 1
56+
done = self.current_step >= self.ep_length
57+
return self.state, reward, done, {}
58+
59+
def generate_state(self) -> None:
60+
self.state = [self._np_random.integers(0, self.dim)]
61+
62+
def render(self, mode: str = "human") -> None:
63+
pass
64+
65+
def seed(self, seed: Optional[int] = None) -> None:
66+
if seed is not None:
67+
self._np_random, seed = seeding.np_random(seed)
68+
69+
def close(self):
70+
pass
71+
72+
73+
register(
74+
id="Custom_Env/IdentityEnv",
75+
entry_point="gymnasium_env:IdentityEnv",
76+
)
77+
78+
env = make("Custom_Env/IdentityEnv", env_num=10)
79+
80+
train_and_test(env)

examples/custom_env/openai_gym_env.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
from typing import Any, Dict, Optional
20+
21+
import gym
22+
from gym import spaces
23+
from gym.envs.registration import EnvSpec, register
24+
from gym.utils import seeding
25+
from train_and_test import train_and_test
26+
27+
from openrl.envs.common import make
28+
29+
30+
class IdentityEnv(gym.Env):
31+
spec = EnvSpec("IdentityEnv-v1")
32+
33+
def __init__(self, **kwargs):
34+
self.dim = 2
35+
self.observation_space = spaces.Discrete(1)
36+
self.action_space = spaces.Discrete(self.dim)
37+
self.ep_length = 5
38+
self.current_step = 0
39+
40+
def reset(
41+
self,
42+
*,
43+
seed: Optional[int] = None,
44+
options: Optional[Dict[str, Any]] = None,
45+
):
46+
if seed is not None:
47+
self.seed(seed)
48+
self.current_step = 0
49+
self.generate_state()
50+
return self.state
51+
52+
def step(self, action):
53+
reward = 1
54+
self.generate_state()
55+
self.current_step += 1
56+
done = self.current_step >= self.ep_length
57+
return self.state, reward, done, {}
58+
59+
def generate_state(self) -> None:
60+
self.state = [self._np_random.randint(0, self.dim - 1)]
61+
62+
def render(self, mode: str = "human") -> None:
63+
pass
64+
65+
def seed(self, seed: Optional[int] = None) -> None:
66+
if seed is not None:
67+
self._np_random, seed = seeding.np_random(seed)
68+
69+
def close(self):
70+
pass
71+
72+
73+
register(
74+
id="Custom_Env/IdentityEnv-v1",
75+
entry_point="openai_gym_env:IdentityEnv",
76+
)
77+
78+
env = make("GymV21Environment-v0:Custom_Env/IdentityEnv-v1", env_num=10)
79+
80+
train_and_test(env)

examples/custom_env/pettingzoo_env.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Copyright 2023 The OpenRL Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
""""""
18+
19+
20+
from rock_paper_scissors import RockPaperScissors
21+
from train_and_test import train_and_test
22+
23+
from openrl.envs.common import make
24+
from openrl.envs.PettingZoo.registration import register
25+
from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper
26+
27+
register("RockPaperScissors", RockPaperScissors)
28+
env = make(
29+
"RockPaperScissors",
30+
env_num=10,
31+
opponent_wrappers=[RandomOpponentWrapper],
32+
)
33+
obs, info = env.reset()
34+
35+
train_and_test(env)

0 commit comments

Comments
 (0)