Source code for tmrl.memory

# standard library imports
import os
import pickle
import zlib
from abc import ABC, abstractmethod
from pathlib import Path
from random import randint
import logging

# third-party imports
import numpy as np
# from torch.utils.data import DataLoader, Dataset, Sampler

# local imports
from tmrl.util import collate_torch


__docformat__ = "google"


def check_samples_crc(original_po, original_a, original_o, original_r, original_d, original_t, rebuilt_po, rebuilt_a, rebuilt_o, rebuilt_r, rebuilt_d, rebuilt_t, debug_ts, debug_ts_res):
    assert original_po is None or str(original_po) == str(rebuilt_po), f"previous observations don't match:\noriginal:\n{original_po}\n!= rebuilt:\n{rebuilt_po}\nTime step: {debug_ts}, since reset: {debug_ts_res}"
    assert str(original_a) == str(rebuilt_a), f"actions don't match:\noriginal:\n{original_a}\n!= rebuilt:\n{rebuilt_a}\nTime step: {debug_ts}, since reset: {debug_ts_res}"
    assert str(original_o) == str(rebuilt_o), f"observations don't match:\noriginal:\n{original_o}\n!= rebuilt:\n{rebuilt_o}\nTime step: {debug_ts}, since reset: {debug_ts_res}"
    assert str(original_r) == str(rebuilt_r), f"rewards don't match:\noriginal:\n{original_r}\n!= rebuilt:\n{rebuilt_r}\nTime step: {debug_ts}, since reset: {debug_ts_res}"
    assert str(original_d) == str(rebuilt_d), f"terminated don't match:\noriginal:\n{original_d}\n!= rebuilt:\n{rebuilt_d}\nTime step: {debug_ts}, since reset: {debug_ts_res}"
    assert str(original_t) == str(rebuilt_t), f"truncated don't match:\noriginal:\n{original_t}\n!= rebuilt:\n{rebuilt_t}\nTime step: {debug_ts}, since reset: {debug_ts_res}"
    original_crc = zlib.crc32(str.encode(str((original_a, original_o, original_r, original_d, original_t))))
    crc = zlib.crc32(str.encode(str((rebuilt_a, rebuilt_o, rebuilt_r, rebuilt_d, rebuilt_t))))
    assert crc == original_crc, f"CRC failed: new crc:{crc} != old crc:{original_crc}.\nEither the custom pipeline is corrupted, or crc_debug is False in the rollout worker.\noriginal sample:\n{(original_a, original_o, original_r, original_d)}\n!= rebuilt sample:\n{(rebuilt_a, rebuilt_o, rebuilt_r, rebuilt_d)}\nTime step: {debug_ts}, since reset: {debug_ts_res}"
    print(f"DEBUG: CRC check passed. Time step: {debug_ts}, since reset: {debug_ts_res}")


[docs] class Memory(ABC): """ Interface implementing the replay buffer. .. note:: When overriding `__init__`, don't forget to call `super().__init__` in the subclass. Your `__init__` method needs to take at least all the arguments of the superclass. """ def __init__(self, device, nb_steps, sample_preprocessor: callable = None, memory_size=1000000, batch_size=256, dataset_path="", crc_debug=False): """ Args: device (str): output tensors will be collated to this device nb_steps (int): number of steps per round sample_preprocessor (callable): can be used for data augmentation memory_size (int): size of the circular buffer batch_size (int): batch size of the output tensors dataset_path (str): an offline dataset may be provided here to initialize the memory crc_debug (bool): False usually, True when using CRC debugging of the pipeline """ self.nb_steps = nb_steps self.device = device self.batch_size = batch_size self.memory_size = memory_size self.sample_preprocessor = sample_preprocessor self.crc_debug = crc_debug # These stats are here because they reach the trainer along with the buffer: self.stat_test_return = 0.0 self.stat_train_return = 0.0 self.stat_test_steps = 0 self.stat_train_steps = 0 # init memory self.path = Path(dataset_path) logging.debug(f"Memory self.path:{self.path}") if os.path.isfile(self.path / 'data.pkl'): with open(self.path / 'data.pkl', 'rb') as f: self.data = list(pickle.load(f)) else: logging.info("no data found, initializing empty replay memory") self.data = [] if len(self) > self.memory_size: # TODO: crop to memory_size logging.warning(f"the dataset length ({len(self)}) is longer than memory_size ({self.memory_size})") def __iter__(self): for _ in range(self.nb_steps): yield self.sample()
[docs] @abstractmethod def append_buffer(self, buffer): """ Must append a Buffer object to the memory. Args: buffer (tmrl.networking.Buffer): the buffer of samples to append. """ raise NotImplementedError
@abstractmethod def __len__(self): """ Must return the length of the memory. Returns: int: the maximum `item` argument of `get_transition` """ raise NotImplementedError
[docs] @abstractmethod def get_transition(self, item): """ Must return a transition. `info` is required in each sample for CRC debugging. The 'crc' key is what is important when using this feature. Args: item (int): the index where to sample Returns: Tuple: (prev_obs, prev_act, rew, obs, terminated, truncated, info) """ raise NotImplementedError
[docs] @abstractmethod def collate(self, batch, device): """ Must collate `batch` onto `device`. `batch` is a list of training samples. The length of `batch` is `batch_size`. Each training sample in the list is of the form `(prev_obs, new_act, rew, new_obs, terminated, truncated)`. These samples must be collated into 6 tensors of batch dimension `batch_size`. These tensors should be collated onto the device indicated by the `device` argument. Then, your implementation must return a single tuple containing these 6 tensors. Args: batch (list): list of `(prev_obs, new_act, rew, new_obs, terminated, truncated)` tuples device: device onto which the list needs to be collated into batches `batch_size` Returns: Tuple of tensors: (prev_obs_tens, new_act_tens, rew_tens, new_obs_tens, terminated_tens, truncated_tens) collated on device `device`, each of batch dimension `batch_size` """ raise NotImplementedError
def sample(self): indices = self.sample_indices() batch = [self[idx] for idx in indices] batch = self.collate(batch, self.device) return batch def append(self, buffer): if len(buffer) > 0: self.stat_train_return = buffer.stat_train_return self.stat_test_return = buffer.stat_test_return self.stat_train_steps = buffer.stat_train_steps self.stat_test_steps = buffer.stat_test_steps self.append_buffer(buffer) def __getitem__(self, item): prev_obs, new_act, rew, new_obs, terminated, truncated, info = self.get_transition(item) if self.crc_debug: po, a, o, r, d, t = info['crc_sample'] debug_ts, debug_ts_res = info['crc_sample_ts'] check_samples_crc(po, a, o, r, d, t, prev_obs, new_act, new_obs, rew, terminated, truncated, debug_ts, debug_ts_res) if self.sample_preprocessor is not None: prev_obs, new_act, rew, new_obs, terminated, truncated = self.sample_preprocessor(prev_obs, new_act, rew, new_obs, terminated, truncated) terminated = np.float32(terminated) # we don't want bool tensors truncated = np.float32(truncated) # we don't want bool tensors return prev_obs, new_act, rew, new_obs, terminated, truncated def sample_indices(self): return (randint(0, len(self) - 1) for _ in range(self.batch_size))
[docs] class TorchMemory(Memory, ABC): """ Partial implementation of the `Memory` class collating samples into batched torch tensors. .. note:: When overriding `__init__`, don't forget to call `super().__init__` in the subclass. Your `__init__` method needs to take at least all the arguments of the superclass. """ def __init__(self, device, nb_steps, sample_preprocessor: callable = None, memory_size=1000000, batch_size=256, dataset_path="", crc_debug=False): """ Args: device (str): output tensors will be collated to this device nb_steps (int): number of steps per round sample_preprocessor (callable): can be used for data augmentation memory_size (int): size of the circular buffer batch_size (int): batch size of the output tensors dataset_path (str): an offline dataset may be provided here to initialize the memory crc_debug (bool): False usually, True when using CRC debugging of the pipeline """ super().__init__(memory_size=memory_size, batch_size=batch_size, dataset_path=dataset_path, nb_steps=nb_steps, sample_preprocessor=sample_preprocessor, crc_debug=crc_debug, device=device)
[docs] def collate(self, batch, device): return collate_torch(batch, device)