Skip to content

Commit 9623d5c

Browse files
authored
add dm-control
add dm-control
2 parents 1d377d0 + 2a61d17 commit 9623d5c

File tree

15 files changed

+315
-87
lines changed

15 files changed

+315
-87
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,4 @@ api_docs
158158
opponent_pool
159159
!/examples/selfplay/opponent_templates/tictactoe_opponent/info.json
160160
wandb_run
161+
examples/dmc/new.gif

examples/dmc/ppo.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
episode_length: 25
2+
lr: 5e-4
3+
critic_lr: 5e-4
4+
gamma: 0.99
5+
ppo_epoch: 5
6+
use_valuenorm: true
7+
entropy_coef: 0.0
8+
hidden_size: 128
9+
layer_N: 4
10+
data_chunk_length: 1

examples/dmc/train_ppo.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import numpy as np
2+
from gymnasium.wrappers import FlattenObservation
3+
4+
from openrl.configs.config import create_config_parser
5+
from openrl.envs.common import make
6+
from openrl.envs.wrappers.base_wrapper import BaseWrapper
7+
from openrl.envs.wrappers.extra_wrappers import GIFWrapper
8+
from openrl.modules.common import PPONet as Net
9+
from openrl.runners.common import PPOAgent as Agent
10+
11+
12+
class FrameSkip(BaseWrapper):
13+
def __init__(self, env, num_frames: int = 8):
14+
super().__init__(env)
15+
self.num_frames = num_frames
16+
17+
def step(self, action):
18+
num_skips = self.num_frames
19+
total_reward = 0.0
20+
21+
for x in range(num_skips):
22+
obs, rew, term, trunc, info = super().step(action)
23+
total_reward += rew
24+
if term or trunc:
25+
break
26+
27+
return obs, total_reward, term, trunc, info
28+
29+
30+
env_name = "dm_control/cartpole-balance-v0"
31+
# env_name = "dm_control/walker-walk-v0"
32+
33+
34+
def train():
35+
# create the neural network
36+
cfg_parser = create_config_parser()
37+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
38+
39+
# create environment, set environment parallelism to 9
40+
env = make(
41+
env_name,
42+
env_num=64,
43+
cfg=cfg,
44+
asynchronous=True,
45+
env_wrappers=[FrameSkip, FlattenObservation],
46+
)
47+
48+
net = Net(env, cfg=cfg, device="cuda")
49+
# initialize the trainer
50+
agent = Agent(
51+
net,
52+
)
53+
# start training, set total number of training steps to 20000
54+
agent.train(total_time_steps=100000)
55+
agent.save("./ppo_agent")
56+
env.close()
57+
return agent
58+
59+
60+
61+
62+
63+
def evaluation():
64+
cfg_parser = create_config_parser()
65+
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
66+
# begin to test
67+
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
68+
render_mode = "group_human"
69+
render_mode = "group_rgb_array"
70+
env = make(
71+
env_name,
72+
render_mode=render_mode,
73+
env_num=4,
74+
asynchronous=True,
75+
env_wrappers=[FrameSkip,FlattenObservation],
76+
cfg=cfg
77+
)
78+
env = GIFWrapper(env, gif_path="./new.gif", fps=5)
79+
80+
81+
82+
net = Net(env, cfg=cfg, device="cuda")
83+
# initialize the trainer
84+
agent = Agent(
85+
net,
86+
)
87+
agent.load("./ppo_agent")
88+
89+
# The trained agent sets up the interactive environment it needs.
90+
agent.set_env(env)
91+
# Initialize the environment and get initial observations and environmental information.
92+
obs, info = env.reset()
93+
done = False
94+
step = 0
95+
total_reward = 0.0
96+
while not np.any(done):
97+
if step > 500:
98+
break
99+
# Based on environmental observation input, predict next action.
100+
action, _ = agent.act(obs, deterministic=True)
101+
obs, r, done, info = env.step(action)
102+
step += 1
103+
total_reward += np.mean(r)
104+
if step % 50 == 0:
105+
print(f"{step}: reward:{np.mean(r)}")
106+
print("total step:", step, total_reward)
107+
env.close()
108+
109+
train()
110+
evaluation()

openrl/algorithms/dqn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def prepare_loss(
167167
)
168168

169169
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
170-
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
170+
q_loss = torch.mean(
171+
F.mse_loss(q_values, q_targets.detach())
172+
) # 均方误差损失函数
171173

172174
loss_list.append(q_loss)
173175

openrl/algorithms/vdn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def prepare_loss(
211211
rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
212212
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
213213
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
214-
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
214+
q_loss = torch.mean(
215+
F.mse_loss(q_values, q_targets.detach())
216+
) # 均方误差损失函数
215217

216218
loss_list.append(q_loss)
217219
return loss_list

openrl/envs/common/registration.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,13 @@ def make(
6565
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
6666
)
6767
else:
68-
if id in gym.envs.registry.keys():
68+
if id.startswith("dm_control/"):
69+
from openrl.envs.dmc import make_dmc_envs
70+
71+
env_fns = make_dmc_envs(
72+
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
73+
)
74+
elif id in gym.envs.registry.keys():
6975
from openrl.envs.gymnasium import make_gym_envs
7076

7177
env_fns = make_gym_envs(
@@ -77,6 +83,7 @@ def make(
7783
env_fns = make_mpe_envs(
7884
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
7985
)
86+
8087
elif id in openrl.envs.nlp_all_envs:
8188
from openrl.envs.nlp import make_nlp_envs
8289

openrl/envs/dmc/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import copy
2+
from typing import Callable, List, Optional, Union
3+
4+
import dmc2gym
5+
6+
from openrl.envs.common import build_envs
7+
from openrl.envs.dmc.dmc_env import make
8+
9+
10+
def make_dmc_envs(
11+
id: str,
12+
env_num: int = 1,
13+
render_mode: Optional[Union[str, List[str]]] = None,
14+
**kwargs,
15+
):
16+
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
17+
RemoveTruncated,
18+
Single2MultiAgentWrapper,
19+
)
20+
from openrl.envs.wrappers.extra_wrappers import ConvertEmptyBoxWrapper
21+
22+
env_wrappers = copy.copy(kwargs.pop("env_wrappers", []))
23+
env_wrappers += [ConvertEmptyBoxWrapper, RemoveTruncated, Single2MultiAgentWrapper]
24+
env_fns = build_envs(
25+
make=make,
26+
id=id,
27+
env_num=env_num,
28+
render_mode=render_mode,
29+
wrappers=env_wrappers,
30+
**kwargs,
31+
)
32+
33+
return env_fns

openrl/envs/dmc/dmc_env.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import Any, Optional
2+
3+
import dmc2gym
4+
import gymnasium as gym
5+
import numpy as np
6+
7+
# class DmcEnv:
8+
# def __init__(self):
9+
# env = dmc2gym.make(
10+
# domain_name='walker',
11+
# task_name='walk',
12+
# seed=42,
13+
# visualize_reward=False,
14+
# from_pixels='features',
15+
# height=224,
16+
# width=224,
17+
# frame_skip=2
18+
# )
19+
# # self.observation_space = spaces.Box(
20+
# # low=np.array([0, 0, 0, 0]),
21+
# # high=np.array([self.nrow - 1, self.ncol - 1, self.nrow - 1, self.ncol - 1]),
22+
# # dtype=int,
23+
# # ) # current position and target position
24+
# # self.action_space = spaces.Discrete(
25+
# # 5
26+
# # )
27+
28+
29+
def make(
30+
id: str,
31+
render_mode: Optional[str] = None,
32+
**kwargs: Any,
33+
):
34+
env = gym.make(id, render_mode=render_mode)
35+
# env = dmc2gym.make(
36+
# domain_name='walker',
37+
# task_name='walk',
38+
# seed=42,
39+
# visualize_reward=False,
40+
# from_pixels='features',
41+
# height=224,
42+
# width=224,
43+
# frame_skip=2
44+
# )
45+
return env

openrl/envs/mpe/rendering.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,10 @@
2929
except ImportError:
3030
print(
3131
"Error occured while running `from pyglet.gl import *`",
32-
(
33-
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
34-
" install python-opengl'. If you're running on a server, you may need a"
35-
" virtual frame buffer; something like this should work: 'xvfb-run -s"
36-
' "-screen 0 1400x900x24" python <your_script.py>\''
37-
),
32+
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
33+
" install python-opengl'. If you're running on a server, you may need a"
34+
" virtual frame buffer; something like this should work: 'xvfb-run -s"
35+
' "-screen 0 1400x900x24" python <your_script.py>\'',
3836
)
3937

4038
import math

openrl/envs/vec_env/async_venv.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,8 @@ def reset_send(
233233

234234
if self._state != AsyncState.DEFAULT:
235235
raise AlreadyPendingCallError(
236-
(
237-
"Calling `reset_send` while waiting for a pending call to"
238-
f" `{self._state.value}` to complete"
239-
),
236+
"Calling `reset_send` while waiting for a pending call to"
237+
f" `{self._state.value}` to complete",
240238
self._state.value,
241239
)
242240

@@ -328,10 +326,8 @@ def step_send(self, actions: np.ndarray):
328326
self._assert_is_running()
329327
if self._state != AsyncState.DEFAULT:
330328
raise AlreadyPendingCallError(
331-
(
332-
"Calling `step_send` while waiting for a pending call to"
333-
f" `{self._state.value}` to complete."
334-
),
329+
"Calling `step_send` while waiting for a pending call to"
330+
f" `{self._state.value}` to complete.",
335331
self._state.value,
336332
)
337333

@@ -575,10 +571,8 @@ def call_send(self, name: str, *args, **kwargs):
575571
self._assert_is_running()
576572
if self._state != AsyncState.DEFAULT:
577573
raise AlreadyPendingCallError(
578-
(
579-
"Calling `call_send` while waiting "
580-
f"for a pending call to `{self._state.value}` to complete."
581-
),
574+
"Calling `call_send` while waiting "
575+
f"for a pending call to `{self._state.value}` to complete.",
582576
str(self._state.value),
583577
)
584578

@@ -635,10 +629,8 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs):
635629
self._assert_is_running()
636630
if self._state != AsyncState.DEFAULT:
637631
raise AlreadyPendingCallError(
638-
(
639-
"Calling `exec_func_send` while waiting "
640-
f"for a pending call to `{self._state.value}` to complete."
641-
),
632+
"Calling `exec_func_send` while waiting "
633+
f"for a pending call to `{self._state.value}` to complete.",
642634
str(self._state.value),
643635
)
644636

@@ -715,10 +707,8 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]):
715707

716708
if self._state != AsyncState.DEFAULT:
717709
raise AlreadyPendingCallError(
718-
(
719-
"Calling `set_attr` while waiting "
720-
f"for a pending call to `{self._state.value}` to complete."
721-
),
710+
"Calling `set_attr` while waiting "
711+
f"for a pending call to `{self._state.value}` to complete.",
722712
str(self._state.value),
723713
)
724714

0 commit comments

Comments
 (0)