from abc import ABC, abstractmethod
[docs]
class TrainingAgent(ABC):
"""
Training algorithm.
CAUTION: When overriding `__init__`, don't forget to call `super().__init__` in the subclass.
"""
def __init__(self,
observation_space,
action_space,
device):
"""
Args:
observation_space (gymnasium.spaces.Space): observation space (here for your convenience)
action_space (gymnasium.spaces.Space): action space (here for your convenience)
device (str): device that should be used for training
"""
self.observation_space = observation_space
self.action_space = action_space
self.device = device
[docs]
@abstractmethod
def train(self, batch):
"""
Executes a training step.
Args:
batch: tuple or batched tensors (previous observation, action, reward, new observation, terminated, truncated)
Returns:
dict: a dictionary containing one entry per metric you wish to log (e.g. for wandb)
"""
raise NotImplementedError
[docs]
@abstractmethod
def get_actor(self):
"""
Returns the current ActorModule to be broadcast to the RolloutWorkers.
Returns:
ActorModule: current actor to be broadcast
"""
raise NotImplementedError