tmrl.training_offline module

class tmrl.training_offline.TorchTrainingOffline(env_cls: type | None = None, memory_cls: type | None = None, training_agent_cls: type | None = None, epochs: int = 10, rounds: int = 50, steps: int = 2000, update_model_interval: int = 100, update_buffer_interval: int = 100, max_training_steps_per_env_step: float = 1.0, sleep_between_buffer_retrieval_attempts: float = 1.0, profiling: bool = False, agent_scheduler: callable | None = None, start_training: int = 0, device: str | None = None)[source]

Bases: TrainingOffline

TrainingOffline for trainers based on PyTorch.

This class implements automatic device selection with PyTorch.

Same arguments as TrainingOffline, but when device is None it is selected automatically for torch.

Parameters:
  • env_cls (type) – class of a dummy environment, used only to retrieve observation and action spaces if needed. Alternatively, this can be a tuple of the form (observation_space, action_space).

  • memory_cls (type) – class of the replay memory

  • training_agent_cls (type) – class of the training agent

  • epochs (int) – total number of epochs, we save the agent every epoch

  • rounds (int) – number of rounds per epoch, we generate statistics every round

  • steps (int) – number of training steps per round

  • update_model_interval (int) – number of training steps between model broadcasts

  • update_buffer_interval (int) – number of training steps between retrieving buffered samples

  • max_training_steps_per_env_step (float) – training will pause when above this ratio

  • sleep_between_buffer_retrieval_attempts (float) – algorithm will sleep for this amount of time when waiting for needed incoming samples

  • profiling (bool) – if True, run_epoch will be profiled and the profiling will be printed at the end of each epoch

  • agent_scheduler (callable) – if not None, must be of the form f(Agent, epoch), called at the beginning of each epoch

  • start_training (int) – minimum number of samples in the replay buffer before starting training

  • device (str) – device on which the memory will collate training samples (None for automatic)

class tmrl.training_offline.TrainingOffline(env_cls: type | None = None, memory_cls: type | None = None, training_agent_cls: type | None = None, epochs: int = 10, rounds: int = 50, steps: int = 2000, update_model_interval: int = 100, update_buffer_interval: int = 100, max_training_steps_per_env_step: float = 1.0, sleep_between_buffer_retrieval_attempts: float = 1.0, profiling: bool = False, agent_scheduler: callable | None = None, start_training: int = 0, device: str | None = None)[source]

Bases: object

Training wrapper for off-policy algorithms.

Parameters:
  • env_cls (type) – class of a dummy environment, used only to retrieve observation and action spaces if needed. Alternatively, this can be a tuple of the form (observation_space, action_space).

  • memory_cls (type) – class of the replay memory

  • training_agent_cls (type) – class of the training agent

  • epochs (int) – total number of epochs, we save the agent every epoch

  • rounds (int) – number of rounds per epoch, we generate statistics every round

  • steps (int) – number of training steps per round

  • update_model_interval (int) – number of training steps between model broadcasts

  • update_buffer_interval (int) – number of training steps between retrieving buffered samples

  • max_training_steps_per_env_step (float) – training will pause when above this ratio

  • sleep_between_buffer_retrieval_attempts (float) – algorithm will sleep for this amount of time when waiting for needed incoming samples

  • profiling (bool) – if True, run_epoch will be profiled and the profiling will be printed at the end of each epoch

  • agent_scheduler (callable) – if not None, must be of the form f(Agent, epoch), called at the beginning of each epoch

  • start_training (int) – minimum number of samples in the replay buffer before starting training

  • device (str) – device on which the memory will collate training samples