Source code for tmrl.networking

# standard library imports
import datetime
import os
import socket
import time
import atexit
import json
import shutil
import tempfile
import itertools
from os.path import exists

# third-party imports
import numpy as np
from requests import get
from tlspyo import Relay, Endpoint

# local imports
from tmrl.actor import ActorModule
from tmrl.util import dump, load, partial_to_dict
import tmrl.config.config_constants as cfg
import tmrl.config.config_objects as cfg_obj

import logging


__docformat__ = "google"


# PRINT: ============================================


def print_with_timestamp(s):
    x = datetime.datetime.now()
    sx = x.strftime("%x %X ")
    logging.info(sx + str(s))


def print_ip():
    public_ip = get('http://api.ipify.org').text
    local_ip = socket.gethostbyname(socket.gethostname())
    print_with_timestamp(f"public IP: {public_ip}, local IP: {local_ip}")


# BUFFER: ===========================================


[docs] class Buffer: """ Buffer of training samples. `Server`, `RolloutWorker` and `Trainer` all have their own `Buffer` to store and send training samples. Samples are tuples of the form (`act`, `new_obs`, `rew`, `terminated`, `truncated`, `info`) """ def __init__(self, maxlen=cfg.BUFFERS_MAXLEN): """ Args: maxlen (int): buffer length """ self.memory = [] self.stat_train_return = 0.0 # stores the train return self.stat_test_return = 0.0 # stores the test return self.stat_train_steps = 0 # stores the number of steps per training episode self.stat_test_steps = 0 # stores the number of steps per test episode self.maxlen = maxlen def clip_to_maxlen(self): lenmem = len(self.memory) if lenmem > self.maxlen: print_with_timestamp("buffer overflow. Discarding old samples.") self.memory = self.memory[(lenmem - self.maxlen):]
[docs] def append_sample(self, sample): """ Appends `sample` to the buffer. Args: sample (Tuple): a training sample of the form (`act`, `new_obs`, `rew`, `terminated`, `truncated`, `info`) """ self.memory.append(sample) self.clip_to_maxlen()
[docs] def clear(self): """ Clears the buffer but keeps train and test returns. """ self.memory = []
def __len__(self): return len(self.memory) def __iadd__(self, other): self.memory += other.memory self.clip_to_maxlen() self.stat_train_return = other.stat_train_return self.stat_test_return = other.stat_test_return self.stat_train_steps = other.stat_train_steps self.stat_test_steps = other.stat_test_steps return self
# SERVER SERVER: =====================================
[docs] class Server: """ Central server. The `Server` lets 1 `Trainer` and n `RolloutWorkers` connect. It buffers experiences sent by workers and periodically sends these to the trainer. It also receives the weights from the trainer and broadcasts these to the connected workers. """ def __init__(self, port=cfg.PORT, password=cfg.PASSWORD, local_port=cfg.LOCAL_PORT_SERVER, header_size=cfg.HEADER_SIZE, security=cfg.SECURITY, keys_dir=cfg.CREDENTIALS_DIRECTORY, max_workers=cfg.NB_WORKERS): """ Args: port (int): tlspyo public port password (str): tlspyo password local_port (int): tlspyo local communication port header_size (int): tlspyo header size (bytes) security (Union[str, None]): tlspyo security type (None or "TLS") keys_dir (str): tlspyo credentials directory max_workers (int): max number of accepted workers """ self.__relay = Relay(port=port, password=password, accepted_groups={ 'trainers': { 'max_count': 1, 'max_consumables': None}, 'workers': { 'max_count': max_workers, 'max_consumables': None}}, local_com_port=local_port, header_size=header_size, security=security, keys_dir=keys_dir)
# TRAINER: ========================================== class TrainerInterface: """ This is the trainer's network interface This connects to the server This receives samples batches and sends new weights """ def __init__(self, server_ip=None, server_port=cfg.PORT, password=cfg.PASSWORD, local_com_port=cfg.LOCAL_PORT_TRAINER, header_size=cfg.HEADER_SIZE, max_buf_len=cfg.BUFFER_SIZE, security=cfg.SECURITY, keys_dir=cfg.CREDENTIALS_DIRECTORY, hostname=cfg.HOSTNAME, model_path=cfg.MODEL_PATH_TRAINER): self.model_path = model_path self.server_ip = server_ip if server_ip is not None else '127.0.0.1' self.__endpoint = Endpoint(ip_server=self.server_ip, port=server_port, password=password, groups="trainers", local_com_port=local_com_port, header_size=header_size, max_buf_len=max_buf_len, security=security, keys_dir=keys_dir, hostname=hostname) print_with_timestamp(f"server IP: {self.server_ip}") self.__endpoint.notify(groups={'trainers': -1}) # retrieve everything def broadcast_model(self, model: ActorModule): """ model must be an ActorModule broadcasts the model's weights to all connected RolloutWorkers """ model.save(self.model_path) with open(self.model_path, 'rb') as f: weights = f.read() self.__endpoint.broadcast(weights, "workers") def retrieve_buffer(self): """ returns the TrainerInterface's buffer of training samples """ buffers = self.__endpoint.receive_all() res = Buffer() for buf in buffers: res += buf self.__endpoint.notify(groups={'trainers': -1}) # retrieve everything return res def log_environment_variables(): """ add certain relevant environment variables to our config usage: `LOG_VARIABLES='HOME JOBID' python ...` """ return {k: os.environ.get(k, '') for k in os.environ.get('LOG_VARIABLES', '').strip().split()} def load_run_instance(checkpoint_path): """ Default function used to load trainers from checkpoint path Args: checkpoint_path: the path where instances of run_cls are checkpointed Returns: An instance of run_cls loaded from checkpoint_path """ return load(checkpoint_path) def dump_run_instance(run_instance, checkpoint_path): """ Default function used to dump trainers to checkpoint path Args: run_instance: the instance of run_cls to checkpoint checkpoint_path: the path where instances of run_cls are checkpointed """ dump(run_instance, checkpoint_path) def iterate_epochs(run_cls, interface: TrainerInterface, checkpoint_path: str, dump_run_instance_fn=dump_run_instance, load_run_instance_fn=load_run_instance, epochs_between_checkpoints=1, updater_fn=None): """ Main training loop (remote) The run_cls instance is saved in checkpoint_path at the end of each epoch The model weights are sent to the RolloutWorker every model_checkpoint_interval epochs Generator yielding episode statistics (list of pd.Series) while running and checkpointing """ checkpoint_path = checkpoint_path or tempfile.mktemp("_remove_on_exit") try: logging.debug(f"checkpoint_path: {checkpoint_path}") if not exists(checkpoint_path): logging.info(f"=== specification ".ljust(70, "=")) run_instance = run_cls() dump_run_instance_fn(run_instance, checkpoint_path) logging.info(f"") else: logging.info(f"Loading checkpoint...") t1 = time.time() run_instance = load_run_instance_fn(checkpoint_path) logging.info(f" Loaded checkpoint in {time.time() - t1} seconds.") if updater_fn is not None: logging.info(f"Updating checkpoint...") t1 = time.time() run_instance = updater_fn(run_instance, run_cls) logging.info(f"Checkpoint updated in {time.time() - t1} seconds.") while run_instance.epoch < run_instance.epochs: # time.sleep(1) # on network file systems writing files is asynchronous and we need to wait for sync yield run_instance.run_epoch(interface=interface) # yield stats data frame (this makes this function a generator) if run_instance.epoch % epochs_between_checkpoints == 0: logging.info(f" saving checkpoint...") t1 = time.time() dump_run_instance_fn(run_instance, checkpoint_path) logging.info(f" saved checkpoint in {time.time() - t1} seconds.") # we delete and reload the run_instance from disk to ensure the exact same code runs regardless of interruptions # del run_instance # gc.collect() # garbage collection # run_instance = load_run_instance_fn(checkpoint_path) finally: if checkpoint_path.endswith("_remove_on_exit") and exists(checkpoint_path): os.remove(checkpoint_path) def run_with_wandb(entity, project, run_id, interface, run_cls, checkpoint_path: str = None, dump_run_instance_fn=None, load_run_instance_fn=None, updater_fn=None): """ Main training loop (remote). saves config and stats to https://wandb.com """ dump_run_instance_fn = dump_run_instance_fn or dump_run_instance load_run_instance_fn = load_run_instance_fn or load_run_instance wandb_dir = tempfile.mkdtemp() # prevent wandb from polluting the home directory atexit.register(shutil.rmtree, wandb_dir, ignore_errors=True) # clean up after wandb atexit handler finishes import wandb logging.debug(f" run_cls: {run_cls}") config = partial_to_dict(run_cls) config['environ'] = log_environment_variables() # config['git'] = git_info() # TODO: check this for bugs resume = checkpoint_path and exists(checkpoint_path) wandb_initialized = False err_cpt = 0 while not wandb_initialized: try: wandb.init(dir=wandb_dir, entity=entity, project=project, id=run_id, resume=resume, config=config) wandb_initialized = True except Exception as e: err_cpt += 1 logging.warning(f"wandb error {err_cpt}: {e}") if err_cpt > 10: logging.warning(f"Could not connect to wandb, aborting.") exit() else: time.sleep(10.0) # logging.info(config) for stats in iterate_epochs(run_cls, interface, checkpoint_path, dump_run_instance_fn, load_run_instance_fn, 1, updater_fn): [wandb.log(json.loads(s.to_json())) for s in stats] def run(interface, run_cls, checkpoint_path: str = None, dump_run_instance_fn=None, load_run_instance_fn=None, updater_fn=None): """ Main training loop (remote). """ dump_run_instance_fn = dump_run_instance_fn or dump_run_instance load_run_instance_fn = load_run_instance_fn or load_run_instance for stats in iterate_epochs(run_cls, interface, checkpoint_path, dump_run_instance_fn, load_run_instance_fn, 1, updater_fn): pass
[docs] class Trainer: """ Training entity. The `Trainer` object is where RL training happens. Typically, it can be located on a HPC cluster. """ def __init__(self, training_cls=cfg_obj.TRAINER, server_ip=cfg.SERVER_IP_FOR_TRAINER, server_port=cfg.PORT, password=cfg.PASSWORD, local_com_port=cfg.LOCAL_PORT_TRAINER, header_size=cfg.HEADER_SIZE, max_buf_len=cfg.BUFFER_SIZE, security=cfg.SECURITY, keys_dir=cfg.CREDENTIALS_DIRECTORY, hostname=cfg.HOSTNAME, model_path=cfg.MODEL_PATH_TRAINER, checkpoint_path=cfg.CHECKPOINT_PATH, dump_run_instance_fn: callable = None, load_run_instance_fn: callable = None, updater_fn: callable = None): """ Args: training_cls (type): training class (subclass of tmrl.training_offline.TrainingOffline) server_ip (str): ip of the central `Server` server_port (int): public port of the central `Server` password (str): password of the central `Server` local_com_port (int): port used by `tlspyo` for local communication header_size (int): number of bytes used for `tlspyo` headers max_buf_len (int): maximum number of messages queued by `tlspyo` security (str): `tlspyo security type` (None or "TLS") keys_dir (str): custom credentials directory for `tlspyo` TLS security hostname (str): custom TLS hostname model_path (str): path where a local copy of the model will be saved checkpoint_path: path where the `Trainer` will be checkpointed (`None` = no checkpointing) dump_run_instance_fn (callable): custom serializer (`None` = pickle.dump) load_run_instance_fn (callable): custom deserializer (`None` = pickle.load) updater_fn (callable): custom updater (`None` = no updater). If provided, this must be a function \ that takes a checkpoint and training_cls as argument and returns an updated checkpoint. \ The updater is called after a checkpoint is loaded, e.g., to update your checkpoint with new arguments. """ self.checkpoint_path = checkpoint_path self.dump_run_instance_fn = dump_run_instance_fn self.load_run_instance_fn = load_run_instance_fn self.updater_fn = updater_fn self.training_cls = training_cls self.interface = TrainerInterface(server_ip=server_ip, server_port=server_port, password=password, local_com_port=local_com_port, header_size=header_size, max_buf_len=max_buf_len, security=security, keys_dir=keys_dir, hostname=hostname, model_path=model_path)
[docs] def run(self): """ Runs training. """ run(interface=self.interface, run_cls=self.training_cls, checkpoint_path=self.checkpoint_path, dump_run_instance_fn=self.dump_run_instance_fn, load_run_instance_fn=self.load_run_instance_fn, updater_fn=self.updater_fn)
[docs] def run_with_wandb(self, entity=cfg.WANDB_ENTITY, project=cfg.WANDB_PROJECT, run_id=cfg.WANDB_RUN_ID, key=None): """ Runs training while logging metrics to wandb_. .. _wandb: https://wandb.ai Args: entity (str): wandb entity project (str): wandb project run_id (str): name of the run key (str): wandb API key """ if key is not None: os.environ['WANDB_API_KEY'] = key run_with_wandb(entity=entity, project=project, run_id=run_id, interface=self.interface, run_cls=self.training_cls, checkpoint_path=self.checkpoint_path, dump_run_instance_fn=self.dump_run_instance_fn, load_run_instance_fn=self.load_run_instance_fn, updater_fn=self.updater_fn)
# ROLLOUT WORKER: ===================================
[docs] class RolloutWorker: """Actor. A `RolloutWorker` deploys the current policy in the environment. A `RolloutWorker` may connect to a `Server` to which it sends buffered experience. Alternatively, it may exist in standalone mode for deployment. """ def __init__( self, env_cls, actor_module_cls, sample_compressor: callable = None, device="cpu", max_samples_per_episode=np.inf, model_path=cfg.MODEL_PATH_WORKER, obs_preprocessor: callable = None, crc_debug=False, model_path_history=cfg.MODEL_PATH_SAVE_HISTORY, model_history=cfg.MODEL_HISTORY, standalone=False, server_ip=None, server_port=cfg.PORT, password=cfg.PASSWORD, local_port=cfg.LOCAL_PORT_WORKER, header_size=cfg.HEADER_SIZE, max_buf_len=cfg.BUFFER_SIZE, security=cfg.SECURITY, keys_dir=cfg.CREDENTIALS_DIRECTORY, hostname=cfg.HOSTNAME ): """ Args: env_cls (type): class of the Gymnasium environment (subclass of tmrl.envs.GenericGymEnv) actor_module_cls (type): class of the module containing the policy (subclass of tmrl.actor.ActorModule) sample_compressor (callable): compressor for sending samples over the Internet; \ when not `None`, `sample_compressor` must be a function that takes the following arguments: \ (prev_act, obs, rew, terminated, truncated, info), and that returns them (modified) in the same order: \ when not `None`, a `sample_compressor` works with a corresponding decompression scheme in the `Memory` class device (str): device on which the policy is running max_samples_per_episode (int): if an episode gets longer than this, it is reset model_path (str): path where a local copy of the policy will be stored obs_preprocessor (callable): utility for modifying observations retrieved from the environment; \ when not `None`, `obs_preprocessor` must be a function that takes an observation as input and outputs the \ modified observation crc_debug (bool): useful for debugging custom pipelines; leave to False otherwise model_path_history (str): (include the filename but omit .tmod) path to the saved history of policies; \ we recommend you leave this to the default model_history (int): policies are saved every `model_history` new policies (0: not saved) standalone (bool): if True, the worker will not try to connect to a server server_ip (str): ip of the central server server_port (int): public port of the central server password (str): tlspyo password local_port (int): tlspyo local communication port; usually, leave this to the default header_size (int): tlspyo header size (bytes) max_buf_len (int): tlspyo max number of messages in buffer security (str): tlspyo security type (None or "TLS") keys_dir (str): tlspyo credentials directory; usually, leave this to the default hostname (str): tlspyo hostname; usually, leave this to the default """ self.obs_preprocessor = obs_preprocessor self.get_local_buffer_sample = sample_compressor self.env = env_cls() obs_space = self.env.observation_space act_space = self.env.action_space self.model_path = model_path self.model_path_history = model_path_history self.device = device self.actor = actor_module_cls(observation_space=obs_space, action_space=act_space).to_device(self.device) self.standalone = standalone if os.path.isfile(self.model_path): logging.debug(f"Loading model from {self.model_path}") self.actor = self.actor.load(self.model_path, device=self.device) else: logging.debug(f"No model found at {self.model_path}") self.buffer = Buffer() self.max_samples_per_episode = max_samples_per_episode self.crc_debug = crc_debug self.model_history = model_history self._cur_hist_cpt = 0 self.model_cpt = 0 self.debug_ts_cpt = 0 self.debug_ts_res_cpt = 0 self.server_ip = server_ip if server_ip is not None else '127.0.0.1' print_with_timestamp(f"server IP: {self.server_ip}") if not self.standalone: self.__endpoint = Endpoint(ip_server=self.server_ip, port=server_port, password=password, groups="workers", local_com_port=local_port, header_size=header_size, max_buf_len=max_buf_len, security=security, keys_dir=keys_dir, hostname=hostname, deserializer_mode="synchronous") else: self.__endpoint = None
[docs] def act(self, obs, test=False): """ Select an action based on observation `obs` Args: obs (nested structure): observation test (bool): directly passed to the `act()` method of the `ActorModule` Returns: numpy.array: action computed by the `ActorModule` """ # if self.obs_preprocessor is not None: # obs = self.obs_preprocessor(obs) action = self.actor.act_(obs, test=test) return action
[docs] def reset(self, collect_samples): """ Starts a new episode. Args: collect_samples (bool): if True, samples are buffered and sent to the `Server` Returns: Tuple: (nested structure: observation retrieved from the environment, dict: information retrieved from the environment) """ obs = None try: # Faster than hasattr() in real-time environments act = self.env.unwrapped.default_action # .astype(np.float32) except AttributeError: # In non-real-time environments, act is None on reset act = None new_obs, info = self.env.reset() if self.obs_preprocessor is not None: new_obs = self.obs_preprocessor(new_obs) rew = 0.0 terminated, truncated = False, False if collect_samples: if self.crc_debug: self.debug_ts_cpt += 1 self.debug_ts_res_cpt = 0 info['crc_sample'] = (obs, act, new_obs, rew, terminated, truncated) info['crc_sample_ts'] = (self.debug_ts_cpt, self.debug_ts_res_cpt) if self.get_local_buffer_sample: sample = self.get_local_buffer_sample(act, new_obs, rew, terminated, truncated, info) else: sample = act, new_obs, rew, terminated, truncated, info self.buffer.append_sample(sample) return new_obs, info
[docs] def step(self, obs, test, collect_samples, last_step=False): """ Performs a full RL transition. A full RL transition is `obs` -> `act` -> `new_obs`, `rew`, `terminated`, `truncated`, `info`. Note that, in the Real-Time RL setting, `act` is appended to a buffer which is part of `new_obs`. This is because is does not directly affect the new observation, due to real-time delays. Args: obs (nested structure): previous observation test (bool): passed to the `act()` method of the `ActorModule` collect_samples (bool): if True, samples are buffered and sent to the `Server` last_step (bool): if True and `terminated` is False, `truncated` will be set to True Returns: Tuple: (nested structure: new observation, float: new reward, bool: episode termination signal, bool: episode truncation signal, dict: information dictionary) """ act = self.act(obs, test=test) new_obs, rew, terminated, truncated, info = self.env.step(act) if self.obs_preprocessor is not None: new_obs = self.obs_preprocessor(new_obs) if collect_samples: if last_step and not terminated: truncated = True if self.crc_debug: self.debug_ts_cpt += 1 self.debug_ts_res_cpt += 1 info['crc_sample'] = (obs, act, new_obs, rew, terminated, truncated) info['crc_sample_ts'] = (self.debug_ts_cpt, self.debug_ts_res_cpt) if self.get_local_buffer_sample: sample = self.get_local_buffer_sample(act, new_obs, rew, terminated, truncated, info) else: sample = act, new_obs, rew, terminated, truncated, info self.buffer.append_sample(sample) # CAUTION: in the buffer, act is for the PREVIOUS transition (act, obs(act)) return new_obs, rew, terminated, truncated, info
[docs] def collect_train_episode(self, max_samples=None): """ Collects a maximum of `max_samples` training transitions (from reset to terminated or truncated) This method stores the episode and the training return in the local `Buffer` of the worker for sending to the `Server`. Args: max_samples (int): if the environment is not `terminated` after `max_samples` time steps, it is forcefully reset and `truncated` is set to True. """ if max_samples is None: max_samples = self.max_samples_per_episode iterator = range(max_samples) if max_samples != np.inf else itertools.count() ret = 0.0 steps = 0 obs, info = self.reset(collect_samples=True) for i in iterator: obs, rew, terminated, truncated, info = self.step(obs=obs, test=False, collect_samples=True, last_step=i == max_samples - 1) ret += rew steps += 1 if terminated or truncated: break self.buffer.stat_train_return = ret self.buffer.stat_train_steps = steps
[docs] def run_episodes(self, max_samples_per_episode=None, nb_episodes=np.inf, train=False): """ Runs `nb_episodes` episodes. Args: max_samples_per_episode (int): same as run_episode nb_episodes (int): total number of episodes to collect train (bool): same as run_episode """ if max_samples_per_episode is None: max_samples_per_episode = self.max_samples_per_episode iterator = range(nb_episodes) if nb_episodes != np.inf else itertools.count() for _ in iterator: self.run_episode(max_samples_per_episode, train=train)
[docs] def run_episode(self, max_samples=None, train=False): """ Collects a maximum of `max_samples` test transitions (from reset to terminated or truncated). Args: max_samples (int): At most `max_samples` samples are collected per episode. If the episode is longer, it is forcefully reset and `truncated` is set to True. train (bool): whether the episode is a training or a test episode. `step` is called with `test=not train`. """ if max_samples is None: max_samples = self.max_samples_per_episode iterator = range(max_samples) if max_samples != np.inf else itertools.count() ret = 0.0 steps = 0 obs, info = self.reset(collect_samples=False) for _ in iterator: obs, rew, terminated, truncated, info = self.step(obs=obs, test=not train, collect_samples=False) ret += rew steps += 1 if terminated or truncated: break self.buffer.stat_test_return = ret self.buffer.stat_test_steps = steps
[docs] def run(self, test_episode_interval=0, nb_episodes=np.inf, verbose=True, expert=False): """ Runs the worker for `nb_episodes` episodes. This method sends episodes continuously to the Server, and checks for new weights between episodes. For synchronous or more fine-grained sampling, use synchronous or lower-level APIs. For deployment, use `run_episodes` rather than `run`. Args: test_episode_interval (int): a test episode is collected for every `test_episode_interval` train episodes; set to 0 to not collect test episodes. nb_episodes (int): maximum number of train episodes to collect. verbose (bool): whether to log INFO messages. expert (bool): experts send training samples without updating their model nor running test episodes. """ iterator = range(nb_episodes) if nb_episodes != np.inf else itertools.count() if expert: if not verbose: for _ in iterator: self.collect_train_episode(self.max_samples_per_episode) self.send_and_clear_buffer() self.ignore_actor_weights() else: for _ in iterator: print_with_timestamp("collecting expert episode") self.collect_train_episode(self.max_samples_per_episode) print_with_timestamp("copying buffer for sending") self.send_and_clear_buffer() self.ignore_actor_weights() elif not verbose: if not test_episode_interval: for _ in iterator: self.collect_train_episode(self.max_samples_per_episode) self.send_and_clear_buffer() self.update_actor_weights(verbose=False) else: for episode in iterator: if episode % test_episode_interval == 0 and not self.crc_debug: self.run_episode(self.max_samples_per_episode, train=False) self.collect_train_episode(self.max_samples_per_episode) self.send_and_clear_buffer() self.update_actor_weights(verbose=False) else: for episode in iterator: if test_episode_interval and episode % test_episode_interval == 0 and not self.crc_debug: print_with_timestamp("running test episode") self.run_episode(self.max_samples_per_episode, train=False) print_with_timestamp("collecting train episode") self.collect_train_episode(self.max_samples_per_episode) print_with_timestamp("copying buffer for sending") self.send_and_clear_buffer() print_with_timestamp("checking for new weights") self.update_actor_weights(verbose=True)
[docs] def run_synchronous(self, test_episode_interval=0, nb_steps=np.inf, initial_steps=1, max_steps_per_update=np.inf, end_episodes=True, verbose=False): """ Collects `nb_steps` steps while synchronizing with the Trainer. This method is useful for traditional (non-real-time) environments that can be stepped fast. It also works for rtgym environments with `wait_on_done` enabled, just set `end_episodes` to `True`. Note: This method does not collect test episodes. Periodically use `run_episode(train=False)` if you wish to. Args: test_episode_interval (int): a test episode is collected for every `test_episode_interval` train episodes; set to 0 to not collect test episodes. NB: `end_episodes` must be `True` to collect test episodes. nb_steps (int): total number of steps to collect (after `initial_steps`). initial_steps (int): initial number of steps to collect before waiting for the first model update. max_steps_per_update (float): maximum number of steps to collect per model received from the Server (this can be a non-integer ratio). end_episodes (bool): when True, waits for episodes to end before sending samples and waiting for updates. When False (default), pauses whenever the max_steps_per_update ratio is exceeded. verbose (bool): whether to log INFO messages. """ # collect initial samples if verbose: logging.info(f"Collecting {initial_steps} initial steps") iteration = 0 done = False while iteration < initial_steps: steps = 0 ret = 0.0 # reset obs, info = self.reset(collect_samples=True) done = False iteration += 1 # episode while not done and (end_episodes or iteration < initial_steps): # step obs, rew, terminated, truncated, info = self.step(obs=obs, test=False, collect_samples=True, last_step=steps == self.max_samples_per_episode - 1) iteration += 1 steps += 1 ret += rew done = terminated or truncated # send the collected samples to the Server self.buffer.stat_train_return = ret self.buffer.stat_train_steps = steps if verbose: logging.info(f"Sending buffer (initial steps)") self.send_and_clear_buffer() i_model = 1 # wait for the first updated model if required here ratio = (iteration + 1) / i_model while ratio > max_steps_per_update: if verbose: logging.info(f"Ratio {ratio} > {max_steps_per_update}, sending buffer checking updates") self.send_and_clear_buffer() i_model += self.update_actor_weights(verbose=verbose, blocking=True) ratio = (iteration + 1) / i_model # collect further samples while synchronizing with the Trainer iteration = 0 episode = 0 steps = 0 ret = 0.0 while iteration < nb_steps: if done: # test episode if test_episode_interval > 0 and episode % test_episode_interval == 0 and end_episodes: if verbose: print_with_timestamp("running test episode") self.run_episode(self.max_samples_per_episode, train=False) # reset obs, info = self.reset(collect_samples=True) done = False iteration += 1 steps = 0 ret = 0.0 episode += 1 while not done and (end_episodes or ratio <= max_steps_per_update): # step obs, rew, terminated, truncated, info = self.step(obs=obs, test=False, collect_samples=True, last_step=steps == self.max_samples_per_episode - 1) iteration += 1 steps += 1 ret += rew done = terminated or truncated if not end_episodes: # check model and send samples after each step ratio = (iteration + 1) / i_model while ratio > max_steps_per_update: if verbose: logging.info(f"Ratio {ratio} > {max_steps_per_update}, sending buffer checking updates (no eoe)") if not done: if verbose: logging.info(f"Sending buffer (no eoe)") self.send_and_clear_buffer() i_model += self.update_actor_weights(verbose=verbose, blocking=True) ratio = (iteration + 1) / i_model if end_episodes: # check model and send samples only after episodes end ratio = (iteration + 1) / i_model while ratio > max_steps_per_update: if verbose: logging.info( f"Ratio {ratio} > {max_steps_per_update}, sending buffer checking updates (eoe)") if not done: if verbose: logging.info(f"Sending buffer (eoe)") self.send_and_clear_buffer() i_model += self.update_actor_weights(verbose=verbose, blocking=True) ratio = (iteration + 1) / i_model self.buffer.stat_train_return = ret self.buffer.stat_train_steps = steps if verbose: logging.info(f"Sending buffer - DEBUG ratio {ratio} iteration {iteration} i_model {i_model}") self.send_and_clear_buffer()
[docs] def run_env_benchmark(self, nb_steps, test=False, verbose=True): """ Benchmarks the environment. This method is only compatible with rtgym_ environments. Furthermore, the `"benchmark"` option of the rtgym configuration dictionary must be set to `True`. .. _rtgym: https://github.com/yannbouteiller/rtgym Args: nb_steps (int): number of steps to perform to compute the benchmark test (int): whether the actor is called in test or train mode verbose (bool): whether to log INFO messages """ if nb_steps == np.inf or nb_steps < 0: raise RuntimeError(f"Invalid number of steps: {nb_steps}") obs, info = self.reset(collect_samples=False) for _ in range(nb_steps): obs, rew, terminated, truncated, info = self.step(obs=obs, test=test, collect_samples=False) if terminated or truncated: break res = self.env.unwrapped.benchmarks() if verbose: print_with_timestamp(f"Benchmark results:\n{res}") return res
[docs] def send_and_clear_buffer(self): """ Sends the buffered samples to the `Server`. """ self.__endpoint.produce(self.buffer, "trainers") self.buffer.clear()
[docs] def update_actor_weights(self, verbose=True, blocking=False): """ Updates the actor with new weights received from the `Server` when available. Args: verbose (bool): whether to log INFO messages. blocking (bool): if True, blocks until a model is received; otherwise, can be a no-op. Returns: int: number of new actor models received from the Server (the latest is used). """ weights_list = self.__endpoint.receive_all(blocking=blocking) nb_received = len(weights_list) if nb_received > 0: weights = weights_list[-1] with open(self.model_path, 'wb') as f: f.write(weights) if self.model_history: self._cur_hist_cpt += 1 if self._cur_hist_cpt == self.model_history: x = datetime.datetime.now() with open(self.model_path_history + str(x.strftime("%d_%m_%Y_%H_%M_%S")) + ".tmod", 'wb') as f: f.write(weights) self._cur_hist_cpt = 0 if verbose: print_with_timestamp("model weights saved in history") self.actor = self.actor.load(self.model_path, device=self.device) if verbose: print_with_timestamp("model weights have been updated") return nb_received
[docs] def ignore_actor_weights(self): """ Clears the buffer of weights received from the `Server`. This is useful for expert RolloutWorkers, because all RolloutWorkers receive weights. Returns: int: number of new (ignored) actor models received from the Server. """ weights_list = self.__endpoint.receive_all(blocking=False) nb_received = len(weights_list) return nb_received