[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #895 from msmith93:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591000415
Change-Id: Ia167f93021c40461080a66228421478ab3e8b716
  • Loading branch information
Copybara-Service committed Dec 14, 2023
2 parents 80572d0 + 441df4b commit 13e0550
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tf_agents/environments/suite_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function

from typing import Dict, Optional, Sequence, Text
from typing import Dict, Optional, Sequence, Text, Any

import ale_py # pylint: disable=unused-import
import gin
Expand Down Expand Up @@ -84,13 +84,15 @@ def load(
] = DEFAULT_ATARI_GYM_WRAPPERS,
env_wrappers: Sequence[types.PyEnvWrapper] = (),
spec_dtype_map: Optional[Dict[gym.Space, np.dtype]] = None,
gym_kwargs: Optional[Dict[str, Any]] = None,
) -> py_environment.PyEnvironment:
"""Loads the selected environment and wraps it with the specified wrappers."""
if spec_dtype_map is None:
spec_dtype_map = {gym.spaces.Box: np.uint8}

gym_kwargs = gym_kwargs if gym_kwargs else {}
gym_spec = gym.spec(environment_name)
gym_env = gym_spec.make()
gym_env = gym_spec.make(**gym_kwargs)

if max_episode_steps is None and gym_spec.max_episode_steps is not None:
max_episode_steps = gym_spec.max_episode_steps
Expand Down

0 comments on commit 13e0550

Please sign in to comment.