tmrl.networking module

class tmrl.networking.Buffer(maxlen=cfg.BUFFERS_MAXLEN)[source]

Bases: object

Buffer of training samples.

Server, RolloutWorker and Trainer all have their own Buffer to store and send training samples.

Samples are tuples of the form (act, new_obs, rew, terminated, truncated, info)

Parameters:

maxlen (int) – buffer length

append_sample(sample)[source]

Appends sample to the buffer.

Parameters:

sample (Tuple) – a training sample of the form (act, new_obs, rew, terminated, truncated, info)

clear()[source]

Clears the buffer but keeps train and test returns.

class tmrl.networking.RolloutWorker(env_cls, actor_module_cls, sample_compressor: callable | None = None, device='cpu', max_samples_per_episode=np.inf, model_path=cfg.MODEL_PATH_WORKER, obs_preprocessor: callable | None = None, crc_debug=False, model_path_history=cfg.MODEL_PATH_SAVE_HISTORY, model_history=cfg.MODEL_HISTORY, standalone=False, server_ip=None, server_port=cfg.PORT, password=cfg.PASSWORD, local_port=cfg.LOCAL_PORT_WORKER, header_size=cfg.HEADER_SIZE, max_buf_len=cfg.BUFFER_SIZE, security=cfg.SECURITY, keys_dir=cfg.CREDENTIALS_DIRECTORY, hostname=cfg.HOSTNAME)[source]

Bases: object

Actor.

A RolloutWorker deploys the current policy in the environment. A RolloutWorker may connect to a Server to which it sends buffered experience. Alternatively, it may exist in standalone mode for deployment.

Parameters:
  • env_cls (type) – class of the Gymnasium environment (subclass of tmrl.envs.GenericGymEnv)

  • actor_module_cls (type) – class of the module containing the policy (subclass of tmrl.actor.ActorModule)

  • sample_compressor (callable) – compressor for sending samples over the Internet; when not None, sample_compressor must be a function that takes the following arguments: (prev_act, obs, rew, terminated, truncated, info), and that returns them (modified) in the same order: when not None, a sample_compressor works with a corresponding decompression scheme in the Memory class

  • device (str) – device on which the policy is running

  • max_samples_per_episode (int) – if an episode gets longer than this, it is reset

  • model_path (str) – path where a local copy of the policy will be stored

  • obs_preprocessor (callable) – utility for modifying observations retrieved from the environment; when not None, obs_preprocessor must be a function that takes an observation as input and outputs the modified observation

  • crc_debug (bool) – useful for debugging custom pipelines; leave to False otherwise

  • model_path_history (str) – (include the filename but omit .tmod) path to the saved history of policies; we recommend you leave this to the default

  • model_history (int) – policies are saved every model_history new policies (0: not saved)

  • standalone (bool) – if True, the worker will not try to connect to a server

  • server_ip (str) – ip of the central server

  • server_port (int) – public port of the central server

  • password (str) – tlspyo password

  • local_port (int) – tlspyo local communication port; usually, leave this to the default

  • header_size (int) – tlspyo header size (bytes)

  • max_buf_len (int) – tlspyo max number of messages in buffer

  • security (str) – tlspyo security type (None or “TLS”)

  • keys_dir (str) – tlspyo credentials directory; usually, leave this to the default

  • hostname (str) – tlspyo hostname; usually, leave this to the default

act(obs, test=False)[source]

Select an action based on observation obs

Parameters:
  • obs (nested structure) – observation

  • test (bool) – directly passed to the act() method of the ActorModule

Returns:

action computed by the ActorModule

Return type:

numpy.array

collect_train_episode(max_samples=None)[source]

Collects a maximum of max_samples training transitions (from reset to terminated or truncated)

This method stores the episode and the training return in the local Buffer of the worker for sending to the Server.

Parameters:

max_samples (int) – if the environment is not terminated after max_samples time steps, it is forcefully reset and truncated is set to True.

ignore_actor_weights()[source]

Clears the buffer of weights received from the Server.

This is useful for expert RolloutWorkers, because all RolloutWorkers receive weights.

Returns:

number of new (ignored) actor models received from the Server.

Return type:

int

reset(collect_samples)[source]

Starts a new episode.

Parameters:

collect_samples (bool) – if True, samples are buffered and sent to the Server

Returns:

(nested structure: observation retrieved from the environment, dict: information retrieved from the environment)

Return type:

Tuple

run(test_episode_interval=0, nb_episodes=np.inf, verbose=True, expert=False)[source]

Runs the worker for nb_episodes episodes.

This method sends episodes continuously to the Server, and checks for new weights between episodes. For synchronous or more fine-grained sampling, use synchronous or lower-level APIs. For deployment, use run_episodes rather than run.

Parameters:
  • test_episode_interval (int) – a test episode is collected for every test_episode_interval train episodes; set to 0 to not collect test episodes.

  • nb_episodes (int) – maximum number of train episodes to collect.

  • verbose (bool) – whether to log INFO messages.

  • expert (bool) – experts send training samples without updating their model nor running test episodes.

run_env_benchmark(nb_steps, test=False, verbose=True)[source]

Benchmarks the environment.

This method is only compatible with rtgym environments. Furthermore, the “benchmark” option of the rtgym configuration dictionary must be set to True.

Parameters:
  • nb_steps (int) – number of steps to perform to compute the benchmark

  • test (int) – whether the actor is called in test or train mode

  • verbose (bool) – whether to log INFO messages

run_episode(max_samples=None, train=False)[source]

Collects a maximum of max_samples test transitions (from reset to terminated or truncated).

Parameters:
  • max_samples (int) – At most max_samples samples are collected per episode. If the episode is longer, it is forcefully reset and truncated is set to True.

  • train (bool) – whether the episode is a training or a test episode. step is called with test=not train.

run_episodes(max_samples_per_episode=None, nb_episodes=np.inf, train=False)[source]

Runs nb_episodes episodes.

Parameters:
  • max_samples_per_episode (int) – same as run_episode

  • nb_episodes (int) – total number of episodes to collect

  • train (bool) – same as run_episode

run_synchronous(test_episode_interval=0, nb_steps=np.inf, initial_steps=1, max_steps_per_update=np.inf, end_episodes=True, verbose=False)[source]

Collects nb_steps steps while synchronizing with the Trainer.

This method is useful for traditional (non-real-time) environments that can be stepped fast. It also works for rtgym environments with wait_on_done enabled, just set end_episodes to True.

Note: This method does not collect test episodes. Periodically use run_episode(train=False) if you wish to.

Parameters:
  • test_episode_interval (int) – a test episode is collected for every test_episode_interval train episodes; set to 0 to not collect test episodes. NB: end_episodes must be True to collect test episodes.

  • nb_steps (int) – total number of steps to collect (after initial_steps).

  • initial_steps (int) – initial number of steps to collect before waiting for the first model update.

  • max_steps_per_update (float) – maximum number of steps to collect per model received from the Server (this can be a non-integer ratio).

  • end_episodes (bool) – when True, waits for episodes to end before sending samples and waiting for updates. When False (default), pauses whenever the max_steps_per_update ratio is exceeded.

  • verbose (bool) – whether to log INFO messages.

send_and_clear_buffer()[source]

Sends the buffered samples to the Server.

step(obs, test, collect_samples, last_step=False)[source]

Performs a full RL transition.

A full RL transition is obs -> act -> new_obs, rew, terminated, truncated, info. Note that, in the Real-Time RL setting, act is appended to a buffer which is part of new_obs. This is because is does not directly affect the new observation, due to real-time delays.

Parameters:
  • obs (nested structure) – previous observation

  • test (bool) – passed to the act() method of the ActorModule

  • collect_samples (bool) – if True, samples are buffered and sent to the Server

  • last_step (bool) – if True and terminated is False, truncated will be set to True

Returns:

(nested structure: new observation, float: new reward, bool: episode termination signal, bool: episode truncation signal, dict: information dictionary)

Return type:

Tuple

update_actor_weights(verbose=True, blocking=False)[source]

Updates the actor with new weights received from the Server when available.

Parameters:
  • verbose (bool) – whether to log INFO messages.

  • blocking (bool) – if True, blocks until a model is received; otherwise, can be a no-op.

Returns:

number of new actor models received from the Server (the latest is used).

Return type:

int

class tmrl.networking.Server(port=cfg.PORT, password=cfg.PASSWORD, local_port=cfg.LOCAL_PORT_SERVER, header_size=cfg.HEADER_SIZE, security=cfg.SECURITY, keys_dir=cfg.CREDENTIALS_DIRECTORY, max_workers=cfg.NB_WORKERS)[source]

Bases: object

Central server.

The Server lets 1 Trainer and n RolloutWorkers connect. It buffers experiences sent by workers and periodically sends these to the trainer. It also receives the weights from the trainer and broadcasts these to the connected workers.

Parameters:
  • port (int) – tlspyo public port

  • password (str) – tlspyo password

  • local_port (int) – tlspyo local communication port

  • header_size (int) – tlspyo header size (bytes)

  • security (Union[str, None]) – tlspyo security type (None or “TLS”)

  • keys_dir (str) – tlspyo credentials directory

  • max_workers (int) – max number of accepted workers

class tmrl.networking.Trainer(training_cls=cfg_obj.TRAINER, server_ip=cfg.SERVER_IP_FOR_TRAINER, server_port=cfg.PORT, password=cfg.PASSWORD, local_com_port=cfg.LOCAL_PORT_TRAINER, header_size=cfg.HEADER_SIZE, max_buf_len=cfg.BUFFER_SIZE, security=cfg.SECURITY, keys_dir=cfg.CREDENTIALS_DIRECTORY, hostname=cfg.HOSTNAME, model_path=cfg.MODEL_PATH_TRAINER, checkpoint_path=cfg.CHECKPOINT_PATH, dump_run_instance_fn: callable | None = None, load_run_instance_fn: callable | None = None, updater_fn: callable | None = None)[source]

Bases: object

Training entity.

The Trainer object is where RL training happens. Typically, it can be located on a HPC cluster.

Parameters:
  • training_cls (type) – training class (subclass of tmrl.training_offline.TrainingOffline)

  • server_ip (str) – ip of the central Server

  • server_port (int) – public port of the central Server

  • password (str) – password of the central Server

  • local_com_port (int) – port used by tlspyo for local communication

  • header_size (int) – number of bytes used for tlspyo headers

  • max_buf_len (int) – maximum number of messages queued by tlspyo

  • security (str) – tlspyo security type (None or “TLS”)

  • keys_dir (str) – custom credentials directory for tlspyo TLS security

  • hostname (str) – custom TLS hostname

  • model_path (str) – path where a local copy of the model will be saved

  • checkpoint_path – path where the Trainer will be checkpointed (None = no checkpointing)

  • dump_run_instance_fn (callable) – custom serializer (None = pickle.dump)

  • load_run_instance_fn (callable) – custom deserializer (None = pickle.load)

  • updater_fn (callable) – custom updater (None = no updater). If provided, this must be a function that takes a checkpoint and training_cls as argument and returns an updated checkpoint. The updater is called after a checkpoint is loaded, e.g., to update your checkpoint with new arguments.

run()[source]

Runs training.

run_with_wandb(entity=cfg.WANDB_ENTITY, project=cfg.WANDB_PROJECT, run_id=cfg.WANDB_RUN_ID, key=None)[source]

Runs training while logging metrics to wandb.

Parameters:
  • entity (str) – wandb entity

  • project (str) – wandb project

  • run_id (str) – name of the run

  • key (str) – wandb API key