tmrl.training module

class tmrl.training.TrainingAgent(observation_space, action_space, device)[source]

Bases: ABC

Training algorithm.

CAUTION: When overriding __init__, don’t forget to call super().__init__ in the subclass.

Parameters:
  • observation_space (gymnasium.spaces.Space) – observation space (here for your convenience)

  • action_space (gymnasium.spaces.Space) – action space (here for your convenience)

  • device (str) – device that should be used for training

abstract get_actor()[source]

Returns the current ActorModule to be broadcast to the RolloutWorkers.

Returns:

current actor to be broadcast

Return type:

ActorModule

abstract train(batch)[source]

Executes a training step.

Parameters:

batch – tuple or batched tensors (previous observation, action, reward, new observation, terminated, truncated)

Returns:

a dictionary containing one entry per metric you wish to log (e.g. for wandb)

Return type:

dict