Skip to content

Commit

Permalink
Fix venv wrapper reset retval error with gym env (#712)
Browse files Browse the repository at this point in the history
* Fix venv wrapper reset retval error with gym env

* fix lint
  • Loading branch information
Trinkle23897 authored Jul 31, 2022
1 parent f270e88 commit 0f59e38
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 20 deletions.
29 changes: 22 additions & 7 deletions test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from tianshou.utils import RunningMeanStd

if __name__ == '__main__':
if __name__ == "__main__":
from env import MyTestEnv, NXEnv
else: # pytest
from test.base.env import MyTestEnv, NXEnv
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
spent_time = time.time()
while current_idx_start < len(action_list):
A, B, C, D = v.step(action=act, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
b = Batch({"obs": A, "rew": B, "done": C, "info": D})
env_ids = b.info.env_id
o.append(b)
current_idx_start += len(act)
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
for info in infos:
assert recurse_comp(infos[0], info)

if __name__ == '__main__':
if __name__ == "__main__":
t = [0] * len(venv)
for i, e in enumerate(venv):
t[i] = time.time()
Expand All @@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
e.reset(np.where(done)[0])
t[i] = time.time() - t[i]
for i, v in enumerate(venv):
print(f'{type(v)}: {t[i]:.6f}s')
print(f"{type(v)}: {t[i]:.6f}s")

def assert_get(v, expected):
assert v.get_env_attr("size") == expected
Expand Down Expand Up @@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8):
assert isinstance(info[0], dict)


def test_venv_wrapper_gym(num_envs: int = 4):
# Issue 697
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
envs = VectorEnvNormObs(envs)
obs_ref = envs.reset(return_info=False)
obs, info = envs.reset(return_info=True)
assert isinstance(obs_ref, np.ndarray)
assert isinstance(obs, np.ndarray)
assert isinstance(info, list)
assert isinstance(info[0], dict)
assert obs_ref.shape[0] == obs.shape[0] == len(info) == num_envs


def run_align_norm_obs(raw_env, train_env, test_env, action_list):
eps = np.finfo(np.float32).eps.item()
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
Expand Down Expand Up @@ -309,7 +322,7 @@ def __init__(self):
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
np.array([original_act] * bsz)
np.array([original_act] * bsz),
)
# convert multidiscrete with different action number per
# dimension to discrete action space
Expand All @@ -321,7 +334,7 @@ def __init__(self):
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
np.array([env_m.action_space.nvec - 1] * bsz)
np.array([env_m.action_space.nvec - 1] * bsz),
)


Expand Down Expand Up @@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
assert v.shape[0] == num_envs


if __name__ == '__main__':
if __name__ == "__main__":
test_venv_norm_obs()
test_venv_wrapper_gym()
test_venv_wrapper_envpool()
test_venv_wrapper_envpool_gym_reset_return_info()
test_env_obs_dtype()
test_vecenv()
test_attr_unwrapped()
Expand Down
4 changes: 2 additions & 2 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(**gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs, info = rval
Expand Down Expand Up @@ -173,7 +173,7 @@ def _reset_env_with_ids(
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(global_ids, **gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs_reset, info = rval
Expand Down
20 changes: 10 additions & 10 deletions tianshou/env/venv_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
return self.venv.reset(id, **kwargs)

def step(
Expand Down Expand Up @@ -84,15 +84,15 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
retval = self.venv.reset(id, **kwargs)
reset_returns_info = isinstance(
retval, (tuple, list)
) and len(retval) == 2 and isinstance(retval[1], dict)
if reset_returns_info:
obs, info = retval
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
rval = self.venv.reset(id, **kwargs)
returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs, info = rval
else:
obs = retval
obs = rval

if isinstance(obs, tuple):
raise TypeError(
Expand All @@ -103,7 +103,7 @@ def reset(
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs)
obs = self._norm_obs(obs)
if reset_returns_info:
if returns_info:
return obs, info
else:
return obs
Expand Down
2 changes: 1 addition & 1 deletion tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
"""Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return
Expand Down

0 comments on commit 0f59e38

Please sign in to comment.