tmrl.memory module
- class tmrl.memory.Memory(device, nb_steps, sample_preprocessor: callable | None = None, memory_size=1000000, batch_size=256, dataset_path='', crc_debug=False)[source]
Bases:
ABCInterface implementing the replay buffer.
Note
When overriding __init__, don’t forget to call super().__init__ in the subclass. Your __init__ method needs to take at least all the arguments of the superclass.
- Parameters:
device (str) – output tensors will be collated to this device
nb_steps (int) – number of steps per round
sample_preprocessor (callable) – can be used for data augmentation
memory_size (int) – size of the circular buffer
batch_size (int) – batch size of the output tensors
dataset_path (str) – an offline dataset may be provided here to initialize the memory
crc_debug (bool) – False usually, True when using CRC debugging of the pipeline
- abstract append_buffer(buffer)[source]
Must append a Buffer object to the memory.
- Parameters:
buffer (tmrl.networking.Buffer) – the buffer of samples to append.
- abstract collate(batch, device)[source]
Must collate batch onto device.
batch is a list of training samples. The length of batch is batch_size. Each training sample in the list is of the form (prev_obs, new_act, rew, new_obs, terminated, truncated). These samples must be collated into 6 tensors of batch dimension batch_size. These tensors should be collated onto the device indicated by the device argument. Then, your implementation must return a single tuple containing these 6 tensors.
- Parameters:
batch (list) – list of (prev_obs, new_act, rew, new_obs, terminated, truncated) tuples
device – device onto which the list needs to be collated into batches batch_size
- Returns:
(prev_obs_tens, new_act_tens, rew_tens, new_obs_tens, terminated_tens, truncated_tens) collated on device device, each of batch dimension batch_size
- Return type:
Tuple of tensors
- abstract get_transition(item)[source]
Must return a transition.
info is required in each sample for CRC debugging. The ‘crc’ key is what is important when using this feature.
- Parameters:
item (int) – the index where to sample
- Returns:
(prev_obs, prev_act, rew, obs, terminated, truncated, info)
- Return type:
Tuple
- class tmrl.memory.TorchMemory(device, nb_steps, sample_preprocessor: callable | None = None, memory_size=1000000, batch_size=256, dataset_path='', crc_debug=False)[source]
Bases:
Memory,ABCPartial implementation of the Memory class collating samples into batched torch tensors.
Note
When overriding __init__, don’t forget to call super().__init__ in the subclass. Your __init__ method needs to take at least all the arguments of the superclass.
- Parameters:
device (str) – output tensors will be collated to this device
nb_steps (int) – number of steps per round
sample_preprocessor (callable) – can be used for data augmentation
memory_size (int) – size of the circular buffer
batch_size (int) – batch size of the output tensors
dataset_path (str) – an offline dataset may be provided here to initialize the memory
crc_debug (bool) – False usually, True when using CRC debugging of the pipeline
- collate(batch, device)[source]
Must collate batch onto device.
batch is a list of training samples. The length of batch is batch_size. Each training sample in the list is of the form (prev_obs, new_act, rew, new_obs, terminated, truncated). These samples must be collated into 6 tensors of batch dimension batch_size. These tensors should be collated onto the device indicated by the device argument. Then, your implementation must return a single tuple containing these 6 tensors.
- Parameters:
batch (list) – list of (prev_obs, new_act, rew, new_obs, terminated, truncated) tuples
device – device onto which the list needs to be collated into batches batch_size
- Returns:
(prev_obs_tens, new_act_tens, rew_tens, new_obs_tens, terminated_tens, truncated_tens) collated on device device, each of batch dimension batch_size
- Return type:
Tuple of tensors