Source code for tmrl.envs

# standard library imports
from dataclasses import InitVar, dataclass

# third-party imports
import gymnasium

# local imports
from tmrl.wrappers import (AffineObservationWrapper, Float64ToFloat32)


__docformat__ = "google"


[docs] class GenericGymEnv(gymnasium.Wrapper): def __init__(self, id: str = "Pendulum-v0", gym_kwargs=None, obs_scale: float = 0., to_float32=False): """ Use this wrapper when using the framework with arbitrary environments. Args: id (str): gymnasium id gym_kwargs (dict): keyword arguments of the gymnasium environment (i.e. between -1.0 and 1.0 when the actual action space is something else) obs_scale (float): change this if wanting to rescale actions by a scalar to_float32 (bool): set this to True if wanting observations to be converted to numpy.float32 """ if gym_kwargs is None: gym_kwargs = {} env = gymnasium.make(id, **gym_kwargs, disable_env_checker=True) if obs_scale: env = AffineObservationWrapper(env, 0, obs_scale) if to_float32: env = Float64ToFloat32(env) # assert isinstance(env.action_space, gymnasium.spaces.Box), f"{env.action_space}" # env = NormalizeActionWrapper(env) super().__init__(env)
if __name__ == '__main__': pass