From f756b810067a842f80ae1824fa54dd77a356f7a6 Mon Sep 17 00:00:00 2001 From: Homanga Bharadhwaj Date: Fri, 10 Sep 2021 20:36:05 +0000 Subject: [PATCH] mbrl distractors --- dbc.py | 657 +++++++++++++++++++++++ deepmdp.py | 650 +++++++++++++++++++++++ dreamer_contrastive.py | 658 +++++++++++++++++++++++ dreamer_contrastive_inverse.py | 799 ++++++++++++++++++++++++++++ models.py | 609 +++++++++++++++------ tools.py | 747 ++++++++++++++------------ tools_inv.py | 522 ++++++++++++++++++ wrappers.py | 945 +++++++++++++++++++-------------- 8 files changed, 4692 insertions(+), 895 deletions(-) create mode 100644 dbc.py create mode 100644 deepmdp.py create mode 100644 dreamer_contrastive.py create mode 100644 dreamer_contrastive_inverse.py create mode 100644 tools_inv.py diff --git a/dbc.py b/dbc.py new file mode 100644 index 0000000..93f86ab --- /dev/null +++ b/dbc.py @@ -0,0 +1,657 @@ +import sys +import pathlib +sys.path.append(str(pathlib.Path(__file__).parent)) + +import wrappers +import tools +import models +from tensorflow_probability import distributions as tfd +from tensorflow.keras.mixed_precision import experimental as prec +import tensorflow as tf +import numpy as np +import argparse +import collections +import functools +import json +import os +import pathlib +import sys +import time +import soft_actor_critic +from tensorflow.python.ops.numpy_ops import np_config +np_config.enable_numpy_behavior() +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +# enable headless training on servers for mujoco +#os.environ['MUJOCO_GL'] = 'egl' + +tf.executing_eagerly() + +tf.get_logger().setLevel('ERROR') + + +sys.path.append(str(pathlib.Path(__file__).parent)) + + +def define_config(): + config = tools.AttrDict() + # General. + config.logdir = pathlib.Path('.') + config.seed = 0 + config.steps = 5e6 + config.eval_every = 1e4 + config.log_every = 1e3 + config.log_scalars = True + config.log_images = True + config.gpu_growth = True + config.precision = 16 + # Environment. + config.task = 'dmc_walker_walk' + config.envs = 1 + config.difficulty = 'none' + config.parallel = 'none' + config.action_repeat = 2 + config.time_limit = 1000 + config.prefill = 5000 + config.eval_noise = 0.0 + config.clip_rewards = 'none' + # Model. + config.deter_size = 200 + config.stoch_size = 30 + config.num_units = 400 + config.dense_act = 'elu' + config.cnn_act = 'relu' + config.cnn_depth = 32 + config.pcont = False + config.free_nats = 3.0 + config.kl_scale = 1.0 + config.pcont_scale = 10.0 + config.weight_decay = 0.0 + config.weight_decay_pattern = r'.*' + # Training. + config.batch_size = 50 + config.batch_length = 50 + config.train_every = 1000 + config.train_steps = 100 + config.pretrain = 100 + config.model_lr = 6e-4 + config.value_lr = 8e-5 + config.actor_lr = 8e-5 + config.grad_clip = 100.0 + config.dataset_balance = False + # Behavior. + config.discount = 0.99 + config.disclam = 0.95 + config.horizon = 15 + config.action_dist = 'tanh_normal' + config.action_init_std = 5.0 + config.expl = 'additive_gaussian' + config.expl_amount = 0.3 + config.expl_decay = 0.0 + config.expl_min = 0.0 + config.log_imgs = True + + # natural or not + config.natural = True + config.custom_video = False + + # obs model + config.obs_model = 'dbc' + + + # use trajectory optimization + config.trajectory_opt = False + config.traj_opt_lr = 0.003 + config.num_samples = 20 + return config + + +class Dreamer(tools.Module): + + def __init__(self, config, datadir, actspace, writer): + self._c = config + self._actspace = actspace + self._actdim = actspace.n if hasattr( + actspace, 'n') else actspace.shape[0] + self._writer = writer + self._random = np.random.RandomState(config.seed) + with tf.device('cpu:0'): + self._step = tf.Variable(count_steps( + datadir, config), dtype=tf.int64) + self._should_pretrain = tools.Once() + self._should_train = tools.Every(config.train_every) + self._should_log = tools.Every(config.log_every) + self._last_log = None + self._last_time = time.time() + self._metrics = collections.defaultdict(tf.metrics.Mean) + self._metrics['expl_amount'] # Create variable for checkpoint. + self._float = prec.global_policy().compute_dtype + self._dataset = iter(load_dataset(datadir, self._c)) + self._build_model() + + def __call__(self, obs, reset, state=None, training=True): + step = self._step.numpy().item() + tf.summary.experimental.set_step(step) + if state is not None and reset.any(): + mask = tf.cast(1 - reset, self._float)[:, None] + state = tf.nest.map_structure(lambda x: x * mask, state) + if self._should_train(step): + log = self._should_log(step) + n = self._c.pretrain if self._should_pretrain() else self._c.train_steps + print(f'Training for {n} steps.') + # with self._strategy.scope(): + for train_step in range(n): + log_images = self._c.log_images and log and train_step == 0 + self.train(next(self._dataset), log_images) + if log: + self._write_summaries() + action, state = self.policy(obs, state, training) + if training: + self._step.assign_add(len(reset) * self._c.action_repeat) + return action, state + + @tf.function + def policy(self, obs, state, training): + if state is None: + latent = self._dynamics.initial(len(obs['image'])) + action = tf.zeros((len(obs['image']), self._actdim), self._float) + else: + latent, action = state + embed = self._encode(preprocess(obs, self._c)) + latent, _ = self._dynamics.obs_step(latent, action, embed) + feat = self._dynamics.get_feat(latent) + + if self._c.trajectory_opt: + action = self._trajectory_optimization(latent) + else: + if training: + action = self._actor(feat).sample() + else: + action = self._actor(feat).mode() + + action = self._exploration(action, training) + state = (latent, action) + return action, state + + def load(self, filename): + super().load(filename) + self._should_pretrain() + + @tf.function() + def train(self, data, log_images=True): + self._train(data, log_images) + + def _train(self, data, log_images): + with tf.GradientTape() as model_tape: + embed = self._encode(data) + batch_size = embed.shape[0] + perm = np.random.permutation(batch_size) + embed2 = embed[perm] + + + post, prior = self._dynamics.observe(embed, data['action']) + post2, prior2 = self._dynamics.observe(embed2, data['action']) + + feat = self._dynamics.get_feat(post) + feat2 = self._dynamics.get_feat(post2) + + reward_pred = self._reward(feat) + reward_pred2 = self._reward(feat2) + + likes = tools.AttrDict() + likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) + + # if we use the generative observation model, we need to perform observation reconstruction + image_pred = self._decode(feat) + # compute the contrative loss directly + cont_loss = self._contrastive(feat, embed) + + + + if self._c.pcont: + pcont_pred = self._pcont(feat) + pcont_target = self._c.discount * data['discount'] + likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) + likes.pcont *= self._c.pcont_scale + + prior_dist = self._dynamics.get_dist(prior) + post_dist = self._dynamics.get_dist(post) + post_dist2 = self._dynamics.get_dist(post2) + mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE) + + + z_dist = tf.math.reduce_mean(mae(embed,embed2), axis = 1) + # r_dist = tfd.kl_divergence(reward_pred,reward_pred2) + r_dist = mae(data['reward'],data['reward'][perm]) + + + # transition_dist = tfd.kl_divergence(post_dist, post_dist2) + + ## Wasserstein distance b/w state transition kernels + transition_dist = tf.math.reduce_mean(tf.math.sqrt(tf.math.square(post['mean'] - post2['mean']) + tf.math.square(post['std'] - post2['std'])),axis=[1,2]) + + + bisim = r_dist + self._c.discount * transition_dist + dbc_loss = tf.math.square(z_dist - bisim) + # tf.print(z_dist) + # tf.print(r_dist) + # tf.print(transition_dist) + # tf.print(dbc_loss) + + # the contrastive / generative implementation of the observation model p(o|s) + if self._c.obs_model == 'generative': + likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) + elif self._c.obs_model == 'contrastive': + likes.image = tf.reduce_mean(cont_loss) + elif self._c.obs_model == 'dbc': + likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) + likes.image1 = tf.reduce_mean(dbc_loss) + + div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) + div = tf.maximum(div, self._c.free_nats) + model_loss = self._c.kl_scale * div - sum(likes.values()) + + + + + with tf.GradientTape() as actor_tape: + imag_feat = self._imagine_ahead(post) + reward = self._reward(imag_feat).mode() + if self._c.pcont: + pcont = self._pcont(imag_feat).mean() + else: + pcont = self._c.discount * tf.ones_like(reward) + value = self._value(imag_feat).mode() + returns = tools.lambda_return( + reward[:-1], value[:-1], pcont[:-1], + bootstrap=value[-1], lambda_=self._c.disclam, axis=0) + discount = tf.stop_gradient(tf.math.cumprod(tf.concat( + [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) + actor_loss = -tf.reduce_mean(discount * returns) + + with tf.GradientTape() as value_tape: + value_pred = self._value(imag_feat)[:-1] + target = tf.stop_gradient(returns) + value_loss = - \ + tf.reduce_mean(discount * value_pred.log_prob(target)) + + actor_norm = self._actor_opt(actor_tape, actor_loss) + value_norm = self._value_opt(value_tape, value_loss) + + + model_norm = self._model_opt(model_tape, model_loss) + states = tf.concat([post['stoch'], post['deter']], axis=-1) + rewards = data['reward'] + dones = tf.zeros_like(rewards) + actions = data['action'] + + + + if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: + if self._c.log_scalars: + self._scalar_summaries( + data, feat, prior_dist, post_dist, likes, div, + model_loss, value_loss, actor_loss, model_norm, value_norm, + actor_norm) + if tf.equal(log_images, True) and self._c.log_imgs: + self._image_summaries(data, embed, image_pred) + + def _build_model(self): + acts = dict( + elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish, + leaky_relu=tf.nn.leaky_relu) + cnn_act = acts[self._c.cnn_act] + act = acts[self._c.dense_act] + self._encode = models.ConvEncoder(self._c.cnn_depth, cnn_act) + self._dynamics = models.RSSM( + self._c.stoch_size, self._c.deter_size, self._c.deter_size) + self._decode = models.ConvDecoder(self._c.cnn_depth, cnn_act) + self._contrastive = models.ContrastiveObsModel(self._c.deter_size, + self._c.deter_size * 2) + self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) + if self._c.pcont: + self._pcont = models.DenseDecoder( + (), 3, self._c.num_units, 'binary', act=act) + self._value = models.DenseDecoder((), 3, self._c.num_units, act=act) + self._Qs = [models.QNetwork(3, self._c.num_units, act=act) for _ in range(self._c.num_Qs)] + self._actor = models.ActionDecoder( + self._actdim, 4, self._c.num_units, self._c.action_dist, + init_std=self._c.action_init_std, act=act) + model_modules = [self._encode, self._dynamics, + self._contrastive, self._reward, self._decode] + if self._c.pcont: + model_modules.append(self._pcont) + Optimizer = functools.partial( + tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip, + wdpattern=self._c.weight_decay_pattern) + self._model_opt = Optimizer('model', model_modules, self._c.model_lr) + self._value_opt = Optimizer('value', [self._value], self._c.value_lr) + self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) + self._q_opts = [Optimizer('qs', [qnet], self._c.value_lr) for qnet in self._Qs] + + + self.train(next(self._dataset)) + + def _exploration(self, action, training): + if training: + amount = self._c.expl_amount + if self._c.expl_decay: + amount *= 0.5 ** (tf.cast(self._step, + tf.float32) / self._c.expl_decay) + if self._c.expl_min: + amount = tf.maximum(self._c.expl_min, amount) + self._metrics['expl_amount'].update_state(amount) + elif self._c.eval_noise: + amount = self._c.eval_noise + else: + return action + if self._c.expl == 'additive_gaussian': + return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) + if self._c.expl == 'completely_random': + return tf.random.uniform(action.shape, -1, 1) + if self._c.expl == 'epsilon_greedy': + indices = tfd.Categorical(0 * action).sample() + return tf.where( + tf.random.uniform(action.shape[:1], 0, 1) < amount, + tf.one_hot(indices, action.shape[-1], dtype=self._float), + action) + raise NotImplementedError(self._c.expl) + + def _imagine_ahead(self, post): + if self._c.pcont: # Last step could be terminal. + post = {k: v[:, :-1] for k, v in post.items()} + + def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:])) + start = {k: flatten(v) for k, v in post.items()} + + def policy(state): return self._actor( + tf.stop_gradient(self._dynamics.get_feat(state))).sample() + states = tools.static_scan( + lambda prev, _: self._dynamics.img_step(prev, policy(prev)), + tf.range(self._c.horizon), start) + imag_feat = self._dynamics.get_feat(states) + return imag_feat + + def _trajectory_optimization(self, post): + def policy(state): return self._actor( + tf.stop_gradient(self._dynamics.get_feat(state))).sample() + + def repeat(x): + return tf.repeat(x, self._c.num_samples, axis=0) + + states, actions = tools.static_scan_action( + lambda prev, action, _: self._dynamics.img_step(prev, action), + lambda prev: policy(prev), + tf.range(self._c.horizon), post) + + feat = self._dynamics.get_feat(states) + reward = self._reward(feat).mode() + + if self._c.pcont: + pcont = self._pcont(feat).mean() + else: + pcont = self._c.discount * tf.ones_like(reward) + value = self._value(feat).mode() + + # compute the accumulated reward + returns = tools.lambda_return( + reward[:-1], value[:-1], pcont[:-1], + bootstrap=value[-1], lambda_=self._c.disclam, axis=0) + + accumulated_reward = returns[0, 0] + + # since the reward and latent dynamics are fully differentiable, we can backprop the gradients to update the actions + grad = tf.gradients(accumulated_reward, actions)[0] + act = actions + grad * self._c.traj_opt_lr + + return act + + + def _scalar_summaries( + self, data, feat, prior_dist, post_dist, likes, div, + model_loss, value_loss, actor_loss, model_norm, value_norm, + actor_norm): + self._metrics['model_grad_norm'].update_state(model_norm) + self._metrics['value_grad_norm'].update_state(value_norm) + self._metrics['actor_grad_norm'].update_state(actor_norm) + self._metrics['prior_ent'].update_state(prior_dist.entropy()) + self._metrics['post_ent'].update_state(post_dist.entropy()) + for name, logprob in likes.items(): + self._metrics[name + '_loss'].update_state(-logprob) + self._metrics['div'].update_state(div) + self._metrics['model_loss'].update_state(model_loss) + self._metrics['value_loss'].update_state(value_loss) + self._metrics['actor_loss'].update_state(actor_loss) + self._metrics['action_ent'].update_state(self._actor(feat).entropy()) + + def _image_summaries(self, data, embed, image_pred): + truth = data['image'][:6] + 0.5 + recon = image_pred.mode()[:6] + init, _ = self._dynamics.observe(embed[:6, :5], data['action'][:6, :5]) + init = {k: v[:, -1] for k, v in init.items()} + prior = self._dynamics.imagine(data['action'][:6, 5:], init) + openl = self._decode(self._dynamics.get_feat(prior)).mode() + model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) + error = (model - truth + 1) / 2 + openl = tf.concat([truth, model, error], 2) + tools.graph_summary( + self._writer, tools.video_summary, 'agent/openl', openl) + + def _write_summaries(self): + step = int(self._step.numpy()) + metrics = [(k, float(v.result())) for k, v in self._metrics.items()] + if self._last_log is not None: + duration = time.time() - self._last_time + self._last_time += duration + metrics.append(('fps', (step - self._last_log) / duration)) + self._last_log = step + [m.reset_states() for m in self._metrics.values()] + with (self._c.logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps({'step': step, **dict(metrics)}) + '\n') + [tf.summary.scalar('agent/' + k, m) for k, m in metrics] + print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics)) + self._writer.flush() + + +def preprocess(obs, config): + dtype = prec.global_policy().compute_dtype + obs = obs.copy() + with tf.device('cpu:0'): + obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 + clip_rewards = dict(none=lambda x: x, tanh=tf.tanh)[ + config.clip_rewards] + obs['reward'] = clip_rewards(obs['reward']) + return obs + + +def count_steps(datadir, config): + return tools.count_episodes(datadir)[1] * config.action_repeat + + +def load_dataset(directory, config): + episode = next(tools.load_episodes(directory, 1)) + types = {k: v.dtype for k, v in episode.items()} + shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()} + + def generator(): return tools.load_episodes( + directory, config.train_steps, config.batch_length, + config.dataset_balance) + dataset = tf.data.Dataset.from_generator(generator, types, shapes) + dataset = dataset.batch(config.batch_size, drop_remainder=True) + dataset = dataset.map(functools.partial(preprocess, config=config)) + dataset = dataset.prefetch(10) + return dataset + + +def summarize_episode(episode, config, datadir, writer, prefix): + episodes, steps = tools.count_episodes(datadir) + length = (len(episode['reward']) - 1) * config.action_repeat + ret = episode['reward'].sum() + print(f'{prefix.title()} episode of length {length} with return {ret:.1f}.') + metrics = [ + (f'{prefix}/return', float(episode['reward'].sum())), + (f'{prefix}/length', len(episode['reward']) - 1), + (f'episodes', episodes)] + step = count_steps(datadir, config) + with (config.logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps(dict([('step', step)] + metrics)) + '\n') + with writer.as_default(): # Env might run in a different thread. + tf.summary.experimental.set_step(step) + [tf.summary.scalar('sim/' + k, v) for k, v in metrics] + if prefix == 'test': + tools.video_summary(f'sim/{prefix}/video', episode['image'][None]) + + +def make_env(config, writer, prefix, datadir, train): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + if config.difficulty == 'none': + env = wrappers.DeepMindControl(task) + else: + env = wrappers.DeepMindControlDistraction(task,difficulty=config.difficulty) + env = wrappers.ActionRepeat(env, config.action_repeat) + env = wrappers.NormalizeActions(env) + if config.natural: + data = load_imgnet(train) + env = wrappers.NaturalMujoco(env, data) + elif config.custom_video: + import pickle + + with open('custom_video_jaco.pkl', 'rb') as file: + data = pickle.load(file) + env = wrappers.CustomMujoco(env, data) + + + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, (64, 64), grayscale=False, + life_done=True, sticky_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) + callbacks = [] + if train: + callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) + callbacks.append( + lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) + env = wrappers.Collect(env, callbacks, config.precision) + env = wrappers.RewardObs(env) + return env +def make_env_test(config, writer, prefix, datadir, train): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + if config.difficulty == 'none': + env = wrappers.DeepMindControl(task) + else: + env = wrappers.DeepMindControlDistraction(task,difficulty=config.difficulty) + env = wrappers.ActionRepeat(env, config.action_repeat) + env = wrappers.NormalizeActions(env) + if config.natural: + data = load_imgnet(train) + env = wrappers.NaturalMujoco(env, data) + elif config.custom_video: + import pickle + + with open('custom_video_jaco.pkl', 'rb') as file: + data = pickle.load(file) + env = wrappers.CustomMujoco(env, data) + + + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, (64, 64), grayscale=False, + life_done=True, sticky_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) + callbacks = [] + if train: + callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) + callbacks.append( + lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) + env = wrappers.Collect(env, callbacks, config.precision) + env = wrappers.RewardObs(env) + return env + + + +def main(config): + print('mainn') + if config.gpu_growth: + for gpu in tf.config.experimental.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(gpu, True) + assert config.precision in (16, 32), config.precision + if config.precision == 16: + prec.set_policy(prec.Policy('mixed_float16')) + config.steps = int(config.steps) + config.logdir.mkdir(parents=True, exist_ok=True) + print('Logdir', config.logdir) + + arg_dict = vars(config).copy() + del arg_dict['logdir'] + + # with open(os.path.join(config.logdir, 'args.json'), 'w') as fout: + # import json + # json.dump(arg_dict, fout) + + # Create environments. + datadir = config.logdir / 'episodes' + datadir.mkdir(parents=True, exist_ok=True) + writer = tf.summary.create_file_writer( + str(config.logdir), max_queue=1000, flush_millis=20000) + writer.set_as_default() + train_envs = [wrappers.Async(lambda: make_env( + config, writer, 'train', datadir, train=True), config.parallel) + for _ in range(config.envs)] + test_envs = [wrappers.Async(lambda: make_env_test( + config, writer, 'test', datadir, train=False), config.parallel) + for _ in range(config.envs)] + actspace = train_envs[0].action_space + + # Prefill dataset with random episodes. + step = count_steps(datadir, config) + prefill = max(0, config.prefill - step) + print(f'Prefill dataset with {prefill} steps.') + def random_agent(o, d, _): return ([actspace.sample() for _ in d], None) + tools.simulate(random_agent, train_envs, prefill / config.action_repeat) + writer.flush() + + # Train and regularly evaluate the agent. + step = count_steps(datadir, config) + print(f'Simulating agent for {config.steps-step} steps.') + agent = Dreamer(config, datadir, actspace, writer) + if (config.logdir / 'variables.pkl').exists(): + print('Load checkpoint.') + agent.load(config.logdir / 'variables.pkl') + state = None + while step < config.steps: + print('Start evaluation.') + tools.simulate( + functools.partial(agent, training=False), test_envs, episodes=1) + writer.flush() + print('Start collection.') + steps = config.eval_every // config.action_repeat + state = tools.simulate(agent, train_envs, steps, state=state) + step = count_steps(datadir, config) + agent.save(config.logdir / 'variables.pkl') + for env in train_envs + test_envs: + env.close() + + +#if __name__ == '__main__': + # try: + # import colored_traceback + # colored_traceback.add_hook() + # except ImportError: + # pass +parser = argparse.ArgumentParser() +for key, value in define_config().items(): + parser.add_argument( + f'--{key}', type=tools.args_type(value), default=value) +args = parser.parse_args() + +print('main') + +main(args) diff --git a/deepmdp.py b/deepmdp.py new file mode 100644 index 0000000..d5f316b --- /dev/null +++ b/deepmdp.py @@ -0,0 +1,650 @@ +import sys +import pathlib +sys.path.append(str(pathlib.Path(__file__).parent)) + +import wrappers +import tools +import models +from tensorflow_probability import distributions as tfd +from tensorflow.keras.mixed_precision import experimental as prec +import tensorflow as tf +import numpy as np +import argparse +import collections +import functools +import json +import os +import pathlib +import sys +import time + +from tensorflow.python.ops.numpy_ops import np_config +np_config.enable_numpy_behavior() +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +# enable headless training on servers for mujoco +#os.environ['MUJOCO_GL'] = 'egl' + +tf.executing_eagerly() + +tf.get_logger().setLevel('ERROR') + + +sys.path.append(str(pathlib.Path(__file__).parent)) + + +def define_config(): + config = tools.AttrDict() + # General. + config.logdir = pathlib.Path('.') + config.seed = 0 + config.steps = 5e6 + config.eval_every = 1e4 + config.log_every = 1e3 + config.log_scalars = True + config.log_images = True + config.gpu_growth = True + config.precision = 16 + # Environment. + config.task = 'dmc_walker_walk' + config.envs = 1 + config.difficulty = 'none' + config.parallel = 'none' + config.action_repeat = 2 + config.time_limit = 1000 + config.prefill = 5000 + config.eval_noise = 0.0 + config.clip_rewards = 'none' + # Model. + config.deter_size = 200 + config.stoch_size = 30 + config.num_units = 400 + config.dense_act = 'elu' + config.cnn_act = 'relu' + config.cnn_depth = 32 + config.pcont = False + config.free_nats = 3.0 + config.kl_scale = 1.0 + config.pcont_scale = 10.0 + config.weight_decay = 0.0 + config.weight_decay_pattern = r'.*' + # Training. + config.batch_size = 50 + config.batch_length = 50 + config.train_every = 1000 + config.train_steps = 100 + config.pretrain = 100 + config.model_lr = 6e-4 + config.value_lr = 8e-5 + config.actor_lr = 8e-5 + config.grad_clip = 100.0 + config.dataset_balance = False + # Behavior. + config.discount = 0.99 + config.disclam = 0.95 + config.horizon = 15 + config.action_dist = 'tanh_normal' + config.action_init_std = 5.0 + config.expl = 'additive_gaussian' + config.expl_amount = 0.3 + config.expl_decay = 0.0 + config.expl_min = 0.0 + config.log_imgs = True + + # natural or not + config.natural = True + config.custom_video = False + + # obs model + config.obs_model = 'dbc' + + + + # use trajectory optimization + config.trajectory_opt = False + config.traj_opt_lr = 0.003 + config.num_samples = 20 + return config + + +class Dreamer(tools.Module): + + def __init__(self, config, datadir, actspace, writer): + self._c = config + self._actspace = actspace + self._actdim = actspace.n if hasattr( + actspace, 'n') else actspace.shape[0] + self._writer = writer + self._random = np.random.RandomState(config.seed) + with tf.device('cpu:0'): + self._step = tf.Variable(count_steps( + datadir, config), dtype=tf.int64) + self._should_pretrain = tools.Once() + self._should_train = tools.Every(config.train_every) + self._should_log = tools.Every(config.log_every) + self._last_log = None + self._last_time = time.time() + self._metrics = collections.defaultdict(tf.metrics.Mean) + self._metrics['expl_amount'] # Create variable for checkpoint. + self._float = prec.global_policy().compute_dtype + self._dataset = iter(load_dataset(datadir, self._c)) + self._build_model() + + def __call__(self, obs, reset, state=None, training=True): + step = self._step.numpy().item() + tf.summary.experimental.set_step(step) + if state is not None and reset.any(): + mask = tf.cast(1 - reset, self._float)[:, None] + state = tf.nest.map_structure(lambda x: x * mask, state) + if self._should_train(step): + log = self._should_log(step) + n = self._c.pretrain if self._should_pretrain() else self._c.train_steps + print(f'Training for {n} steps.') + # with self._strategy.scope(): + for train_step in range(n): + log_images = self._c.log_images and log and train_step == 0 + self.train(next(self._dataset), log_images) + if log: + self._write_summaries() + action, state = self.policy(obs, state, training) + if training: + self._step.assign_add(len(reset) * self._c.action_repeat) + return action, state + + @tf.function + def policy(self, obs, state, training): + if state is None: + latent = self._dynamics.initial(len(obs['image'])) + action = tf.zeros((len(obs['image']), self._actdim), self._float) + else: + latent, action = state + embed = self._encode(preprocess(obs, self._c)) + latent, _ = self._dynamics.obs_step(latent, action, embed) + feat = self._dynamics.get_feat(latent) + + if self._c.trajectory_opt: + action = self._trajectory_optimization(latent) + else: + if training: + action = self._actor(feat).sample() + else: + action = self._actor(feat).mode() + + action = self._exploration(action, training) + state = (latent, action) + return action, state + + def load(self, filename): + super().load(filename) + self._should_pretrain() + + @tf.function() + def train(self, data, log_images=True): + self._train(data, log_images) + + def _train(self, data, log_images): + with tf.GradientTape() as model_tape: + embed = self._encode(data) + batch_size = embed.shape[0] + + + + + post, prior = self._dynamics.observe(embed, data['action']) + + + feat = self._dynamics.get_feat(post) + + + reward_pred = self._reward(feat) + + + likes = tools.AttrDict() + likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) + + # if we use the generative observation model, we need to perform observation reconstruction + image_pred = self._decode(feat) + # compute the contrative loss directly + cont_loss = self._contrastive(feat, embed) + + + + if self._c.pcont: + pcont_pred = self._pcont(feat) + pcont_target = self._c.discount * data['discount'] + likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) + likes.pcont *= self._c.pcont_scale + + prior_dist = self._dynamics.get_dist(prior) + post_dist = self._dynamics.get_dist(post) + + + + + + + + + + ## Wasserstein distance b/w state transition kernels + transition_dist = tf.math.reduce_mean(tf.math.sqrt(tf.math.square(post['mean'] - prior['mean']) + tf.math.square(post['std'] - prior['std'])),axis=[1,2]) + + + + deepmdp_transition_loss = self._c.discount * transition_dist + # tf.print(z_dist) + # tf.print(r_dist) + # tf.print(transition_dist) + # tf.print(dbc_loss) + + # the contrastive / generative implementation of the observation model p(o|s) + if self._c.obs_model == 'generative': + likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) + elif self._c.obs_model == 'contrastive': + likes.image = tf.reduce_mean(cont_loss) + elif self._c.obs_model == 'dbc': + likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) + likes.image1 = tf.reduce_mean(deepmdp_transition_loss) + + div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) + div = tf.maximum(div, self._c.free_nats) + model_loss = self._c.kl_scale * div - sum(likes.values()) + + with tf.GradientTape() as actor_tape: + imag_feat = self._imagine_ahead(post) + reward = self._reward(imag_feat).mode() + if self._c.pcont: + pcont = self._pcont(imag_feat).mean() + else: + pcont = self._c.discount * tf.ones_like(reward) + value = self._value(imag_feat).mode() + returns = tools.lambda_return( + reward[:-1], value[:-1], pcont[:-1], + bootstrap=value[-1], lambda_=self._c.disclam, axis=0) + discount = tf.stop_gradient(tf.math.cumprod(tf.concat( + [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) + actor_loss = -tf.reduce_mean(discount * returns) + + with tf.GradientTape() as value_tape: + value_pred = self._value(imag_feat)[:-1] + target = tf.stop_gradient(returns) + value_loss = - \ + tf.reduce_mean(discount * value_pred.log_prob(target)) + + actor_norm = self._actor_opt(actor_tape, actor_loss) + value_norm = self._value_opt(value_tape, value_loss) + + + model_norm = self._model_opt(model_tape, model_loss) + states = tf.concat([post['stoch'], post['deter']], axis=-1) + rewards = data['reward'] + dones = tf.zeros_like(rewards) + actions = data['action'] + + + + if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: + if self._c.log_scalars: + self._scalar_summaries( + data, feat, prior_dist, post_dist, likes, div, + model_loss, value_loss, actor_loss, model_norm, value_norm, + actor_norm) + if tf.equal(log_images, True) and self._c.log_imgs: + self._image_summaries(data, embed, image_pred) + + def _build_model(self): + acts = dict( + elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish, + leaky_relu=tf.nn.leaky_relu) + cnn_act = acts[self._c.cnn_act] + act = acts[self._c.dense_act] + self._encode = models.ConvEncoder(self._c.cnn_depth, cnn_act) + self._dynamics = models.RSSM( + self._c.stoch_size, self._c.deter_size, self._c.deter_size) + self._decode = models.ConvDecoder(self._c.cnn_depth, cnn_act) + self._contrastive = models.ContrastiveObsModel(self._c.deter_size, + self._c.deter_size * 2) + self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) + if self._c.pcont: + self._pcont = models.DenseDecoder( + (), 3, self._c.num_units, 'binary', act=act) + self._value = models.DenseDecoder((), 3, self._c.num_units, act=act) + self._Qs = [models.QNetwork(3, self._c.num_units, act=act) for _ in range(self._c.num_Qs)] + self._actor = models.ActionDecoder( + self._actdim, 4, self._c.num_units, self._c.action_dist, + init_std=self._c.action_init_std, act=act) + model_modules = [self._encode, self._dynamics, + self._contrastive, self._reward, self._decode] + if self._c.pcont: + model_modules.append(self._pcont) + Optimizer = functools.partial( + tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip, + wdpattern=self._c.weight_decay_pattern) + self._model_opt = Optimizer('model', model_modules, self._c.model_lr) + self._value_opt = Optimizer('value', [self._value], self._c.value_lr) + self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) + self._q_opts = [Optimizer('qs', [qnet], self._c.value_lr) for qnet in self._Qs] + + self.train(next(self._dataset)) + + def _exploration(self, action, training): + if training: + amount = self._c.expl_amount + if self._c.expl_decay: + amount *= 0.5 ** (tf.cast(self._step, + tf.float32) / self._c.expl_decay) + if self._c.expl_min: + amount = tf.maximum(self._c.expl_min, amount) + self._metrics['expl_amount'].update_state(amount) + elif self._c.eval_noise: + amount = self._c.eval_noise + else: + return action + if self._c.expl == 'additive_gaussian': + return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) + if self._c.expl == 'completely_random': + return tf.random.uniform(action.shape, -1, 1) + if self._c.expl == 'epsilon_greedy': + indices = tfd.Categorical(0 * action).sample() + return tf.where( + tf.random.uniform(action.shape[:1], 0, 1) < amount, + tf.one_hot(indices, action.shape[-1], dtype=self._float), + action) + raise NotImplementedError(self._c.expl) + + def _imagine_ahead(self, post): + if self._c.pcont: # Last step could be terminal. + post = {k: v[:, :-1] for k, v in post.items()} + + def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:])) + start = {k: flatten(v) for k, v in post.items()} + + def policy(state): return self._actor( + tf.stop_gradient(self._dynamics.get_feat(state))).sample() + states = tools.static_scan( + lambda prev, _: self._dynamics.img_step(prev, policy(prev)), + tf.range(self._c.horizon), start) + imag_feat = self._dynamics.get_feat(states) + return imag_feat + + def _trajectory_optimization(self, post): + def policy(state): return self._actor( + tf.stop_gradient(self._dynamics.get_feat(state))).sample() + + def repeat(x): + return tf.repeat(x, self._c.num_samples, axis=0) + + states, actions = tools.static_scan_action( + lambda prev, action, _: self._dynamics.img_step(prev, action), + lambda prev: policy(prev), + tf.range(self._c.horizon), post) + + feat = self._dynamics.get_feat(states) + reward = self._reward(feat).mode() + + if self._c.pcont: + pcont = self._pcont(feat).mean() + else: + pcont = self._c.discount * tf.ones_like(reward) + value = self._value(feat).mode() + + # compute the accumulated reward + returns = tools.lambda_return( + reward[:-1], value[:-1], pcont[:-1], + bootstrap=value[-1], lambda_=self._c.disclam, axis=0) + + accumulated_reward = returns[0, 0] + + # since the reward and latent dynamics are fully differentiable, we can backprop the gradients to update the actions + grad = tf.gradients(accumulated_reward, actions)[0] + act = actions + grad * self._c.traj_opt_lr + + return act + + + def _scalar_summaries( + self, data, feat, prior_dist, post_dist, likes, div, + model_loss, value_loss, actor_loss, model_norm, value_norm, + actor_norm): + self._metrics['model_grad_norm'].update_state(model_norm) + self._metrics['value_grad_norm'].update_state(value_norm) + self._metrics['actor_grad_norm'].update_state(actor_norm) + self._metrics['prior_ent'].update_state(prior_dist.entropy()) + self._metrics['post_ent'].update_state(post_dist.entropy()) + for name, logprob in likes.items(): + self._metrics[name + '_loss'].update_state(-logprob) + self._metrics['div'].update_state(div) + self._metrics['model_loss'].update_state(model_loss) + self._metrics['value_loss'].update_state(value_loss) + self._metrics['actor_loss'].update_state(actor_loss) + self._metrics['action_ent'].update_state(self._actor(feat).entropy()) + + def _image_summaries(self, data, embed, image_pred): + truth = data['image'][:6] + 0.5 + recon = image_pred.mode()[:6] + init, _ = self._dynamics.observe(embed[:6, :5], data['action'][:6, :5]) + init = {k: v[:, -1] for k, v in init.items()} + prior = self._dynamics.imagine(data['action'][:6, 5:], init) + openl = self._decode(self._dynamics.get_feat(prior)).mode() + model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) + error = (model - truth + 1) / 2 + openl = tf.concat([truth, model, error], 2) + tools.graph_summary( + self._writer, tools.video_summary, 'agent/openl', openl) + + def _write_summaries(self): + step = int(self._step.numpy()) + metrics = [(k, float(v.result())) for k, v in self._metrics.items()] + if self._last_log is not None: + duration = time.time() - self._last_time + self._last_time += duration + metrics.append(('fps', (step - self._last_log) / duration)) + self._last_log = step + [m.reset_states() for m in self._metrics.values()] + with (self._c.logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps({'step': step, **dict(metrics)}) + '\n') + [tf.summary.scalar('agent/' + k, m) for k, m in metrics] + print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics)) + self._writer.flush() + + +def preprocess(obs, config): + dtype = prec.global_policy().compute_dtype + obs = obs.copy() + with tf.device('cpu:0'): + obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 + clip_rewards = dict(none=lambda x: x, tanh=tf.tanh)[ + config.clip_rewards] + obs['reward'] = clip_rewards(obs['reward']) + return obs + + +def count_steps(datadir, config): + return tools.count_episodes(datadir)[1] * config.action_repeat + + +def load_dataset(directory, config): + episode = next(tools.load_episodes(directory, 1)) + types = {k: v.dtype for k, v in episode.items()} + shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()} + + def generator(): return tools.load_episodes( + directory, config.train_steps, config.batch_length, + config.dataset_balance) + dataset = tf.data.Dataset.from_generator(generator, types, shapes) + dataset = dataset.batch(config.batch_size, drop_remainder=True) + dataset = dataset.map(functools.partial(preprocess, config=config)) + dataset = dataset.prefetch(10) + return dataset + + +def summarize_episode(episode, config, datadir, writer, prefix): + episodes, steps = tools.count_episodes(datadir) + length = (len(episode['reward']) - 1) * config.action_repeat + ret = episode['reward'].sum() + print(f'{prefix.title()} episode of length {length} with return {ret:.1f}.') + metrics = [ + (f'{prefix}/return', float(episode['reward'].sum())), + (f'{prefix}/length', len(episode['reward']) - 1), + (f'episodes', episodes)] + step = count_steps(datadir, config) + with (config.logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps(dict([('step', step)] + metrics)) + '\n') + with writer.as_default(): # Env might run in a different thread. + tf.summary.experimental.set_step(step) + [tf.summary.scalar('sim/' + k, v) for k, v in metrics] + if prefix == 'test': + tools.video_summary(f'sim/{prefix}/video', episode['image'][None]) + + +def make_env(config, writer, prefix, datadir, train): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + if config.difficulty == 'none': + env = wrappers.DeepMindControl(task) + else: + env = wrappers.DeepMindControlDistraction(task,difficulty=config.difficulty) + env = wrappers.ActionRepeat(env, config.action_repeat) + env = wrappers.NormalizeActions(env) + if config.natural: + data = load_imgnet(train) + env = wrappers.NaturalMujoco(env, data) + elif config.custom_video: + import pickle + + with open('custom_video_jaco.pkl', 'rb') as file: + data = pickle.load(file) + env = wrappers.CustomMujoco(env, data) + + + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, (64, 64), grayscale=False, + life_done=True, sticky_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) + callbacks = [] + if train: + callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) + callbacks.append( + lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) + env = wrappers.Collect(env, callbacks, config.precision) + env = wrappers.RewardObs(env) + return env +def make_env_test(config, writer, prefix, datadir, train): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + if config.difficulty == 'none': + env = wrappers.DeepMindControl(task) + else: + env = wrappers.DeepMindControlDistraction(task,difficulty=config.difficulty) + env = wrappers.ActionRepeat(env, config.action_repeat) + env = wrappers.NormalizeActions(env) + if config.natural: + data = load_imgnet(train) + env = wrappers.NaturalMujoco(env, data) + elif config.custom_video: + import pickle + + with open('custom_video_jaco.pkl', 'rb') as file: + data = pickle.load(file) + env = wrappers.CustomMujoco(env, data) + + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, (64, 64), grayscale=False, + life_done=True, sticky_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) + callbacks = [] + if train: + callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) + callbacks.append( + lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) + env = wrappers.Collect(env, callbacks, config.precision) + env = wrappers.RewardObs(env) + return env + + +def main(config): + print('mainn') + if config.gpu_growth: + for gpu in tf.config.experimental.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(gpu, True) + assert config.precision in (16, 32), config.precision + if config.precision == 16: + prec.set_policy(prec.Policy('mixed_float16')) + config.steps = int(config.steps) + config.logdir.mkdir(parents=True, exist_ok=True) + print('Logdir', config.logdir) + + arg_dict = vars(config).copy() + del arg_dict['logdir'] + + # with open(os.path.join(config.logdir, 'args.json'), 'w') as fout: + # import json + # json.dump(arg_dict, fout) + + # Create environments. + datadir = config.logdir / 'episodes' + datadir.mkdir(parents=True, exist_ok=True) + writer = tf.summary.create_file_writer( + str(config.logdir), max_queue=1000, flush_millis=20000) + writer.set_as_default() + train_envs = [wrappers.Async(lambda: make_env( + config, writer, 'train', datadir, train=True), config.parallel) + for _ in range(config.envs)] + test_envs = [wrappers.Async(lambda: make_env_test( + config, writer, 'test', datadir, train=False), config.parallel) + for _ in range(config.envs)] + actspace = train_envs[0].action_space + + # Prefill dataset with random episodes. + step = count_steps(datadir, config) + prefill = max(0, config.prefill - step) + print(f'Prefill dataset with {prefill} steps.') + def random_agent(o, d, _): return ([actspace.sample() for _ in d], None) + tools.simulate(random_agent, train_envs, prefill / config.action_repeat) + writer.flush() + + # Train and regularly evaluate the agent. + step = count_steps(datadir, config) + print(f'Simulating agent for {config.steps-step} steps.') + agent = Dreamer(config, datadir, actspace, writer) + if (config.logdir / 'variables.pkl').exists(): + print('Load checkpoint.') + agent.load(config.logdir / 'variables.pkl') + state = None + while step < config.steps: + print('Start evaluation.') + tools.simulate( + functools.partial(agent, training=False), test_envs, episodes=1) + writer.flush() + print('Start collection.') + steps = config.eval_every // config.action_repeat + state = tools.simulate(agent, train_envs, steps, state=state) + step = count_steps(datadir, config) + agent.save(config.logdir / 'variables.pkl') + for env in train_envs + test_envs: + env.close() + + +#if __name__ == '__main__': + # try: + # import colored_traceback + # colored_traceback.add_hook() + # except ImportError: + # pass +parser = argparse.ArgumentParser() +for key, value in define_config().items(): + parser.add_argument( + f'--{key}', type=tools.args_type(value), default=value) +args = parser.parse_args() + +print('main') + +main(args) diff --git a/dreamer_contrastive.py b/dreamer_contrastive.py new file mode 100644 index 0000000..864df36 --- /dev/null +++ b/dreamer_contrastive.py @@ -0,0 +1,658 @@ +import sys +import pathlib +sys.path.append(str(pathlib.Path(__file__).parent)) + +import wrappers +import tools +import models +from tensorflow_probability import distributions as tfd +from tensorflow.keras.mixed_precision import experimental as prec +import tensorflow as tf +import numpy as np +import argparse +import collections +import functools +import json +import os +import pathlib +import sys +import time +import soft_actor_critic + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +# enable headless training on servers for mujoco +#os.environ['MUJOCO_GL'] = 'egl' + +tf.executing_eagerly() + +tf.get_logger().setLevel('ERROR') + + +sys.path.append(str(pathlib.Path(__file__).parent)) + + +def define_config(): + config = tools.AttrDict() + # General. + config.logdir = pathlib.Path('.') + config.seed = 0 + config.steps = 5e6 + config.eval_every = 1e4 + config.log_every = 1e3 + config.log_scalars = True + config.log_images = True + config.gpu_growth = True + config.precision = 16 + # Environment. + config.task = 'dmc_walker_walk' + config.envs = 1 + config.difficulty = 'none' + config.parallel = 'none' + config.action_repeat = 2 + config.time_limit = 1000 + config.prefill = 5000 + config.eval_noise = 0.0 + config.clip_rewards = 'none' + # Model. + config.deter_size = 200 + config.stoch_size = 30 + config.num_units = 400 + config.dense_act = 'elu' + config.cnn_act = 'relu' + config.cnn_depth = 32 + config.pcont = False + config.free_nats = 3.0 + config.kl_scale = 1.0 + config.pcont_scale = 10.0 + config.weight_decay = 0.0 + config.weight_decay_pattern = r'.*' + # Training. + config.batch_size = 50 + config.batch_length = 50 + config.train_every = 1000 + config.train_steps = 100 + config.pretrain = 100 + config.model_lr = 6e-4 + config.value_lr = 8e-5 + config.actor_lr = 8e-5 + config.grad_clip = 100.0 + config.dataset_balance = False + # Behavior. + config.discount = 0.99 + config.disclam = 0.95 + config.horizon = 15 + config.action_dist = 'tanh_normal' + config.action_init_std = 5.0 + config.expl = 'additive_gaussian' + config.expl_amount = 0.3 + config.expl_decay = 0.0 + config.expl_min = 0.0 + config.log_imgs = True + + # natural or not + config.natural = True + config.custom_video = True + + # obs model + config.obs_model = 'generative' # or 'contrastive' + + + + # use trajectory optimization + config.trajectory_opt = False + config.traj_opt_lr = 0.003 + config.num_samples = 20 + return config + + +class Dreamer(tools.Module): + + def __init__(self, config, datadir, actspace, writer): + self._c = config + self._actspace = actspace + self._actdim = actspace.n if hasattr( + actspace, 'n') else actspace.shape[0] + self._writer = writer + self._random = np.random.RandomState(config.seed) + with tf.device('cpu:0'): + self._step = tf.Variable(count_steps( + datadir, config), dtype=tf.int64) + self._should_pretrain = tools.Once() + self._should_train = tools.Every(config.train_every) + self._should_log = tools.Every(config.log_every) + self._last_log = None + self._last_time = time.time() + self._metrics = collections.defaultdict(tf.metrics.Mean) + self._metrics['expl_amount'] # Create variable for checkpoint. + self._float = prec.global_policy().compute_dtype + self._dataset = iter(load_dataset(datadir, self._c)) + self._build_model() + + def __call__(self, obs, reset, state=None, training=True): + step = self._step.numpy().item() + tf.summary.experimental.set_step(step) + if state is not None and reset.any(): + mask = tf.cast(1 - reset, self._float)[:, None] + state = tf.nest.map_structure(lambda x: x * mask, state) + if self._should_train(step): + log = self._should_log(step) + n = self._c.pretrain if self._should_pretrain() else self._c.train_steps + print(f'Training for {n} steps.') + # with self._strategy.scope(): + for train_step in range(n): + log_images = self._c.log_images and log and train_step == 0 + self.train(next(self._dataset), log_images) + if log: + self._write_summaries() + action, state = self.policy(obs, state, training) + if training: + self._step.assign_add(len(reset) * self._c.action_repeat) + return action, state + + @tf.function + def policy(self, obs, state, training): + if state is None: + latent = self._dynamics.initial(len(obs['image'])) + action = tf.zeros((len(obs['image']), self._actdim), self._float) + else: + latent, action = state + embed = self._encode(preprocess(obs, self._c)) + latent, _ = self._dynamics.obs_step(latent, action, embed) + feat = self._dynamics.get_feat(latent) + + if self._c.trajectory_opt: + action = self._trajectory_optimization(latent) + else: + if training: + action = self._actor(feat).sample() + else: + action = self._actor(feat).mode() + + action = self._exploration(action, training) + state = (latent, action) + return action, state + + def load(self, filename): + super().load(filename) + self._should_pretrain() + + @tf.function() + def train(self, data, log_images=True): + self._train(data, log_images) + + def _train(self, data, log_images): + with tf.GradientTape() as model_tape: + embed = self._encode(data) + post, prior = self._dynamics.observe(embed, data['action']) + batch_size = embed.shape[0] + feat = self._dynamics.get_feat(post) + reward_pred = self._reward(feat) + likes = tools.AttrDict() + likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) + + calc_bs = True + bs = None + + if calc_bs == True: + ## for behavioral similarity ## + from scipy.spatial import distance + + try: + states_graph = data['position'] + except: + states_graph = data['orientations'] + + latents_graph = feat + assert states_graph.shape[0] == latents_graph.shape[0] + + #K = tf.zeros([batch_size,batch_size],dtype=tf.dtypes.float16) + c = 1e2 + K = tf.constant([0],dtype=tf.dtypes.float16) + for i in range(batch_size): + for j in range(batch_size): + # K[i,j] = max(0,c - abs(distance.cosine(latents_graph[i].numpy(),latents_graph[j].numpy())-distance.cosine(states_graph[i].numpy(),states_graph[j].numpy())) ) + #tf.norm(x1-y1,ord='euclidean') + K = K + (1/(batch_size*batch_size))*tf.math.maximum(tf.constant([0],dtype=tf.dtypes.float16),c - tf.abs(tf.norm(latents_graph[i]-latents_graph[j],ord='euclidean')-tf.norm(states_graph[i]-states_graph[j],ord='euclidean')) ) + ## compute normalized Kernel distance - a number between 0 and 100. 100 is max similaruuty, 0 is min similarity + + bs = K + + ## behavioral similarity end ## + + # if we use the generative observation model, we need to perform observation reconstruction + image_pred = self._decode(feat) + # compute the contrative loss directly + cont_loss = self._contrastive(feat, embed) + + # the contrastive / generative implementation of the observation model p(o|s) + if self._c.obs_model == 'generative': + likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) + elif self._c.obs_model == 'contrastive': + likes.image = tf.reduce_mean(cont_loss) + + if self._c.pcont: + pcont_pred = self._pcont(feat) + pcont_target = self._c.discount * data['discount'] + likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) + likes.pcont *= self._c.pcont_scale + + prior_dist = self._dynamics.get_dist(prior) + post_dist = self._dynamics.get_dist(post) + div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) + div = tf.maximum(div, self._c.free_nats) + model_loss = self._c.kl_scale * div - sum(likes.values()) + + + with tf.GradientTape() as actor_tape: + imag_feat = self._imagine_ahead(post) + reward = self._reward(imag_feat).mode() + if self._c.pcont: + pcont = self._pcont(imag_feat).mean() + else: + pcont = self._c.discount * tf.ones_like(reward) + value = self._value(imag_feat).mode() + returns = tools.lambda_return( + reward[:-1], value[:-1], pcont[:-1], + bootstrap=value[-1], lambda_=self._c.disclam, axis=0) + discount = tf.stop_gradient(tf.math.cumprod(tf.concat( + [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) + actor_loss = -tf.reduce_mean(discount * returns) + + with tf.GradientTape() as value_tape: + value_pred = self._value(imag_feat)[:-1] + target = tf.stop_gradient(returns) + value_loss = - \ + tf.reduce_mean(discount * value_pred.log_prob(target)) + + actor_norm = self._actor_opt(actor_tape, actor_loss) + value_norm = self._value_opt(value_tape, value_loss) + + + model_norm = self._model_opt(model_tape, model_loss) + states = tf.concat([post['stoch'], post['deter']], axis=-1) + rewards = data['reward'] + dones = tf.zeros_like(rewards) + actions = data['action'] + + + if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: + if self._c.log_scalars: + self._scalar_summaries( + data, feat, prior_dist, post_dist, likes, div, + model_loss, value_loss, actor_loss, model_norm, value_norm, + actor_norm,bs) + if tf.equal(log_images, True) and self._c.log_imgs: + self._image_summaries(data, embed, image_pred) + + def _build_model(self): + acts = dict( + elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish, + leaky_relu=tf.nn.leaky_relu) + cnn_act = acts[self._c.cnn_act] + act = acts[self._c.dense_act] + self._encode = models.ConvEncoder(self._c.cnn_depth, cnn_act) + self._dynamics = models.RSSM( + self._c.stoch_size, self._c.deter_size, self._c.deter_size) + self._decode = models.ConvDecoder(self._c.cnn_depth, cnn_act) + self._contrastive = models.ContrastiveObsModel(self._c.deter_size, + self._c.deter_size * 2) + self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) + if self._c.pcont: + self._pcont = models.DenseDecoder( + (), 3, self._c.num_units, 'binary', act=act) + self._value = models.DenseDecoder((), 3, self._c.num_units, act=act) + self._Qs = [models.QNetwork(3, self._c.num_units, act=act) for _ in range(self._c.num_Qs)] + self._actor = models.ActionDecoder( + self._actdim, 4, self._c.num_units, self._c.action_dist, + init_std=self._c.action_init_std, act=act) + model_modules = [self._encode, self._dynamics, + self._contrastive, self._reward, self._decode] + if self._c.pcont: + model_modules.append(self._pcont) + Optimizer = functools.partial( + tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip, + wdpattern=self._c.weight_decay_pattern) + self._model_opt = Optimizer('model', model_modules, self._c.model_lr) + self._value_opt = Optimizer('value', [self._value], self._c.value_lr) + self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) + self._q_opts = [Optimizer('qs', [qnet], self._c.value_lr) for qnet in self._Qs] + + + + self.train(next(self._dataset)) + + def _exploration(self, action, training): + if training: + amount = self._c.expl_amount + if self._c.expl_decay: + amount *= 0.5 ** (tf.cast(self._step, + tf.float32) / self._c.expl_decay) + if self._c.expl_min: + amount = tf.maximum(self._c.expl_min, amount) + self._metrics['expl_amount'].update_state(amount) + elif self._c.eval_noise: + amount = self._c.eval_noise + else: + return action + if self._c.expl == 'additive_gaussian': + return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) + if self._c.expl == 'completely_random': + return tf.random.uniform(action.shape, -1, 1) + if self._c.expl == 'epsilon_greedy': + indices = tfd.Categorical(0 * action).sample() + return tf.where( + tf.random.uniform(action.shape[:1], 0, 1) < amount, + tf.one_hot(indices, action.shape[-1], dtype=self._float), + action) + raise NotImplementedError(self._c.expl) + + def _imagine_ahead(self, post): + if self._c.pcont: # Last step could be terminal. + post = {k: v[:, :-1] for k, v in post.items()} + + def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:])) + start = {k: flatten(v) for k, v in post.items()} + + def policy(state): return self._actor( + tf.stop_gradient(self._dynamics.get_feat(state))).sample() + states = tools.static_scan( + lambda prev, _: self._dynamics.img_step(prev, policy(prev)), + tf.range(self._c.horizon), start) + imag_feat = self._dynamics.get_feat(states) + return imag_feat + + def _trajectory_optimization(self, post): + def policy(state): return self._actor( + tf.stop_gradient(self._dynamics.get_feat(state))).sample() + + def repeat(x): + return tf.repeat(x, self._c.num_samples, axis=0) + + states, actions = tools.static_scan_action( + lambda prev, action, _: self._dynamics.img_step(prev, action), + lambda prev: policy(prev), + tf.range(self._c.horizon), post) + + feat = self._dynamics.get_feat(states) + reward = self._reward(feat).mode() + + if self._c.pcont: + pcont = self._pcont(feat).mean() + else: + pcont = self._c.discount * tf.ones_like(reward) + value = self._value(feat).mode() + + # compute the accumulated reward + returns = tools.lambda_return( + reward[:-1], value[:-1], pcont[:-1], + bootstrap=value[-1], lambda_=self._c.disclam, axis=0) + + accumulated_reward = returns[0, 0] + + # since the reward and latent dynamics are fully differentiable, we can backprop the gradients to update the actions + grad = tf.gradients(accumulated_reward, actions)[0] + act = actions + grad * self._c.traj_opt_lr + + return act + + + def _scalar_summaries( + self, data, feat, prior_dist, post_dist, likes, div, + model_loss, value_loss, actor_loss, model_norm, value_norm, + actor_norm,bs=None): + self._metrics['model_grad_norm'].update_state(model_norm) + self._metrics['value_grad_norm'].update_state(value_norm) + self._metrics['actor_grad_norm'].update_state(actor_norm) + self._metrics['prior_ent'].update_state(prior_dist.entropy()) + self._metrics['post_ent'].update_state(post_dist.entropy()) + for name, logprob in likes.items(): + self._metrics[name + '_loss'].update_state(-logprob) + self._metrics['div'].update_state(div) + self._metrics['model_loss'].update_state(model_loss) + self._metrics['value_loss'].update_state(value_loss) + self._metrics['actor_loss'].update_state(actor_loss) + self._metrics['action_ent'].update_state(self._actor(feat).entropy()) + + if bs is not None: + self._metrics['bs'].update_state(bs) + + def _image_summaries(self, data, embed, image_pred): + truth = data['image'][:6] + 0.5 + recon = image_pred.mode()[:6] + init, _ = self._dynamics.observe(embed[:6, :5], data['action'][:6, :5]) + init = {k: v[:, -1] for k, v in init.items()} + prior = self._dynamics.imagine(data['action'][:6, 5:], init) + openl = self._decode(self._dynamics.get_feat(prior)).mode() + model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) + error = (model - truth + 1) / 2 + openl = tf.concat([truth, model, error], 2) + tools.graph_summary( + self._writer, tools.video_summary, 'agent/openl', openl) + + def _write_summaries(self): + step = int(self._step.numpy()) + metrics = [(k, float(v.result())) for k, v in self._metrics.items()] + if self._last_log is not None: + duration = time.time() - self._last_time + self._last_time += duration + metrics.append(('fps', (step - self._last_log) / duration)) + self._last_log = step + [m.reset_states() for m in self._metrics.values()] + with (self._c.logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps({'step': step, **dict(metrics)}) + '\n') + [tf.summary.scalar('agent/' + k, m) for k, m in metrics] + print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics)) + self._writer.flush() + + +def preprocess(obs, config): + dtype = prec.global_policy().compute_dtype + obs = obs.copy() + with tf.device('cpu:0'): + obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 + clip_rewards = dict(none=lambda x: x, tanh=tf.tanh)[ + config.clip_rewards] + obs['reward'] = clip_rewards(obs['reward']) + return obs + + +def count_steps(datadir, config): + return tools.count_episodes(datadir)[1] * config.action_repeat + + +def load_dataset(directory, config): + episode = next(tools.load_episodes(directory, 1)) + types = {k: v.dtype for k, v in episode.items()} + shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()} + + def generator(): return tools.load_episodes( + directory, config.train_steps, config.batch_length, + config.dataset_balance) + dataset = tf.data.Dataset.from_generator(generator, types, shapes) + dataset = dataset.batch(config.batch_size, drop_remainder=True) + dataset = dataset.map(functools.partial(preprocess, config=config)) + dataset = dataset.prefetch(10) + return dataset + + +def summarize_episode(episode, config, datadir, writer, prefix): + episodes, steps = tools.count_episodes(datadir) + length = (len(episode['reward']) - 1) * config.action_repeat + ret = episode['reward'].sum() + print(f'{prefix.title()} episode of length {length} with return {ret:.1f}.') + metrics = [ + (f'{prefix}/return', float(episode['reward'].sum())), + (f'{prefix}/length', len(episode['reward']) - 1), + (f'episodes', episodes)] + step = count_steps(datadir, config) + with (config.logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps(dict([('step', step)] + metrics)) + '\n') + with writer.as_default(): # Env might run in a different thread. + tf.summary.experimental.set_step(step) + [tf.summary.scalar('sim/' + k, v) for k, v in metrics] + if prefix == 'test': + tools.video_summary(f'sim/{prefix}/video', episode['image'][None]) + + +def make_env(config, writer, prefix, datadir, train): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + if config.difficulty == 'none': + env = wrappers.DeepMindControl(task) + else: + env = wrappers.DeepMindControlDistraction(task,difficulty=config.difficulty) + env = wrappers.ActionRepeat(env, config.action_repeat) + env = wrappers.NormalizeActions(env) + if config.natural: + data = load_imgnet(train) + env = wrappers.NaturalMujoco(env, data) + elif config.custom_video: + import pickle + + with oen('custom_video_jaco.pkl', 'rb') as file: + data = pickle.load(file) + env = wrappers.CustomMujoco(env, data) + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, (64, 64), grayscale=False, + life_done=True, sticky_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) + callbacks = [] + if train: + callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) + callbacks.append( + lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) + env = wrappers.Collect(env, callbacks, config.precision) + env = wrappers.RewardObs(env) + return env +def make_env_test(config, writer, prefix, datadir, train): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + if config.difficulty == 'none': + env = wrappers.DeepMindControl(task) + else: + env = wrappers.DeepMindControlDistraction(task,difficulty=config.difficulty) + env = wrappers.ActionRepeat(env, config.action_repeat) + env = wrappers.NormalizeActions(env) + if config.natural: + data = load_imgnet(train) + env = wrappers.NaturalMujoco(env, data) + elif config.custom_video: + import pickle + + with gfile.Open('custom_video_jaco.pkl', 'rb') as file: + data = pickle.load(file) + env = wrappers.CustomMujoco(env, data) + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, (64, 64), grayscale=False, + life_done=True, sticky_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) + callbacks = [] + if train: + callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) + callbacks.append( + lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) + env = wrappers.Collect(env, callbacks, config.precision) + env = wrappers.RewardObs(env) + return env + +def load_imgnet(train): + import pickle + name = 'train' if train else 'valid' + + # images_train. pkl and images_test.pkl to be downloaded from + + with open('images_{}.pkl'.format(name), 'rb') as fin: + imgnet = pickle.load(fin) + + imgnet = np.transpose(imgnet, axes=(0, 1, 3, 4, 2)) + + return imgnet + + +def main(config): + print('mainn') + if config.gpu_growth: + for gpu in tf.config.experimental.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(gpu, True) + assert config.precision in (16, 32), config.precision + if config.precision == 16: + prec.set_policy(prec.Policy('mixed_float16')) + config.steps = int(config.steps) + config.logdir.mkdir(parents=True, exist_ok=True) + print('Logdir', config.logdir) + + arg_dict = vars(config).copy() + del arg_dict['logdir'] + + # with open(os.path.join(config.logdir, 'args.json'), 'w') as fout: + # import json + # json.dump(arg_dict, fout) + + # Create environments. + datadir = config.logdir / 'episodes' + datadir.mkdir(parents=True, exist_ok=True) + writer = tf.summary.create_file_writer( + str(config.logdir), max_queue=1000, flush_millis=20000) + writer.set_as_default() + train_envs = [wrappers.Async(lambda: make_env( + config, writer, 'train', datadir, train=True), config.parallel) + for _ in range(config.envs)] + test_envs = [wrappers.Async(lambda: make_env_test( + config, writer, 'test', datadir, train=False), config.parallel) + for _ in range(config.envs)] + actspace = train_envs[0].action_space + + # Prefill dataset with random episodes. + step = count_steps(datadir, config) + prefill = max(0, config.prefill - step) + print(f'Prefill dataset with {prefill} steps.') + def random_agent(o, d, _): return ([actspace.sample() for _ in d], None) + tools.simulate(random_agent, train_envs, prefill / config.action_repeat) + writer.flush() + + # Train and regularly evaluate the agent. + step = count_steps(datadir, config) + print(f'Simulating agent for {config.steps-step} steps.') + agent = Dreamer(config, datadir, actspace, writer) + if (config.logdir / 'variables.pkl').exists(): + print('Load checkpoint.') + agent.load(config.logdir / 'variables.pkl') + state = None + while step < config.steps: + print('Start evaluation.') + tools.simulate( + functools.partial(agent, training=False), test_envs, episodes=1) + writer.flush() + print('Start collection.') + steps = config.eval_every // config.action_repeat + state = tools.simulate(agent, train_envs, steps, state=state) + step = count_steps(datadir, config) + agent.save(config.logdir / 'variables.pkl') + for env in train_envs + test_envs: + env.close() + + +#if __name__ == '__main__': + # try: + # import colored_traceback + # colored_traceback.add_hook() + # except ImportError: + # pass +parser = argparse.ArgumentParser() +for key, value in define_config().items(): + parser.add_argument( + f'--{key}', type=tools.args_type(value), default=value) +args = parser.parse_args() + +print('main') + +main(args) diff --git a/dreamer_contrastive_inverse.py b/dreamer_contrastive_inverse.py new file mode 100644 index 0000000..e0bb444 --- /dev/null +++ b/dreamer_contrastive_inverse.py @@ -0,0 +1,799 @@ +import sys +import pathlib +sys.path.append(str(pathlib.Path(__file__).parent)) + +import wrappers +import tools_inv as tools +import models +from tensorflow_probability import distributions as tfd +from tensorflow.keras.mixed_precision import experimental as prec + + +import tensorflow as tf +import numpy as np +import argparse +import collections +import functools +import json +import os +import pathlib +import sys +import time +import soft_actor_critic + +import sklearn +import sklearn.manifold + +import pandas as pd +import io +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +from PIL import Image +import cv2 + +# enable headless training on servers for mujoco +#os.environ['MUJOCO_GL'] = 'egl' + +tf.executing_eagerly() + +tf.get_logger().setLevel('ERROR') + + +sys.path.append(str(pathlib.Path(__file__).parent)) + + +def define_config(): + config = tools.AttrDict() + # General. + config.logdir = pathlib.Path('.') + config.seed = 0 + config.steps = 10e6 + config.eval_every = 1e4 + config.log_every = 1e3 + config.log_scalars = True + config.log_images = True + config.gpu_growth = True + config.precision = 16 + # Environment. + config.task = 'dmc_cup_catch' #'dmc_walker_walk' + config.envs = 1 + config.difficulty = 'none' + config.parallel = 'none' + config.action_repeat = 2 + config.time_limit = 1000 + + config.eval_noise = 0.0 + config.clip_rewards = 'none' + # Model. + config.deter_size = 200 + config.stoch_size = 30 + config.num_units = 400 + config.dense_act = 'elu' + config.cnn_act = 'relu' + config.cnn_depth = 32 + config.pcont = False + config.free_nats = 3.0 + config.kl_scale = 1.0 + config.pcont_scale = 10.0 + config.weight_decay = 0.0 + config.weight_decay_pattern = r'.*' + + # Training. settings for debugging + # config.prefill = 100 #5000 + # config.batch_size = 10 #50 + # config.batch_length = 10 #50 + + # config.train_steps = 1 #100 + # config.pretrain = 1 #100 + + config.prefill = 5000 + config.batch_size = 50 + config.batch_length = 50 + + config.train_steps = 100 + config.pretrain = 100 + + + config.train_every = 1000 + config.model_lr = 6e-4 + config.value_lr = 8e-5 + config.actor_lr = 8e-5 + config.grad_clip = 100.0 + config.dataset_balance = False + # Behavior. + config.discount = 0.99 + config.disclam = 0.95 + config.horizon = 15 + config.action_dist = 'tanh_normal' + config.action_init_std = 5.0 + config.expl = 'additive_gaussian' + config.expl_amount = 0.3 + config.expl_decay = 0.0 + config.expl_min = 0.0 + config.log_imgs = True + config.inv_dyn_bonus = 0.01 + + # natural, custom frame, or none + config.natural = True + config.custom_video = True + + # obs model + config.obs_model = 'contrastive' + + + + + # use trajectory optimization + config.trajectory_opt = True + config.traj_opt_lr = 0.003 + config.num_samples = 20 + return config + + +class Dreamer(tools.Module): + + def __init__(self, config, datadir, actspace, writer): + self._c = config + self._actspace = actspace + self._actdim = actspace.n if hasattr( + actspace, 'n') else actspace.shape[0] + self._writer = writer + self._random = np.random.RandomState(config.seed) + with tf.device('cpu:0'): + self._step = tf.Variable(count_steps( + datadir, config), dtype=tf.int64) + self._should_pretrain = tools.Once() + self._should_train = tools.Every(config.train_every) + self._should_log = tools.Every(config.log_every) + self._should_calc_bs = tools.Every(1e4) + self._last_log = None + self._last_time = time.time() + self._metrics = collections.defaultdict(tf.metrics.Mean) + self._metrics['expl_amount'] # Create variable for checkpoint. + self._float = prec.global_policy().compute_dtype + self._dataset = iter(load_dataset(datadir, self._c)) + self._build_model() + + def __call__(self, obs, reset, state=None, training=True): + step = self._step.numpy().item() + tf.summary.experimental.set_step(step) + + inv_action_loss = 0 + if state is not None and reset.any(): + mask = tf.cast(1 - reset, self._float)[:, None] + state = tf.nest.map_structure(lambda x: x * mask, state) + if self._should_train(step): + log = self._should_log(step) + calc_bs = self._should_calc_bs(step) + n = self._c.pretrain if self._should_pretrain() else self._c.train_steps + print(f'Training for {n} steps.') + # with self._strategy.scope(): + for train_step in range(n): + log_images = self._c.log_images and log and train_step == 0 + self.train(next(self._dataset), log_images, inv_action_loss,calc_bs) + if log: + self._write_summaries() + action, state, old_latent,old_action = self.policy(obs, state, training) + latent, _ = state + + from sklearn.preprocessing import StandardScaler + from sklearn.decomposition import PCA + + + + if self._should_train(step): + log = self._should_log(step) + n = self._c.pretrain if self._should_pretrain() else self._c.train_steps + print(f'Training inverse model for {n} steps.') + # with self._strategy.scope(): + for train_step in range(n): + with tf.GradientTape() as inv_action_tape: + predicted_action = self.predict_action(latent,old_latent) ## predicted action sample + inv_action_loss = tf.reduce_mean(tf.keras.metrics.mean_squared_error(old_action, predicted_action)) + inv_action_norm = self._inv_action_opt(inv_action_tape, inv_action_loss) + + + if training: + self._step.assign_add(len(reset) * self._c.action_repeat) + + return action, state, old_latent,old_action + + @tf.function + def policy(self, obs, state, training): + if state is None: + latent = self._dynamics.initial(len(obs['image'])) + action = tf.zeros((len(obs['image']), self._actdim), self._float) + else: + latent, action = state + embed = self._encode(preprocess(obs, self._c)) + old_latent = latent + old_action = action + latent, _ = self._dynamics.obs_step(latent, action, embed) + feat = self._dynamics.get_feat(latent) + + if self._c.trajectory_opt: + action = self._trajectory_optimization(latent) + else: + if training: + action = self._actor(feat).sample() + else: + action = self._actor(feat).mode() + + # ### inv dynamics bonus ### + + # predicted_action = self.predict_action(latent,old_latent) + + # action = action + self._c.inv_dyn_bonus* predicted_action + + action = self._exploration(action, training) + state = (latent, action) + return action, state, old_latent, old_action + + @tf.function + def predict_action(self,latent,old_latent,reverse=False): + + # embed = self._encode(preprocess(obs, self._c)) + # old_embed = self._encode(preprocess(old_obs, self._c)) + # latent, _ = self._dynamics.obs_step(latent, action, embed) + # old_latent, _ = self._dynamics.obs_step(old_latent, action, old_embed) + feat = self._dynamics.get_feat(latent) + old_feat = self._dynamics.get_feat(old_latent) + + predicted_action = self._inverse_model(feat,old_feat) + + + return predicted_action.sample() + + + + def load(self, filename): + super().load(filename) + self._should_pretrain() + + @tf.function() + def train(self, data, log_images=True,inv_action_loss=0,calc_bs=True): + self._train(data, log_images,inv_action_loss,calc_bs) + + def _train(self, data, log_images,inv_action_loss,calc_bs): + with tf.GradientTape() as model_tape: + embed = self._encode(data) + #print(data) + batch_size = embed.shape[0] + post, prior = self._dynamics.observe(embed, data['action']) + feat = self._dynamics.get_feat(post) + reward_pred = self._reward(feat) + likes = tools.AttrDict() + likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) + + calc_bs = True + bs = None + + if calc_bs == True: + ## for behavioral similarity ## + from scipy.spatial import distance + + try: + states_graph = data['position'] + except: + states_graph = data['orientations'] + + latents_graph = feat + assert states_graph.shape[0] == latents_graph.shape[0] + + #K = tf.zeros([batch_size,batch_size],dtype=tf.dtypes.float16) + c = 1e2 + K = tf.constant([0],dtype=tf.dtypes.float16) + for i in range(batch_size): + for j in range(batch_size): + # K[i,j] = max(0,c - abs(distance.cosine(latents_graph[i].numpy(),latents_graph[j].numpy())-distance.cosine(states_graph[i].numpy(),states_graph[j].numpy())) ) + #tf.norm(x1-y1,ord='euclidean') + K = K + (1/(batch_size*batch_size))*tf.math.maximum(tf.constant([0],dtype=tf.dtypes.float16),c - tf.abs(tf.norm(latents_graph[i]-latents_graph[j],ord='euclidean')-tf.norm(states_graph[i]-states_graph[j],ord='euclidean')) ) + ## compute normalized Kernel distance - a number between 0 and 100. 100 is max similaruuty, 0 is min similarity + + bs = K + + ## behavioral similarity end ## + + # if we use the generative observation model, we need to perform observation reconstruction + image_pred = self._decode(feat) + # compute the contrative loss directly + cont_loss = self._contrastive(feat, embed) + + # the contrastive / generative implementation of the observation model p(o|s) + if self._c.obs_model == 'generative': + likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) + elif self._c.obs_model == 'contrastive': + print(self._step) + if self._step < 100000: + likes.image = tf.reduce_mean(cont_loss) + else: + #likes.image = 0.001*tf.reduce_mean(cont_loss) + likes.image = tf.reduce_mean(cont_loss) + #likes.image1 = tf.stop_gradient(tf.reduce_mean(image_pred.log_prob(data['image']))) + + + if self._c.pcont: + pcont_pred = self._pcont(feat) + pcont_target = self._c.discount * data['discount'] + likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) + likes.pcont *= self._c.pcont_scale + + prior_dist = self._dynamics.get_dist(prior) + post_dist = self._dynamics.get_dist(post) + div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) + div = tf.maximum(div, self._c.free_nats) + model_loss = self._c.kl_scale * div - sum(likes.values()) + + + + with tf.GradientTape() as vae_tape: + likes_vae = tools.AttrDict() + image_pred = self._decode(feat) + likes_vae.image = tf.reduce_mean(image_pred.log_prob(data['image'])) + vae_loss = - sum(likes_vae.values()) + + + with tf.GradientTape() as actor_tape: + imag_feat = self._imagine_ahead(post) + reward = self._reward(imag_feat).mode() + if self._c.pcont: + pcont = self._pcont(imag_feat).mean() + else: + pcont = self._c.discount * tf.ones_like(reward) + value = self._value(imag_feat).mode() + returns = tools.lambda_return( + reward[:-1], value[:-1], pcont[:-1], + bootstrap=value[-1], lambda_=self._c.disclam, axis=0) + discount = tf.stop_gradient(tf.math.cumprod(tf.concat( + [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) + actor_loss = -tf.reduce_mean(discount * returns) - self._c.inv_dyn_bonus*inv_action_loss + #actor_loss = -tf.reduce_mean(discount * returns) + + with tf.GradientTape() as value_tape: + value_pred = self._value(imag_feat)[:-1] + target = tf.stop_gradient(returns) + value_loss = - \ + tf.reduce_mean(discount * value_pred.log_prob(target)) + + actor_norm = self._actor_opt(actor_tape, actor_loss) + value_norm = self._value_opt(value_tape, value_loss) + + + model_norm = self._model_opt(model_tape, model_loss) + vae_norm = self._vae_opt(vae_tape,vae_loss) + + states = tf.concat([post['stoch'], post['deter']], axis=-1) + rewards = data['reward'] + dones = tf.zeros_like(rewards) + actions = data['action'] + + + #self._image_summaries(data, embed, image_pred) + + if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: + if self._c.log_scalars: + self._scalar_summaries( + data, feat, prior_dist, post_dist, likes, div, + model_loss, value_loss, actor_loss, model_norm, value_norm, + actor_norm,likes_vae,vae_loss,vae_norm,bs) + #log_images = True + if tf.equal(log_images, True) and self._c.log_imgs: + print('logging reconstructions') + self._image_summaries(data, embed, image_pred) + + def _build_model(self): + acts = dict( + elu=tf.nn.elu, relu=tf.nn.relu, swish=tf.nn.swish, + leaky_relu=tf.nn.leaky_relu) + cnn_act = acts[self._c.cnn_act] + act = acts[self._c.dense_act] + self._encode = models.ConvEncoder(self._c.cnn_depth, cnn_act) + self._dynamics = models.RSSM( + self._c.stoch_size, self._c.deter_size, self._c.deter_size) + self._decode = models.ConvDecoder(self._c.cnn_depth, cnn_act) + # self._contrastive = models.ContrastiveObsModel(self._c.deter_size, + # self._c.deter_size * 2) + self._contrastive = models.ContrastiveObsModelNWJ(self._c.deter_size, + self._c.deter_size * 2) + self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) + if self._c.pcont: + self._pcont = models.DenseDecoder( + (), 3, self._c.num_units, 'binary', act=act) + self._value = models.DenseDecoder((), 3, self._c.num_units, act=act) + self._Qs = [models.QNetwork(3, self._c.num_units, act=act) for _ in range(self._c.num_Qs)] + self._actor = models.ActionDecoder( + self._actdim, 4, self._c.num_units, self._c.action_dist, + init_std=self._c.action_init_std, act=act) + + # model_modules = [self._encode, self._dynamics, + # self._contrastive, self._reward, self._decode] + model_modules = [self._encode, self._dynamics, + self._contrastive, self._reward] + if self._c.pcont: + model_modules.append(self._pcont) + Optimizer = functools.partial( + tools.Adam, wd=self._c.weight_decay, clip=self._c.grad_clip, + wdpattern=self._c.weight_decay_pattern) + self._model_opt = Optimizer('model', model_modules, self._c.model_lr) + self._vae_opt = Optimizer('vae', [self._decode], self._c.model_lr) + self._value_opt = Optimizer('value', [self._value], self._c.value_lr) + self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) + self._q_opts = [Optimizer('qs', [qnet], self._c.value_lr) for qnet in self._Qs] + + self._inverse_model = models.InverseActionDecoder( + self._actdim, 4, self._c.num_units, self._c.action_dist, + init_std=self._c.action_init_std, act=act) + self._inv_action_opt = Optimizer('inverse_model', [self._inverse_model], self._c.actor_lr) + + + + self.train(next(self._dataset)) + + def _exploration(self, action, training): + if training: + amount = self._c.expl_amount + if self._c.expl_decay: + amount *= 0.5 ** (tf.cast(self._step, + tf.float32) / self._c.expl_decay) + if self._c.expl_min: + amount = tf.maximum(self._c.expl_min, amount) + self._metrics['expl_amount'].update_state(amount) + elif self._c.eval_noise: + amount = self._c.eval_noise + else: + return action + if self._c.expl == 'additive_gaussian': + return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) + if self._c.expl == 'completely_random': + return tf.random.uniform(action.shape, -1, 1) + if self._c.expl == 'epsilon_greedy': + indices = tfd.Categorical(0 * action).sample() + return tf.where( + tf.random.uniform(action.shape[:1], 0, 1) < amount, + tf.one_hot(indices, action.shape[-1], dtype=self._float), + action) + raise NotImplementedError(self._c.expl) + + def _imagine_ahead(self, post): + if self._c.pcont: # Last step could be terminal. + post = {k: v[:, :-1] for k, v in post.items()} + + def flatten(x): return tf.reshape(x, [-1] + list(x.shape[2:])) + start = {k: flatten(v) for k, v in post.items()} + + def policy(state): return self._actor( + tf.stop_gradient(self._dynamics.get_feat(state))).sample() + states = tools.static_scan( + lambda prev, _: self._dynamics.img_step(prev, policy(prev)), + tf.range(self._c.horizon), start) + imag_feat = self._dynamics.get_feat(states) + return imag_feat + + def _trajectory_optimization(self, post): + def policy(state): return self._actor( + tf.stop_gradient(self._dynamics.get_feat(state))).sample() + + def repeat(x): + return tf.repeat(x, self._c.num_samples, axis=0) + + states, actions = tools.static_scan_action( + lambda prev, action, _: self._dynamics.img_step(prev, action), + lambda prev: policy(prev), + tf.range(self._c.horizon), post) + + feat = self._dynamics.get_feat(states) + reward = self._reward(feat).mode() + + if self._c.pcont: + pcont = self._pcont(feat).mean() + else: + pcont = self._c.discount * tf.ones_like(reward) + value = self._value(feat).mode() + + # compute the accumulated reward + returns = tools.lambda_return( + reward[:-1], value[:-1], pcont[:-1], + bootstrap=value[-1], lambda_=self._c.disclam, axis=0) + + accumulated_reward = returns[0, 0] + + # since the reward and latent dynamics are fully differentiable, we can backprop the gradients to update the actions + grad = tf.gradients(accumulated_reward, actions)[0] + act = actions + grad * self._c.traj_opt_lr + + return act + + + def _scalar_summaries( + self, data, feat, prior_dist, post_dist, likes, div, + model_loss, value_loss, actor_loss, model_norm, value_norm, + actor_norm,likes_vae = None,vae_loss=None,vae_norm=None,bs = None): + self._metrics['model_grad_norm'].update_state(model_norm) + self._metrics['value_grad_norm'].update_state(value_norm) + self._metrics['actor_grad_norm'].update_state(actor_norm) + self._metrics['prior_ent'].update_state(prior_dist.entropy()) + self._metrics['post_ent'].update_state(post_dist.entropy()) + for name, logprob in likes.items(): + self._metrics[name + '_loss'].update_state(-logprob) + + for name, logprob in likes_vae.items(): + self._metrics[name + '_loss'].update_state(-logprob) + self._metrics['vae_loss'].update_state(vae_loss) + self._metrics['vae_norm'].update_state(vae_norm) + self._metrics['div'].update_state(div) + self._metrics['model_loss'].update_state(model_loss) + self._metrics['value_loss'].update_state(value_loss) + self._metrics['actor_loss'].update_state(actor_loss) + self._metrics['action_ent'].update_state(self._actor(feat).entropy()) + + if bs is not None: + self._metrics['bs'].update_state(bs) + def _print_bs(self,bs): + print(bs.numpy()) + + + def _image_summaries(self, data, embed, image_pred): + truth = data['image'][:6] + 0.5 + recon = image_pred.mode()[:6] + init, _ = self._dynamics.observe(embed[:6, :5], data['action'][:6, :5]) + init = {k: v[:, -1] for k, v in init.items()} + prior = self._dynamics.imagine(data['action'][:6, 5:], init) + openl = self._decode(self._dynamics.get_feat(prior)).mode() + model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) + error = (model - truth + 1) / 2 + openl = tf.concat([truth, model, error], 2) + + #openl = tf.concat([truth, recon + 0.5, error], 2) + + + # tsne = sklearn.manifold.TSNE( + # n_components=2, + # perplexity=40, + # metric='cosine', + # early_exaggeration=10.0, + # init='pca', + # verbose=True, + # n_iter=400) + # low_dim_embs = tsne.fit_transform(self._dynamics.get_feat(prior)) + # plt.figure() + # plt.plot(low_dim_embs) + + # buf = io.BytesIO() + # plt.savefig(buf, format='png') + # buf.seek(0) + + # img = tf.image.decode_png(buf.getvalue(), channels=4) + + # # Add the batch dimension + # img = tf.expand_dims(img, 0) + # #with self._writer.as_default(): + # self._writer.add_summary(tf.summary.image("tsne", img, step=0)) + + + + + tools.graph_summary( + self._writer, tools.video_summary, 'agent/openl', openl) + + def _write_summaries(self): + step = int(self._step.numpy()) + metrics = [(k, float(v.result())) for k, v in self._metrics.items()] + if self._last_log is not None: + duration = time.time() - self._last_time + self._last_time += duration + metrics.append(('fps', (step - self._last_log) / duration)) + self._last_log = step + [m.reset_states() for m in self._metrics.values()] + with (self._c.logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps({'step': step, **dict(metrics)}) + '\n') + [tf.summary.scalar('agent/' + k, m) for k, m in metrics] + print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics)) + self._writer.flush() + + + +def preprocess(obs, config): + dtype = prec.global_policy().compute_dtype + obs = obs.copy() + with tf.device('cpu:0'): + obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 + clip_rewards = dict(none=lambda x: x, tanh=tf.tanh)[ + config.clip_rewards] + obs['reward'] = clip_rewards(obs['reward']) + return obs + + +def count_steps(datadir, config): + return tools.count_episodes(datadir)[1] * config.action_repeat + + +def load_dataset(directory, config): + episode = next(tools.load_episodes(directory, 1)) + types = {k: v.dtype for k, v in episode.items()} + shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()} + + def generator(): return tools.load_episodes( + directory, config.train_steps, config.batch_length, + config.dataset_balance) + dataset = tf.data.Dataset.from_generator(generator, types, shapes) + dataset = dataset.batch(config.batch_size, drop_remainder=True) + dataset = dataset.map(functools.partial(preprocess, config=config)) + dataset = dataset.prefetch(10) + return dataset + + +def summarize_episode(episode, config, datadir, writer, prefix): + episodes, steps = tools.count_episodes(datadir) + length = (len(episode['reward']) - 1) * config.action_repeat + ret = episode['reward'].sum() + print(f'{prefix.title()} episode of length {length} with return {ret:.1f}.') + metrics = [ + (f'{prefix}/return', float(episode['reward'].sum())), + (f'{prefix}/length', len(episode['reward']) - 1), + (f'episodes', episodes)] + step = count_steps(datadir, config) + with (config.logdir / 'metrics.jsonl').open('a') as f: + f.write(json.dumps(dict([('step', step)] + metrics)) + '\n') + with writer.as_default(): # Env might run in a different thread. + tf.summary.experimental.set_step(step) + [tf.summary.scalar('sim/' + k, v) for k, v in metrics] + if prefix == 'test': + tools.video_summary(f'sim/{prefix}/video', episode['image'][None]) + + +def make_env(config, writer, prefix, datadir, train): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + if config.difficulty == 'none': + env = wrappers.DeepMindControl(task) + else: + env = wrappers.DeepMindControlDistraction(task,difficulty=config.difficulty) + env = wrappers.ActionRepeat(env, config.action_repeat) + env = wrappers.NormalizeActions(env) + if config.natural: + data = load_imgnet(train) + env = wrappers.NaturalMujoco(env, data) + elif config.custom_video: + import pickle + + with open('custom_video_jaco.pkl', 'rb') as file: + data = pickle.load(file) + env = wrappers.CustomMujoco(env, data) + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, (64, 64), grayscale=False, + life_done=True, sticky_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) + callbacks = [] + if train: + callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) + callbacks.append( + lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) + env = wrappers.Collect(env, callbacks, config.precision) + env = wrappers.RewardObs(env) + return env +def make_env_test(config, writer, prefix, datadir, train): + suite, task = config.task.split('_', 1) + if suite == 'dmc': + if config.difficulty == 'none': + env = wrappers.DeepMindControl(task) + else: + env = wrappers.DeepMindControlDistraction(task,difficulty=config.difficulty) + env = wrappers.ActionRepeat(env, config.action_repeat) + env = wrappers.NormalizeActions(env) + if config.natural: + data = load_imgnet(train) + env = wrappers.NaturalMujoco(env, data) + elif config.custom_video: + import pickle + + with open('custom_video_jaco.pkl', 'rb') as file: + data = pickle.load(file) + env = wrappers.CustomMujoco(env, data) + + + elif suite == 'atari': + env = wrappers.Atari( + task, config.action_repeat, (64, 64), grayscale=False, + life_done=True, sticky_actions=True) + env = wrappers.OneHotAction(env) + else: + raise NotImplementedError(suite) + env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) + callbacks = [] + if train: + callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) + callbacks.append( + lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) + env = wrappers.Collect(env, callbacks, config.precision) + env = wrappers.RewardObs(env) + return env + + +def load_imgnet(train): + import pickle + name = 'train' if train else 'valid' + + # images_train. pkl and images_test.pkl to be downloaded from + + with open('images_{}.pkl'.format(name), 'rb') as fin: + imgnet = pickle.load(fin) + + imgnet = np.transpose(imgnet, axes=(0, 1, 3, 4, 2)) + + return imgnet + + +def main(config): + print('mainn') + if config.gpu_growth: + for gpu in tf.config.experimental.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(gpu, True) + assert config.precision in (16, 32), config.precision + if config.precision == 16: + prec.set_policy(prec.Policy('mixed_float16')) + config.steps = int(config.steps) + config.logdir.mkdir(parents=True, exist_ok=True) + print('Logdir', config.logdir) + + arg_dict = vars(config).copy() + del arg_dict['logdir'] + + # with open(os.path.join(config.logdir, 'args.json'), 'w') as fout: + # import json + # json.dump(arg_dict, fout) + + # Create environments. + datadir = config.logdir / 'episodes' + datadir.mkdir(parents=True, exist_ok=True) + writer = tf.summary.create_file_writer( + str(config.logdir), max_queue=1000, flush_millis=20000) + writer.set_as_default() + train_envs = [wrappers.Async(lambda: make_env( + config, writer, 'train', datadir, train=True), config.parallel) + for _ in range(config.envs)] + test_envs = [wrappers.Async(lambda: make_env_test( + config, writer, 'test', datadir, train=False), config.parallel) + for _ in range(config.envs)] + actspace = train_envs[0].action_space + + # Prefill dataset with random episodes. + step = count_steps(datadir, config) + prefill = max(0, config.prefill - step) + print(f'Prefill dataset with {prefill} steps.') + def random_agent(o, d, _): return ([actspace.sample() for _ in d], None, None, None) + tools.simulate(random_agent, train_envs, prefill / config.action_repeat) + writer.flush() + + # Train and regularly evaluate the agent. + step = count_steps(datadir, config) + print(f'Simulating agent for {config.steps-step} steps.') + agent = Dreamer(config, datadir, actspace, writer) + if (config.logdir / 'variables.pkl').exists(): + print('Load checkpoint.') + agent.load(config.logdir / 'variables.pkl') + state = None + while step < config.steps: + print('Start evaluation.') + tools.simulate( + functools.partial(agent, training=False), test_envs, episodes=1) + writer.flush() + print('Start collection.') + steps = config.eval_every // config.action_repeat + state = tools.simulate(agent, train_envs, steps, state=state) + step = count_steps(datadir, config) + agent.save(config.logdir / 'variables.pkl') + for env in train_envs + test_envs: + env.close() + +parser = argparse.ArgumentParser() +for key, value in define_config().items(): + parser.add_argument( + f'--{key}', type=tools.args_type(value), default=value) +args = parser.parse_args() + + +main(args) diff --git a/models.py b/models.py index 0f40316..62155c8 100644 --- a/models.py +++ b/models.py @@ -3,174 +3,487 @@ from tensorflow.keras import layers as tfkl from tensorflow_probability import distributions as tfd from tensorflow.keras.mixed_precision import experimental as prec - import tools + + + +class Inverse(tools.Module): + + def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu): + super().__init__() + self._activation = act + self._stoch_size = stoch + self._deter_size = deter + self._hidden_size = hidden + self._cell = tfkl.GRUCell(self._deter_size) + + def initial(self, batch_size): + dtype = prec.global_policy().compute_dtype + return dict( + mean=tf.zeros([batch_size, self._stoch_size], dtype), + std=tf.zeros([batch_size, self._stoch_size], dtype), + stoch=tf.zeros([batch_size, self._stoch_size], dtype), + deter=self._cell.get_initial_state(None, batch_size, dtype)) + + + + @tf.function + def observe(self, embed, action, state=None): + if state is None: + state = self.initial(tf.shape(action)[0]) + embed = tf.transpose(embed, [1, 0, 2]) + action = tf.transpose(action, [1, 0, 2]) + post, prior = tools.static_scan( + lambda prev, inputs: self.obs_step(prev[0], *inputs), + (action, embed), (state, state)) + post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()} + prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} + return post, prior + + @tf.function + def imagine(self, action, state=None): + if state is None: + state = self.initial(tf.shape(action)[0]) + assert isinstance(state, dict), state + action = tf.transpose(action, [1, 0, 2]) + prior = tools.static_scan(self.img_step, action, state) + prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} + return prior + + def get_feat(self, state): + return tf.concat([state['stoch'], state['deter']], -1) + + def get_dist(self, state): + return tfd.MultivariateNormalDiag(state['mean'], state['std']) + + @tf.function + def obs_step(self, prev_state, prev_action, embed): + prior = self.img_step(prev_state, prev_action) + x = tf.concat([prior['deter'], embed], -1) + x = self.get('obs1', tfkl.Dense, self._hidden_size, + self._activation)(x) + x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x) + mean, std = tf.split(x, 2, -1) + std = tf.nn.softplus(std) + 0.1 + stoch = self.get_dist({'mean': mean, 'std': std}).sample() + post = {'mean': mean, 'std': std, + 'stoch': stoch, 'deter': prior['deter']} + return post, prior + + @tf.function + def img_step(self, prev_state, prev_action): + x = tf.concat([prev_state['stoch'], prev_action], -1) + x = self.get('img1', tfkl.Dense, self._hidden_size, + self._activation)(x) + x, deter = self._cell(x, [prev_state['deter']]) + deter = deter[0] # Keras wraps the state in a list. + x = self.get('img2', tfkl.Dense, self._hidden_size, + self._activation)(x) + x = self.get('img3', tfkl.Dense, 2 * self._stoch_size, None)(x) + mean, std = tf.split(x, 2, -1) + std = tf.nn.softplus(std) + 0.1 + stoch = self.get_dist({'mean': mean, 'std': std}).sample() + prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter} + return prior + + + + + class RSSM(tools.Module): - def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu): - super().__init__() - self._activation = act - self._stoch_size = stoch - self._deter_size = deter - self._hidden_size = hidden - self._cell = tfkl.GRUCell(self._deter_size) - - def initial(self, batch_size): - dtype = prec.global_policy().compute_dtype - return dict( - mean=tf.zeros([batch_size, self._stoch_size], dtype), - std=tf.zeros([batch_size, self._stoch_size], dtype), - stoch=tf.zeros([batch_size, self._stoch_size], dtype), - deter=self._cell.get_initial_state(None, batch_size, dtype)) - - @tf.function - def observe(self, embed, action, state=None): - if state is None: - state = self.initial(tf.shape(action)[0]) - embed = tf.transpose(embed, [1, 0, 2]) - action = tf.transpose(action, [1, 0, 2]) - post, prior = tools.static_scan( - lambda prev, inputs: self.obs_step(prev[0], *inputs), - (action, embed), (state, state)) - post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()} - prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} - return post, prior - - @tf.function - def imagine(self, action, state=None): - if state is None: - state = self.initial(tf.shape(action)[0]) - assert isinstance(state, dict), state - action = tf.transpose(action, [1, 0, 2]) - prior = tools.static_scan(self.img_step, action, state) - prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} - return prior - - def get_feat(self, state): - return tf.concat([state['stoch'], state['deter']], -1) - - def get_dist(self, state): - return tfd.MultivariateNormalDiag(state['mean'], state['std']) - - @tf.function - def obs_step(self, prev_state, prev_action, embed): - prior = self.img_step(prev_state, prev_action) - x = tf.concat([prior['deter'], embed], -1) - x = self.get('obs1', tfkl.Dense, self._hidden_size, self._activation)(x) - x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x) - mean, std = tf.split(x, 2, -1) - std = tf.nn.softplus(std) + 0.1 - stoch = self.get_dist({'mean': mean, 'std': std}).sample() - post = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': prior['deter']} - return post, prior - - @tf.function - def img_step(self, prev_state, prev_action): - x = tf.concat([prev_state['stoch'], prev_action], -1) - x = self.get('img1', tfkl.Dense, self._hidden_size, self._activation)(x) - x, deter = self._cell(x, [prev_state['deter']]) - deter = deter[0] # Keras wraps the state in a list. - x = self.get('img2', tfkl.Dense, self._hidden_size, self._activation)(x) - x = self.get('img3', tfkl.Dense, 2 * self._stoch_size, None)(x) - mean, std = tf.split(x, 2, -1) - std = tf.nn.softplus(std) + 0.1 - stoch = self.get_dist({'mean': mean, 'std': std}).sample() - prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter} - return prior + def __init__(self, stoch=30, deter=200, hidden=200, act=tf.nn.elu): + super().__init__() + self._activation = act + self._stoch_size = stoch + self._deter_size = deter + self._hidden_size = hidden + self._cell = tfkl.GRUCell(self._deter_size) + + def initial(self, batch_size): + dtype = prec.global_policy().compute_dtype + return dict( + mean=tf.zeros([batch_size, self._stoch_size], dtype), + std=tf.zeros([batch_size, self._stoch_size], dtype), + stoch=tf.zeros([batch_size, self._stoch_size], dtype), + deter=self._cell.get_initial_state(None, batch_size, dtype)) + + @tf.function + def observe(self, embed, action, state=None): + if state is None: + state = self.initial(tf.shape(action)[0]) + embed = tf.transpose(embed, [1, 0, 2]) + action = tf.transpose(action, [1, 0, 2]) + post, prior = tools.static_scan( + lambda prev, inputs: self.obs_step(prev[0], *inputs), + (action, embed), (state, state)) + post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()} + prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} + return post, prior + + @tf.function + def imagine(self, action, state=None): + if state is None: + state = self.initial(tf.shape(action)[0]) + assert isinstance(state, dict), state + action = tf.transpose(action, [1, 0, 2]) + prior = tools.static_scan(self.img_step, action, state) + prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} + return prior + + def get_feat(self, state): + return tf.concat([state['stoch'], state['deter']], -1) + + def get_dist(self, state): + return tfd.MultivariateNormalDiag(state['mean'], state['std']) + + + @tf.function + def obs_step(self, prev_state, prev_action, embed): + prior = self.img_step(prev_state, prev_action) + x = tf.concat([prior['deter'], embed], -1) + x = self.get('obs1', tfkl.Dense, self._hidden_size, + self._activation)(x) + x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x) + mean, std = tf.split(x, 2, -1) + std = tf.nn.softplus(std) + 0.1 + stoch = self.get_dist({'mean': mean, 'std': std}).sample() + post = {'mean': mean, 'std': std, + 'stoch': stoch, 'deter': prior['deter']} + return post, prior + + @tf.function + def img_step(self, prev_state, prev_action): + x = tf.concat([prev_state['stoch'], prev_action], -1) + x = self.get('img1', tfkl.Dense, self._hidden_size, + self._activation)(x) + x, deter = self._cell(x, [prev_state['deter']]) + deter = deter[0] # Keras wraps the state in a list. + x = self.get('img2', tfkl.Dense, self._hidden_size, + self._activation)(x) + x = self.get('img3', tfkl.Dense, 2 * self._stoch_size, None)(x) + mean, std = tf.split(x, 2, -1) + std = tf.nn.softplus(std) + 0.1 + stoch = self.get_dist({'mean': mean, 'std': std}).sample() + prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter} + return prior class ConvEncoder(tools.Module): - def __init__(self, depth=32, act=tf.nn.relu): - self._act = act - self._depth = depth + def __init__(self, depth=32, act=tf.nn.relu): + self._act = act + self._depth = depth - def __call__(self, obs): - kwargs = dict(strides=2, activation=self._act) - x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:])) - x = self.get('h1', tfkl.Conv2D, 1 * self._depth, 4, **kwargs)(x) - x = self.get('h2', tfkl.Conv2D, 2 * self._depth, 4, **kwargs)(x) - x = self.get('h3', tfkl.Conv2D, 4 * self._depth, 4, **kwargs)(x) - x = self.get('h4', tfkl.Conv2D, 8 * self._depth, 4, **kwargs)(x) - shape = tf.concat([tf.shape(obs['image'])[:-3], [32 * self._depth]], 0) - return tf.reshape(x, shape) + def __call__(self, obs): + kwargs = dict(strides=2, activation=self._act) + x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:])) + x = self.get('h1', tfkl.Conv2D, 1 * self._depth, 4, **kwargs)(x) + x = self.get('h2', tfkl.Conv2D, 2 * self._depth, 4, **kwargs)(x) + x = self.get('h3', tfkl.Conv2D, 4 * self._depth, 4, **kwargs)(x) + x = self.get('h4', tfkl.Conv2D, 8 * self._depth, 4, **kwargs)(x) + shape = tf.concat([tf.shape(obs['image'])[:-3], [32 * self._depth]], 0) + return tf.reshape(x, shape) class ConvDecoder(tools.Module): - def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)): - self._act = act - self._depth = depth - self._shape = shape + def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)): + self._act = act + self._depth = depth + self._shape = shape + + def __call__(self, features): + kwargs = dict(strides=2, activation=self._act) + x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features) + x = tf.reshape(x, [-1, 1, 1, 32 * self._depth]) + x = self.get('h2', tfkl.Conv2DTranspose, + 4 * self._depth, 5, **kwargs)(x) + x = self.get('h3', tfkl.Conv2DTranspose, + 2 * self._depth, 5, **kwargs)(x) + x = self.get('h4', tfkl.Conv2DTranspose, + 1 * self._depth, 6, **kwargs)(x) + x = self.get('h5', tfkl.Conv2DTranspose, + self._shape[-1], 6, strides=2)(x) + mean = tf.reshape(x, tf.concat( + [tf.shape(features)[:-1], self._shape], 0)) + return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)) + + +class ContrastiveObsModel(tools.Module): + """The contrastive observation model + """ + def __init__(self, hz, hx, act=tf.nn.elu): + self.act = act + self.hz = hz + self.hx = hx + + def __call__(self, z, x): + """Both inputs have the shape of [batch_sz, length, dim]. For each positive sample, we use the rest of batch_sz * length - 1 samples as negative samples + + Args: + z (tensor): latent state + x (tensor): encoded observation + """ + + x = tf.reshape(x, (-1, x.shape[-1])) + z = tf.reshape(z, (-1, z.shape[-1])) + + # use mixed precision of float32 to avoid overflow + x = self.get('obs_enc1', tfkl.Dense, self.hx, self.act)(x) + x = self.get('obs_enc2', tfkl.Dense, self.hz, self.act, dtype='float32')(x) + + z = self.get('state_merge1', tfkl.Dense, self.hz, self.act)(z) + z = self.get('state_merge2', tfkl.Dense, self.hz, self.act, + dtype='float32')(z) + + weight_mat = tf.matmul(z, x, transpose_b=True) + + positive = tf.linalg.tensor_diag_part(weight_mat) + norm = tf.reduce_logsumexp(weight_mat, axis=1) + + # compute the infonce loss and change the predicion back to float16 + info_nce = tf.cast(positive - norm, 'float16') + + return info_nce + +class ContrastiveObsModelNWJ(tools.Module): + """The contrastive observation model + """ + def __init__(self, hz, hx, act=tf.nn.elu): + self.act = act + self.hz = hz + self.hx = hx + + def __call__(self, z, x): + """Both inputs have the shape of [batch_sz, length, dim]. For each positive sample, we use the rest of batch_sz * length - 1 samples as negative samples + + Args: + z (tensor): latent state + x (tensor): encoded observation + """ + + self.batch_size = z.shape[0] + self.negative_samples = z.shape[0] + print("contrastive obs",x.shape) - def __call__(self, features): - kwargs = dict(strides=2, activation=self._act) - x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features) - x = tf.reshape(x, [-1, 1, 1, 32 * self._depth]) - x = self.get('h2', tfkl.Conv2DTranspose, 4 * self._depth, 5, **kwargs)(x) - x = self.get('h3', tfkl.Conv2DTranspose, 2 * self._depth, 5, **kwargs)(x) - x = self.get('h4', tfkl.Conv2DTranspose, 1 * self._depth, 6, **kwargs)(x) - x = self.get('h5', tfkl.Conv2DTranspose, self._shape[-1], 6, strides=2)(x) - mean = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) - return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)) + x = tf.reshape(x, (-1, x.shape[-1])) + z = tf.reshape(z, (-1, z.shape[-1])) + print("contrastive obs",x.shape) + + + + + + # use mixed precision of float32 to avoid overflow + x = self.get('obs_enc1', tfkl.Dense, self.hx, self.act)(x) + x = self.get('obs_enc2', tfkl.Dense, self.hz, self.act, dtype='float32')(x) + + z = self.get('state_merge1', tfkl.Dense, self.hz, self.act)(z) + z = self.get('state_merge2', tfkl.Dense, self.hz, self.act, + dtype='float32')(z) + + score_matrix = tf.matmul(z, x, transpose_b=True) + self.batch_size = score_matrix.shape[0] + self.negative_samples = score_matrix.shape[0] + mask = tf.eye(self.batch_size) + complem_mask = 1 - mask + T_joint = tf.multiply(score_matrix, mask) + T_product = tf.multiply(score_matrix, complem_mask) + + E_joint = 1 / self.batch_size * tf.reduce_sum(T_joint) + E_product = 1 / (np.e * self.batch_size * self.negative_samples) * (tf.reduce_sum(tf.exp(T_product)) - self.batch_size) + mi = tf.cast(E_joint - E_product, 'float16') + + + return mi + + +class ContrastiveObsModelMINE(tools.Module): + """The contrastive observation model + """ + def __init__(self, hz, hx, act=tf.nn.elu): + self.act = act + self.hz = hz + self.hx = hx + + def __call__(self, z, x): + """Both inputs have the shape of [batch_sz, length, dim]. For each positive sample, we use the rest of batch_sz * length - 1 samples as negative samples + + Args: + z (tensor): latent state + x (tensor): encoded observation + """ + self.batch_size = z.shape[0] + self.negative_samples = z.shape[0] + print("contrastive obs",x.shape) + + x = tf.reshape(x, (-1, x.shape[-1])) + z = tf.reshape(z, (-1, z.shape[-1])) + print("contrastive obs",x.shape) + + + + + + # use mixed precision of float32 to avoid overflow + x = self.get('obs_enc1', tfkl.Dense, self.hx, self.act)(x) + x = self.get('obs_enc2', tfkl.Dense, self.hz, self.act, dtype='float32')(x) + + z = self.get('state_merge1', tfkl.Dense, self.hz, self.act)(z) + z = self.get('state_merge2', tfkl.Dense, self.hz, self.act, + dtype='float32')(z) + + score_matrix = tf.matmul(z, x, transpose_b=True) + self.batch_size = score_matrix.shape[0] + self.negative_samples = score_matrix.shape[0] + self.ema_decay = 0.99 + mask = tf.eye(self.batch_size) + complem_mask = 1 - mask + T_joint = tf.multiply(score_matrix, mask) + T_product = tf.multiply(score_matrix, complem_mask) + + E_joint = 1 / self.batch_size * tf.reduce_sum(T_joint) + E_product = np.log(1 / (self.batch_size * self.negative_samples)) + tf.math.log(tf.reduce_sum(tf.exp(T_product)) - self.batch_size) + mi = tf.cast(E_joint - E_product, 'float16') + + # ema_denominator = tf.Variable(tf.exp(tf.reduce_logsumexp(T_product))) + # ema_denominator -= (1 - self.ema_decay) * (ema_denominator - tf.exp(tf.reduce_logsumexp(T_product))) + # mi_for_grads = E_joint - 1 / tf.stop_gradient(ema_denominator) * tf.exp(tf.reduce_logsumexp(T_product)) + + # mi_for_grads = tf.cast(mi_for_grads,'float16') + + return mi class DenseDecoder(tools.Module): - def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu): - self._shape = shape - self._layers = layers - self._units = units - self._dist = dist - self._act = act - - def __call__(self, features): - x = features - for index in range(self._layers): - x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) - x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x) - x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) - if self._dist == 'normal': - return tfd.Independent(tfd.Normal(x, 1), len(self._shape)) - if self._dist == 'binary': - return tfd.Independent(tfd.Bernoulli(x), len(self._shape)) - raise NotImplementedError(self._dist) + def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu): + self._shape = shape + self._layers = layers + self._units = units + self._dist = dist + self._act = act + + def __call__(self, features): + x = features + for index in range(self._layers): + x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) + x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x) + x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) + if self._dist == 'normal': + return tfd.Independent(tfd.Normal(x, 1), len(self._shape)) + if self._dist == 'binary': + return tfd.Independent(tfd.Bernoulli(x), len(self._shape)) + raise NotImplementedError(self._dist) + +class QNetwork(tools.Module): + + def __init__(self, layers, units, dist='normal', act=tf.nn.elu, shape=()): + self._shape = shape + self._layers = layers + self._units = units + self._dist = dist + self._act = act + + def __call__(self, features): + x = features + for index in range(self._layers): + x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) + x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x) + x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) + + return x + + + +class InverseActionDecoder(tools.Module): + + def __init__( + self, size, layers, units, dist='tanh_normal', act=tf.nn.elu, + min_std=1e-4, init_std=5, mean_scale=5): + self._size = size + self._layers = layers + self._units = units + self._dist = dist + self._act = act + self._min_std = min_std + self._init_std = init_std + self._mean_scale = mean_scale + + def __call__(self, features1, features2): + raw_init_std = np.log(np.exp(self._init_std) - 1) + x = tf.concat([features1, features2], -1) + for index in range(self._layers): + x = self.get(f'h{index+100}', tfkl.Dense, self._units, self._act)(x) + if self._dist == 'tanh_normal': + # https://www.desmos.com/calculator/rcmcf5jwe7 + x = self.get(f'hout1', tfkl.Dense, 2 * self._size)(x) + mean, std = tf.split(x, 2, -1) + mean = self._mean_scale * tf.tanh(mean / self._mean_scale) + std = tf.nn.softplus(std + raw_init_std) + self._min_std + dist = tfd.Normal(mean, std) + dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) + dist = tfd.Independent(dist, 1) + dist = tools.SampleDist(dist) + elif self._dist == 'onehot': + x = self.get(f'hout1', tfkl.Dense, self._size)(x) + dist = tools.OneHotDist(x) + else: + raise NotImplementedError(dist) + return dist + + def actions_and_log_probs(self, features): + dist = self(features) + action = dist.sample() + log_prob = dist.log_prob(action) + + return action, log_prob + + class ActionDecoder(tools.Module): - def __init__( - self, size, layers, units, dist='tanh_normal', act=tf.nn.elu, - min_std=1e-4, init_std=5, mean_scale=5): - self._size = size - self._layers = layers - self._units = units - self._dist = dist - self._act = act - self._min_std = min_std - self._init_std = init_std - self._mean_scale = mean_scale - - def __call__(self, features): - raw_init_std = np.log(np.exp(self._init_std) - 1) - x = features - for index in range(self._layers): - x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) - if self._dist == 'tanh_normal': - # https://www.desmos.com/calculator/rcmcf5jwe7 - x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) - mean, std = tf.split(x, 2, -1) - mean = self._mean_scale * tf.tanh(mean / self._mean_scale) - std = tf.nn.softplus(std + raw_init_std) + self._min_std - dist = tfd.Normal(mean, std) - dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) - dist = tfd.Independent(dist, 1) - dist = tools.SampleDist(dist) - elif self._dist == 'onehot': - x = self.get(f'hout', tfkl.Dense, self._size)(x) - dist = tools.OneHotDist(x) - else: - raise NotImplementedError(dist) - return dist + def __init__( + self, size, layers, units, dist='tanh_normal', act=tf.nn.elu, + min_std=1e-4, init_std=5, mean_scale=5): + self._size = size + self._layers = layers + self._units = units + self._dist = dist + self._act = act + self._min_std = min_std + self._init_std = init_std + self._mean_scale = mean_scale + + def __call__(self, features): + raw_init_std = np.log(np.exp(self._init_std) - 1) + x = features + for index in range(self._layers): + x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) + if self._dist == 'tanh_normal': + # https://www.desmos.com/calculator/rcmcf5jwe7 + x = self.get(f'hout', tfkl.Dense, 2 * self._size)(x) + mean, std = tf.split(x, 2, -1) + mean = self._mean_scale * tf.tanh(mean / self._mean_scale) + std = tf.nn.softplus(std + raw_init_std) + self._min_std + dist = tfd.Normal(mean, std) + dist = tfd.TransformedDistribution(dist, tools.TanhBijector()) + dist = tfd.Independent(dist, 1) + dist = tools.SampleDist(dist) + elif self._dist == 'onehot': + x = self.get(f'hout', tfkl.Dense, self._size)(x) + dist = tools.OneHotDist(x) + else: + raise NotImplementedError(dist) + return dist + + def actions_and_log_probs(self, features): + dist = self(features) + action = dist.sample() + log_prob = dist.log_prob(action) + + return action, log_prob diff --git a/tools.py b/tools.py index d5a80f3..826d37b 100644 --- a/tools.py +++ b/tools.py @@ -13,419 +13,466 @@ from tensorflow.keras.mixed_precision import experimental as prec from tensorflow_probability import distributions as tfd +from PIL import Image + class AttrDict(dict): - __setattr__ = dict.__setitem__ - __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __getattr__ = dict.__getitem__ class Module(tf.Module): - def save(self, filename): - values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) - with pathlib.Path(filename).open('wb') as f: - pickle.dump(values, f) + def save(self, filename): + values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) + with pathlib.Path(filename).open('wb') as f: + pickle.dump(values, f) - def load(self, filename): - with pathlib.Path(filename).open('rb') as f: - values = pickle.load(f) - tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) + def load(self, filename): + with pathlib.Path(filename).open('rb') as f: + values = pickle.load(f) + tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) - def get(self, name, ctor, *args, **kwargs): - # Create or get layer by name to avoid mentioning it in the constructor. - if not hasattr(self, '_modules'): - self._modules = {} - if name not in self._modules: - self._modules[name] = ctor(*args, **kwargs) - return self._modules[name] + def get(self, name, ctor, *args, **kwargs): + # Create or get layer by name to avoid mentioning it in the constructor. + if not hasattr(self, '_modules'): + self._modules = {} + if name not in self._modules: + self._modules[name] = ctor(*args, **kwargs) + return self._modules[name] def nest_summary(structure): - if isinstance(structure, dict): - return {k: nest_summary(v) for k, v in structure.items()} - if isinstance(structure, list): - return [nest_summary(v) for v in structure] - if hasattr(structure, 'shape'): - return str(structure.shape).replace(', ', 'x').strip('(), ') - return '?' + if isinstance(structure, dict): + return {k: nest_summary(v) for k, v in structure.items()} + if isinstance(structure, list): + return [nest_summary(v) for v in structure] + if hasattr(structure, 'shape'): + return str(structure.shape).replace(', ', 'x').strip('(), ') + return '?' def graph_summary(writer, fn, *args): - step = tf.summary.experimental.get_step() - def inner(*args): - tf.summary.experimental.set_step(step) - with writer.as_default(): - fn(*args) - return tf.numpy_function(inner, args, []) + step = tf.summary.experimental.get_step() + + def inner(*args): + tf.summary.experimental.set_step(step) + with writer.as_default(): + fn(*args) + return tf.numpy_function(inner, args, []) def video_summary(name, video, step=None, fps=20): - name = name if isinstance(name, str) else name.decode('utf-8') - if np.issubdtype(video.dtype, np.floating): - video = np.clip(255 * video, 0, 255).astype(np.uint8) - B, T, H, W, C = video.shape - try: - frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) - summary = tf1.Summary() - image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C) - image.encoded_image_string = encode_gif(frames, fps) - summary.value.add(tag=name + '/gif', image=image) - tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) - except (IOError, OSError) as e: - print('GIF summaries require ffmpeg in $PATH.', e) - frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) - tf.summary.image(name + '/grid', frames, step) + name = name if isinstance(name, str) else str(name) + if np.issubdtype(video.dtype, np.floating): + video = np.clip(255 * video, 0, 255).astype(np.uint8) + B, T, H, W, C = video.shape + try: + frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) + summary = tf1.Summary() + image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C) + image.encoded_image_string = encode_gif(frames, fps) + summary.value.add(tag=name + '/gif', image=image) + if step == None: + step = 100000 + tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) + except (IOError, OSError) as e: + print('GIF summaries require ffmpeg in $PATH.', e) + frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) + if step == None: + step = 100000 + tf.summary.image(name + '/grid', frames, step) def encode_gif(frames, fps): - from subprocess import Popen, PIPE - h, w, c = frames[0].shape - pxfmt = {1: 'gray', 3: 'rgb24'}[c] - cmd = ' '.join([ - f'ffmpeg -y -f rawvideo -vcodec rawvideo', - f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', - f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', - f'-r {fps:.02f} -f gif -']) - proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) - for image in frames: - proc.stdin.write(image.tostring()) - out, err = proc.communicate() - if proc.returncode: - raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) - del proc - return out + from subprocess import Popen, PIPE + h, w, c = frames[0].shape + pxfmt = {1: 'gray', 3: 'rgb24'}[c] + cmd = ' '.join([ + f'ffmpeg -y -f rawvideo -vcodec rawvideo', + f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', + f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', + f'-r {fps:.02f} -f gif -']) + proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) + for image in frames: + proc.stdin.write(image.tostring()) + out, err = proc.communicate() + if proc.returncode: + raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) + del proc + return out def simulate(agent, envs, steps=0, episodes=0, state=None): - # Initialize or unpack simulation state. - if state is None: - step, episode = 0, 0 - done = np.ones(len(envs), np.bool) - length = np.zeros(len(envs), np.int32) - obs = [None] * len(envs) - agent_state = None - else: - step, episode, done, length, obs, agent_state = state - while (steps and step < steps) or (episodes and episode < episodes): - # Reset envs if necessary. - if done.any(): - indices = [index for index, d in enumerate(done) if d] - promises = [envs[i].reset(blocking=False) for i in indices] - for index, promise in zip(indices, promises): - obs[index] = promise() - # Step agents. - obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} - action, agent_state = agent(obs, done, agent_state) - action = np.array(action) - assert len(action) == len(envs) - # Step envs. - promises = [e.step(a, blocking=False) for e, a in zip(envs, action)] - obs, _, done = zip(*[p()[:3] for p in promises]) - obs = list(obs) - done = np.stack(done) - episode += int(done.sum()) - length += 1 - step += (done * length).sum() - length *= (1 - done) - # Return new state to allow resuming the simulation. - return (step - steps, episode - episodes, done, length, obs, agent_state) + # Initialize or unpack simulation state. + if state is None: + step, episode = 0, 0 + done = np.ones(len(envs), np.bool) + length = np.zeros(len(envs), np.int32) + obs = [None] * len(envs) + agent_state = None + else: + step, episode, done, length, obs, agent_state = state + while (steps and step < steps) or (episodes and episode < episodes): + # Reset envs if necessary. + if done.any(): + indices = [index for index, d in enumerate(done) if d] + promises = [envs[i].reset(blocking=False) for i in indices] + for index, promise in zip(indices, promises): + obs[index] = promise() + # Step agents. + obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} + action, agent_state = agent(obs, done, agent_state) + action = np.array(action) + assert len(action) == len(envs) + # Step envs. + promises = [e.step(a, blocking=False) for e, a in zip(envs, action)] + obs, _, done = zip(*[p()[:3] for p in promises]) + obs = list(obs) + done = np.stack(done) + episode += int(done.sum()) + length += 1 + step += (done * length).sum() + length *= (1 - done) + # Return new state to allow resuming the simulation. + return (step - steps, episode - episodes, done, length, obs, agent_state) def count_episodes(directory): - filenames = directory.glob('*.npz') - lengths = [int(n.stem.rsplit('-', 1)[-1]) - 1 for n in filenames] - episodes, steps = len(lengths), sum(lengths) - return episodes, steps + filenames = directory.glob('*.npz') + lengths = [int(n.stem.rsplit('-', 1)[-1]) - 1 for n in filenames] + episodes, steps = len(lengths), sum(lengths) + return episodes, steps def save_episodes(directory, episodes): - directory = pathlib.Path(directory).expanduser() - directory.mkdir(parents=True, exist_ok=True) - timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') - for episode in episodes: - identifier = str(uuid.uuid4().hex) - length = len(episode['reward']) - filename = directory / f'{timestamp}-{identifier}-{length}.npz' - with io.BytesIO() as f1: - np.savez_compressed(f1, **episode) - f1.seek(0) - with filename.open('wb') as f2: - f2.write(f1.read()) + directory = pathlib.Path(directory).expanduser() + directory.mkdir(parents=True, exist_ok=True) + timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') + for episode in episodes: + identifier = str(uuid.uuid4().hex) + length = len(episode['reward']) + filename = directory / f'{timestamp}-{identifier}-{length}.npz' + with io.BytesIO() as f1: + np.savez_compressed(f1, **episode) + f1.seek(0) + with filename.open('wb') as f2: + f2.write(f1.read()) def load_episodes(directory, rescan, length=None, balance=False, seed=0): - directory = pathlib.Path(directory).expanduser() - random = np.random.RandomState(seed) - cache = {} - while True: - for filename in directory.glob('*.npz'): - if filename not in cache: - try: - with filename.open('rb') as f: - episode = np.load(f) - episode = {k: episode[k] for k in episode.keys()} - except Exception as e: - print(f'Could not load episode: {e}') - continue - cache[filename] = episode - keys = list(cache.keys()) - for index in random.choice(len(keys), rescan): - episode = cache[keys[index]] - if length: - total = len(next(iter(episode.values()))) - available = total - length - if available < 1: - print(f'Skipped short episode of length {available}.') - continue - if balance: - index = min(random.randint(0, total), available) - else: - index = int(random.randint(0, available)) - episode = {k: v[index: index + length] for k, v in episode.items()} - yield episode + directory = pathlib.Path(directory).expanduser() + random = np.random.RandomState(seed) + cache = {} + while True: + for filename in directory.glob('*.npz'): + if filename not in cache: + try: + with filename.open('rb') as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f'Could not load episode: {e}') + continue + cache[filename] = episode + keys = list(cache.keys()) + for index in random.choice(len(keys), rescan): + episode = cache[keys[index]] + if length: + total = len(next(iter(episode.values()))) + available = total - length + if available < 1: + print(f'Skipped short episode of length {available}.') + continue + if balance: + index = min(random.randint(0, total), available) + else: + index = int(random.randint(0, available)) + episode = {k: v[index: index + length] + for k, v in episode.items()} + yield episode class DummyEnv: - def __init__(self): - self._random = np.random.RandomState(seed=0) - self._step = None - - @property - def observation_space(self): - low = np.zeros([64, 64, 3], dtype=np.uint8) - high = 255 * np.ones([64, 64, 3], dtype=np.uint8) - spaces = {'image': gym.spaces.Box(low, high)} - return gym.spaces.Dict(spaces) - - @property - def action_space(self): - low = -np.ones([5], dtype=np.float32) - high = np.ones([5], dtype=np.float32) - return gym.spaces.Box(low, high) - - def reset(self): - self._step = 0 - obs = self.observation_space.sample() - return obs - - def step(self, action): - obs = self.observation_space.sample() - reward = self._random.uniform(0, 1) - self._step += 1 - done = self._step >= 1000 - info = {} - return obs, reward, done, info + def __init__(self): + self._random = np.random.RandomState(seed=0) + self._step = None + + @property + def observation_space(self): + low = np.zeros([64, 64, 3], dtype=np.uint8) + high = 255 * np.ones([64, 64, 3], dtype=np.uint8) + spaces = {'image': gym.spaces.Box(low, high)} + return gym.spaces.Dict(spaces) + + @property + def action_space(self): + low = -np.ones([5], dtype=np.float32) + high = np.ones([5], dtype=np.float32) + return gym.spaces.Box(low, high) + + def reset(self): + self._step = 0 + obs = self.observation_space.sample() + return obs + + def step(self, action): + obs = self.observation_space.sample() + reward = self._random.uniform(0, 1) + self._step += 1 + done = self._step >= 1000 + info = {} + return obs, reward, done, info class SampleDist: - def __init__(self, dist, samples=100): - self._dist = dist - self._samples = samples + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples - @property - def name(self): - return 'SampleDist' + @property + def name(self): + return 'SampleDist' - def __getattr__(self, name): - return getattr(self._dist, name) + def __getattr__(self, name): + return getattr(self._dist, name) - def mean(self): - samples = self._dist.sample(self._samples) - return tf.reduce_mean(samples, 0) + def mean(self): + samples = self._dist.sample(self._samples) + return tf.reduce_mean(samples, 0) - def mode(self): - sample = self._dist.sample(self._samples) - logprob = self._dist.log_prob(sample) - return tf.gather(sample, tf.argmax(logprob))[0] + def mode(self): + sample = self._dist.sample(self._samples) + logprob = self._dist.log_prob(sample) + return tf.gather(sample, tf.argmax(logprob))[0] - def entropy(self): - sample = self._dist.sample(self._samples) - logprob = self.log_prob(sample) - return -tf.reduce_mean(logprob, 0) + def entropy(self): + sample = self._dist.sample(self._samples) + logprob = self.log_prob(sample) + return -tf.reduce_mean(logprob, 0) class OneHotDist: - def __init__(self, logits=None, probs=None): - self._dist = tfd.Categorical(logits=logits, probs=probs) - self._num_classes = self.mean().shape[-1] - self._dtype = prec.global_policy().compute_dtype + def __init__(self, logits=None, probs=None): + self._dist = tfd.Categorical(logits=logits, probs=probs) + self._num_classes = self.mean().shape[-1] + self._dtype = prec.global_policy().compute_dtype - @property - def name(self): - return 'OneHotDist' + @property + def name(self): + return 'OneHotDist' - def __getattr__(self, name): - return getattr(self._dist, name) + def __getattr__(self, name): + return getattr(self._dist, name) - def prob(self, events): - indices = tf.argmax(events, axis=-1) - return self._dist.prob(indices) + def prob(self, events): + indices = tf.argmax(events, axis=-1) + return self._dist.prob(indices) - def log_prob(self, events): - indices = tf.argmax(events, axis=-1) - return self._dist.log_prob(indices) + def log_prob(self, events): + indices = tf.argmax(events, axis=-1) + return self._dist.log_prob(indices) - def mean(self): - return self._dist.probs_parameter() + def mean(self): + return self._dist.probs_parameter() - def mode(self): - return self._one_hot(self._dist.mode()) + def mode(self): + return self._one_hot(self._dist.mode()) - def sample(self, amount=None): - amount = [amount] if amount else [] - indices = self._dist.sample(*amount) - sample = self._one_hot(indices) - probs = self._dist.probs_parameter() - sample += tf.cast(probs - tf.stop_gradient(probs), self._dtype) - return sample + def sample(self, amount=None): + amount = [amount] if amount else [] + indices = self._dist.sample(*amount) + sample = self._one_hot(indices) + probs = self._dist.probs_parameter() + sample += tf.cast(probs - tf.stop_gradient(probs), self._dtype) + return sample - def _one_hot(self, indices): - return tf.one_hot(indices, self._num_classes, dtype=self._dtype) + def _one_hot(self, indices): + return tf.one_hot(indices, self._num_classes, dtype=self._dtype) class TanhBijector(tfp.bijectors.Bijector): - def __init__(self, validate_args=False, name='tanh'): - super().__init__( - forward_min_event_ndims=0, - validate_args=validate_args, - name=name) + def __init__(self, validate_args=False, name='tanh'): + super().__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) - def _forward(self, x): - return tf.nn.tanh(x) + def _forward(self, x): + return tf.nn.tanh(x) - def _inverse(self, y): - dtype = y.dtype - y = tf.cast(y, tf.float32) - y = tf.where( - tf.less_equal(tf.abs(y), 1.), - tf.clip_by_value(y, -0.99999997, 0.99999997), y) - y = tf.atanh(y) - y = tf.cast(y, dtype) - return y + def _inverse(self, y): + dtype = y.dtype + y = tf.cast(y, tf.float32) + y = tf.where( + tf.less_equal(tf.abs(y), 1.), + tf.clip_by_value(y, -0.99999997, 0.99999997), y) + y = tf.atanh(y) + y = tf.cast(y, dtype) + return y - def _forward_log_det_jacobian(self, x): - log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) - return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) + def _forward_log_det_jacobian(self, x): + log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) + return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) def lambda_return( - reward, value, pcont, bootstrap, lambda_, axis): - # Setting lambda=1 gives a discounted Monte Carlo return. - # Setting lambda=0 gives a fixed 1-step return. - assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) - if isinstance(pcont, (int, float)): - pcont = pcont * tf.ones_like(reward) - dims = list(range(reward.shape.ndims)) - dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] - if axis != 0: - reward = tf.transpose(reward, dims) - value = tf.transpose(value, dims) - pcont = tf.transpose(pcont, dims) - if bootstrap is None: - bootstrap = tf.zeros_like(value[-1]) - next_values = tf.concat([value[1:], bootstrap[None]], 0) - inputs = reward + pcont * next_values * (1 - lambda_) - returns = static_scan( - lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, - (inputs, pcont), bootstrap, reverse=True) - if axis != 0: - returns = tf.transpose(returns, dims) - return returns + reward, value, pcont, bootstrap, lambda_, axis): + # Setting lambda=1 gives a discounted Monte Carlo return. + # Setting lambda=0 gives a fixed 1-step return. + assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) + if isinstance(pcont, (int, float)): + pcont = pcont * tf.ones_like(reward) + dims = list(range(reward.shape.ndims)) + dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] + if axis != 0: + reward = tf.transpose(reward, dims) + value = tf.transpose(value, dims) + pcont = tf.transpose(pcont, dims) + if bootstrap is None: + bootstrap = tf.zeros_like(value[-1]) + next_values = tf.concat([value[1:], bootstrap[None]], 0) + inputs = reward + pcont * next_values * (1 - lambda_) + returns = static_scan( + lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, + (inputs, pcont), bootstrap, reverse=True) + if axis != 0: + returns = tf.transpose(returns, dims) + return returns class Adam(tf.Module): - def __init__(self, name, modules, lr, clip=None, wd=None, wdpattern=r'.*'): - self._name = name - self._modules = modules - self._clip = clip - self._wd = wd - self._wdpattern = wdpattern - self._opt = tf.optimizers.Adam(lr) - self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic') - self._variables = None - - @property - def variables(self): - return self._opt.variables() - - def __call__(self, tape, loss): - if self._variables is None: - variables = [module.variables for module in self._modules] - self._variables = tf.nest.flatten(variables) - count = sum(np.prod(x.shape) for x in self._variables) - print(f'Found {count} {self._name} parameters.') - assert len(loss.shape) == 0, loss.shape - with tape: - loss = self._opt.get_scaled_loss(loss) - grads = tape.gradient(loss, self._variables) - grads = self._opt.get_unscaled_gradients(grads) - norm = tf.linalg.global_norm(grads) - if self._clip: - grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) - if self._wd: - context = tf.distribute.get_replica_context() - context.merge_call(self._apply_weight_decay) - self._opt.apply_gradients(zip(grads, self._variables)) - return norm - - def _apply_weight_decay(self, strategy): - print('Applied weight decay to variables:') - for var in self._variables: - if re.search(self._wdpattern, self._name + '/' + var.name): - print('- ' + self._name + '/' + var.name) - strategy.extended.update(var, lambda var: self._wd * var) + def __init__(self, name, modules, lr, clip=None, wd=None, wdpattern=r'.*'): + self._name = name + self._modules = modules + self._clip = clip + self._wd = wd + self._wdpattern = wdpattern + self._opt = tf.optimizers.Adam(lr) + self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic') + self._variables = None + + @property + def variables(self): + return self._opt.variables() + + def __call__(self, tape, loss): + if self._variables is None: + variables = [module.variables for module in self._modules] + self._variables = tf.nest.flatten(variables) + count = sum(np.prod(x.shape) for x in self._variables) + print(f'Found {count} {self._name} parameters.') + print("loss.shape",loss.shape) + assert len(loss.shape) == 0, loss.shape + with tape: + loss = self._opt.get_scaled_loss(loss) + grads = tape.gradient(loss, self._variables) + grads = self._opt.get_unscaled_gradients(grads) + norm = tf.linalg.global_norm(grads) + if self._clip: + grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) + if self._wd: + context = tf.distribute.get_replica_context() + context.merge_call(self._apply_weight_decay) + self._opt.apply_gradients(zip(grads, self._variables)) + return norm + + def _apply_weight_decay(self, strategy): + print('Applied weight decay to variables:') + for var in self._variables: + if re.search(self._wdpattern, self._name + '/' + var.name): + print('- ' + self._name + '/' + var.name) + strategy.extended.update(var, lambda var: self._wd * var) def args_type(default): - if isinstance(default, bool): - return lambda x: bool(['False', 'True'].index(x)) - if isinstance(default, int): - return lambda x: float(x) if ('e' in x or '.' in x) else int(x) - if isinstance(default, pathlib.Path): - return lambda x: pathlib.Path(x).expanduser() - return type(default) + if isinstance(default, bool): + return lambda x: bool(['False', 'True'].index(x)) + if isinstance(default, int): + return lambda x: float(x) if ('e' in x or '.' in x) else int(x) + if isinstance(default, pathlib.Path): + return lambda x: pathlib.Path(x).expanduser() + return type(default) def static_scan(fn, inputs, start, reverse=False): - last = start - outputs = [[] for _ in tf.nest.flatten(start)] - indices = range(len(tf.nest.flatten(inputs)[0])) - if reverse: - indices = reversed(indices) - for index in indices: - inp = tf.nest.map_structure(lambda x: x[index], inputs) - last = fn(last, inp) - [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] - if reverse: - outputs = [list(reversed(x)) for x in outputs] - outputs = [tf.stack(x, 0) for x in outputs] - return tf.nest.pack_sequence_as(start, outputs) + last = start + outputs = [[] for _ in tf.nest.flatten(start)] + indices = range(len(tf.nest.flatten(inputs)[0])) + if reverse: + indices = reversed(indices) + for index in indices: + inp = tf.nest.map_structure(lambda x: x[index], inputs) + last = fn(last, inp) + [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] + if reverse: + outputs = [list(reversed(x)) for x in outputs] + outputs = [tf.stack(x, 0) for x in outputs] + return tf.nest.pack_sequence_as(start, outputs) + + +def static_scan_inverse(fn, inputs, start, reverse=False): + last = start + outputs = [[] for _ in tf.nest.flatten(start)] + indices = range(len(tf.nest.flatten(inputs)[0])) + if reverse: + indices = reversed(indices) + for index in indices: + inp = tf.nest.map_structure(lambda x: x[index], inputs) + last = fn(last, inp) + [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] + if reverse: + outputs = [list(reversed(x)) for x in outputs] + outputs = [tf.stack(x, 0) for x in outputs] + return tf.nest.pack_sequence_as(start, outputs) + + + + +def static_scan_action(fn1, fn2, inputs, start, reverse=False): + last = start + outputs = [[] for _ in tf.nest.flatten(start)] + indices = range(len(tf.nest.flatten(inputs)[0])) + actions = [] + if reverse: + indices = reversed(indices) + for index in indices: + inp = tf.nest.map_structure(lambda x: x[index], inputs) + action = fn2(last) + last = fn1(last, action, inp) + [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] + actions.append(action) + if reverse: + outputs = [list(reversed(x)) for x in outputs] + outputs = [tf.stack(x, 0) for x in outputs] + return tf.nest.pack_sequence_as(start, outputs), actions[0] + def _mnd_sample(self, sample_shape=(), seed=None, name='sample'): - return tf.random.normal( - tuple(sample_shape) + tuple(self.event_shape), - self.mean(), self.stddev(), self.dtype, seed, name) + return tf.random.normal( + tuple(sample_shape) + tuple(self.event_shape), + self.mean(), self.stddev(), self.dtype, seed, name) tfd.MultivariateNormalDiag.sample = _mnd_sample def _cat_sample(self, sample_shape=(), seed=None, name='sample'): - assert len(sample_shape) in (0, 1), sample_shape - assert len(self.logits_parameter().shape) == 2 - indices = tf.random.categorical( - self.logits_parameter(), sample_shape[0] if sample_shape else 1, - self.dtype, seed, name) - if not sample_shape: - indices = indices[..., 0] - return indices + assert len(sample_shape) in (0, 1), sample_shape + assert len(self.logits_parameter().shape) == 2 + indices = tf.random.categorical( + self.logits_parameter(), sample_shape[0] if sample_shape else 1, + self.dtype, seed, name) + if not sample_shape: + indices = indices[..., 0] + return indices tfd.Categorical.sample = _cat_sample @@ -433,27 +480,39 @@ def _cat_sample(self, sample_shape=(), seed=None, name='sample'): class Every: - def __init__(self, every): - self._every = every - self._last = None + def __init__(self, every): + self._every = every + self._last = None - def __call__(self, step): - if self._last is None: - self._last = step - return True - if step >= self._last + self._every: - self._last += self._every - return True - return False + def __call__(self, step): + if self._last is None: + self._last = step + return True + if step >= self._last + self._every: + self._last += self._every + return True + return False class Once: - def __init__(self): - self._once = True + def __init__(self): + self._once = True + + def __call__(self): + if self._once: + self._once = False + return True + return False + + +def load_imgnet(train): + import pickle + name = 'train' if train else 'valid' + + with open('/cns/tp-d/home/homanga/logdir/natural_{}.pkl'.format(name), 'rb') as fin: + imgnet = pickle.load(fin) + + imgnet = np.transpose(imgnet, axes=(0, 1, 3, 4, 2)) - def __call__(self): - if self._once: - self._once = False - return True - return False + return imgnet diff --git a/tools_inv.py b/tools_inv.py new file mode 100644 index 0000000..6168f6d --- /dev/null +++ b/tools_inv.py @@ -0,0 +1,522 @@ +import datetime +import io +import pathlib +import pickle +import re +import uuid + +import gym +import numpy as np +import tensorflow as tf +import tensorflow.compat.v1 as tf1 +import tensorflow_probability as tfp +from tensorflow.keras.mixed_precision import experimental as prec +from tensorflow_probability import distributions as tfd + +from PIL import Image + + +class AttrDict(dict): + + __setattr__ = dict.__setitem__ + __getattr__ = dict.__getitem__ + + +class Module(tf.Module): + + def save(self, filename): + values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) + with pathlib.Path(filename).open('wb') as f: + pickle.dump(values, f) + + def load(self, filename): + with pathlib.Path(filename).open('rb') as f: + values = pickle.load(f) + tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) + + def get(self, name, ctor, *args, **kwargs): + # Create or get layer by name to avoid mentioning it in the constructor. + if not hasattr(self, '_modules'): + self._modules = {} + if name not in self._modules: + self._modules[name] = ctor(*args, **kwargs) + return self._modules[name] + + +def nest_summary(structure): + if isinstance(structure, dict): + return {k: nest_summary(v) for k, v in structure.items()} + if isinstance(structure, list): + return [nest_summary(v) for v in structure] + if hasattr(structure, 'shape'): + return str(structure.shape).replace(', ', 'x').strip('(), ') + return '?' + + +def graph_summary(writer, fn, *args): + step = tf.summary.experimental.get_step() + + def inner(*args): + tf.summary.experimental.set_step(step) + with writer.as_default(): + fn(*args) + return tf.numpy_function(inner, args, []) + + +def video_summary(name, video, step=None, fps=20): + name = name if isinstance(name, str) else str(name) + if np.issubdtype(video.dtype, np.floating): + video = np.clip(255 * video, 0, 255).astype(np.uint8) + B, T, H, W, C = video.shape + try: + frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) + summary = tf1.Summary() + image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C) + image.encoded_image_string = encode_gif(frames, fps) + summary.value.add(tag=name + '/gif', image=image) + if step == None: + step = 100000 + tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) + except (IOError, OSError) as e: + print('GIF summaries require ffmpeg in $PATH.', e) + frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) + if step == None: + step = 100000 + tf.summary.image(name + '/grid', frames, step) + + +def encode_gif(frames, fps): + from subprocess import Popen, PIPE + h, w, c = frames[0].shape + pxfmt = {1: 'gray', 3: 'rgb24'}[c] + cmd = ' '.join([ + f'ffmpeg -y -f rawvideo -vcodec rawvideo', + f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', + f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', + f'-r {fps:.02f} -f gif -']) + proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) + for image in frames: + proc.stdin.write(image.tostring()) + out, err = proc.communicate() + if proc.returncode: + raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) + del proc + return out + + +def simulate(agent, envs, steps=0, episodes=0, state=None): + # Initialize or unpack simulation state. + if state is None: + step, episode = 0, 0 + done = np.ones(len(envs), np.bool) + length = np.zeros(len(envs), np.int32) + obs = [None] * len(envs) + agent_state = None + old_agent_state = None + old_action = None + else: + step, episode, done, length, obs, agent_state, old_agent_state, old_action = state + while (steps and step < steps) or (episodes and episode < episodes): + # Reset envs if necessary. + if done.any(): + indices = [index for index, d in enumerate(done) if d] + promises = [envs[i].reset(blocking=False) for i in indices] + for index, promise in zip(indices, promises): + obs[index] = promise() + # Step agents. + obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} + action, agent_state, old_agent_state, old_action = agent(obs, done, agent_state) + action = np.array(action) + assert len(action) == len(envs) + # Step envs. + promises = [e.step(a, blocking=False) for e, a in zip(envs, action)] + obs, _, done = zip(*[p()[:3] for p in promises]) + obs = list(obs) + done = np.stack(done) + episode += int(done.sum()) + length += 1 + step += (done * length).sum() + length *= (1 - done) + # Return new state to allow resuming the simulation. + return (step - steps, episode - episodes, done, length, obs, agent_state, old_agent_state, old_action) + + +def count_episodes(directory): + filenames = directory.glob('*.npz') + lengths = [int(n.stem.rsplit('-', 1)[-1]) - 1 for n in filenames] + episodes, steps = len(lengths), sum(lengths) + return episodes, steps + + +def save_episodes(directory, episodes): + directory = pathlib.Path(directory).expanduser() + directory.mkdir(parents=True, exist_ok=True) + timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') + for episode in episodes: + identifier = str(uuid.uuid4().hex) + length = len(episode['reward']) + filename = directory / f'{timestamp}-{identifier}-{length}.npz' + with io.BytesIO() as f1: + np.savez_compressed(f1, **episode) + f1.seek(0) + with filename.open('wb') as f2: + f2.write(f1.read()) + + +def load_episodes(directory, rescan, length=None, balance=False, seed=0): + directory = pathlib.Path(directory).expanduser() + random = np.random.RandomState(seed) + cache = {} + while True: + for filename in directory.glob('*.npz'): + if filename not in cache: + try: + with filename.open('rb') as f: + episode = np.load(f) + episode = {k: episode[k] for k in episode.keys()} + except Exception as e: + print(f'Could not load episode: {e}') + continue + cache[filename] = episode + keys = list(cache.keys()) + for index in random.choice(len(keys), rescan): + episode = cache[keys[index]] + if length: + total = len(next(iter(episode.values()))) + available = total - length + if available < 1: + print(f'Skipped short episode of length {available}.') + continue + if balance: + index = min(random.randint(0, total), available) + else: + index = int(random.randint(0, available)) + episode = {k: v[index: index + length] + for k, v in episode.items()} + yield episode + + +class DummyEnv: + + def __init__(self): + self._random = np.random.RandomState(seed=0) + self._step = None + + @property + def observation_space(self): + low = np.zeros([64, 64, 3], dtype=np.uint8) + high = 255 * np.ones([64, 64, 3], dtype=np.uint8) + spaces = {'image': gym.spaces.Box(low, high)} + return gym.spaces.Dict(spaces) + + @property + def action_space(self): + low = -np.ones([5], dtype=np.float32) + high = np.ones([5], dtype=np.float32) + return gym.spaces.Box(low, high) + + def reset(self): + self._step = 0 + obs = self.observation_space.sample() + return obs + + def step(self, action): + obs = self.observation_space.sample() + reward = self._random.uniform(0, 1) + self._step += 1 + done = self._step >= 1000 + info = {} + return obs, reward, done, info + + +class SampleDist: + + def __init__(self, dist, samples=100): + self._dist = dist + self._samples = samples + + @property + def name(self): + return 'SampleDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def mean(self): + samples = self._dist.sample(self._samples) + return tf.reduce_mean(samples, 0) + + def mode(self): + sample = self._dist.sample(self._samples) + logprob = self._dist.log_prob(sample) + return tf.gather(sample, tf.argmax(logprob))[0] + + def entropy(self): + sample = self._dist.sample(self._samples) + logprob = self.log_prob(sample) + return -tf.reduce_mean(logprob, 0) + + +class OneHotDist: + + def __init__(self, logits=None, probs=None): + self._dist = tfd.Categorical(logits=logits, probs=probs) + self._num_classes = self.mean().shape[-1] + self._dtype = prec.global_policy().compute_dtype + + @property + def name(self): + return 'OneHotDist' + + def __getattr__(self, name): + return getattr(self._dist, name) + + def prob(self, events): + indices = tf.argmax(events, axis=-1) + return self._dist.prob(indices) + + def log_prob(self, events): + indices = tf.argmax(events, axis=-1) + return self._dist.log_prob(indices) + + def mean(self): + return self._dist.probs_parameter() + + def mode(self): + return self._one_hot(self._dist.mode()) + + def sample(self, amount=None): + amount = [amount] if amount else [] + indices = self._dist.sample(*amount) + sample = self._one_hot(indices) + probs = self._dist.probs_parameter() + sample += tf.cast(probs - tf.stop_gradient(probs), self._dtype) + return sample + + def _one_hot(self, indices): + return tf.one_hot(indices, self._num_classes, dtype=self._dtype) + + +class TanhBijector(tfp.bijectors.Bijector): + + def __init__(self, validate_args=False, name='tanh'): + super().__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return tf.nn.tanh(x) + + def _inverse(self, y): + dtype = y.dtype + y = tf.cast(y, tf.float32) + y = tf.where( + tf.less_equal(tf.abs(y), 1.), + tf.clip_by_value(y, -0.99999997, 0.99999997), y) + y = tf.atanh(y) + y = tf.cast(y, dtype) + return y + + def _forward_log_det_jacobian(self, x): + log2 = tf.math.log(tf.constant(2.0, dtype=x.dtype)) + return 2.0 * (log2 - x - tf.nn.softplus(-2.0 * x)) + + +def lambda_return( + reward, value, pcont, bootstrap, lambda_, axis): + # Setting lambda=1 gives a discounted Monte Carlo return. + # Setting lambda=0 gives a fixed 1-step return. + assert reward.shape.ndims == value.shape.ndims, (reward.shape, value.shape) + if isinstance(pcont, (int, float)): + pcont = pcont * tf.ones_like(reward) + dims = list(range(reward.shape.ndims)) + dims = [axis] + dims[1:axis] + [0] + dims[axis + 1:] + if axis != 0: + reward = tf.transpose(reward, dims) + value = tf.transpose(value, dims) + pcont = tf.transpose(pcont, dims) + if bootstrap is None: + bootstrap = tf.zeros_like(value[-1]) + next_values = tf.concat([value[1:], bootstrap[None]], 0) + inputs = reward + pcont * next_values * (1 - lambda_) + returns = static_scan( + lambda agg, cur: cur[0] + cur[1] * lambda_ * agg, + (inputs, pcont), bootstrap, reverse=True) + if axis != 0: + returns = tf.transpose(returns, dims) + return returns + + +class Adam(tf.Module): + + def __init__(self, name, modules, lr, clip=None, wd=None, wdpattern=r'.*'): + self._name = name + self._modules = modules + self._clip = clip + self._wd = wd + self._wdpattern = wdpattern + self._opt = tf.optimizers.Adam(lr) + self._opt = prec.LossScaleOptimizer(self._opt, 'dynamic') + self._variables = None + + @property + def variables(self): + return self._opt.variables() + + def __call__(self, tape, loss): + if self._variables is None: + variables = [module.variables for module in self._modules] + self._variables = tf.nest.flatten(variables) + count = sum(np.prod(x.shape) for x in self._variables) + print(f'Found {count} {self._name} parameters.') + print("loss.shape",loss.shape) + assert len(loss.shape) == 0, loss.shape + with tape: + loss = self._opt.get_scaled_loss(loss) + grads = tape.gradient(loss, self._variables) + grads = self._opt.get_unscaled_gradients(grads) + norm = tf.linalg.global_norm(grads) + if self._clip: + grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) + if self._wd: + context = tf.distribute.get_replica_context() + context.merge_call(self._apply_weight_decay) + self._opt.apply_gradients(zip(grads, self._variables)) + return norm + + def _apply_weight_decay(self, strategy): + print('Applied weight decay to variables:') + for var in self._variables: + if re.search(self._wdpattern, self._name + '/' + var.name): + print('- ' + self._name + '/' + var.name) + strategy.extended.update(var, lambda var: self._wd * var) + + +def args_type(default): + if isinstance(default, bool): + return lambda x: bool(['False', 'True'].index(x)) + if isinstance(default, int): + return lambda x: float(x) if ('e' in x or '.' in x) else int(x) + if isinstance(default, pathlib.Path): + return lambda x: pathlib.Path(x).expanduser() + return type(default) + + +def static_scan(fn, inputs, start, reverse=False): + last = start + outputs = [[] for _ in tf.nest.flatten(start)] + indices = range(len(tf.nest.flatten(inputs)[0])) + if reverse: + indices = reversed(indices) + for index in indices: + inp = tf.nest.map_structure(lambda x: x[index], inputs) + last = fn(last, inp) + [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] + if reverse: + outputs = [list(reversed(x)) for x in outputs] + outputs = [tf.stack(x, 0) for x in outputs] + return tf.nest.pack_sequence_as(start, outputs) + + +def static_scan_inverse(fn, inputs, start, reverse=False): + last = start + outputs = [[] for _ in tf.nest.flatten(start)] + indices = range(len(tf.nest.flatten(inputs)[0])) + if reverse: + indices = reversed(indices) + for index in indices: + inp = tf.nest.map_structure(lambda x: x[index], inputs) + last = fn(last, inp) + [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] + if reverse: + outputs = [list(reversed(x)) for x in outputs] + outputs = [tf.stack(x, 0) for x in outputs] + return tf.nest.pack_sequence_as(start, outputs) + + + + +def static_scan_action(fn1, fn2, inputs, start, reverse=False): + last = start + outputs = [[] for _ in tf.nest.flatten(start)] + indices = range(len(tf.nest.flatten(inputs)[0])) + actions = [] + if reverse: + indices = reversed(indices) + for index in indices: + inp = tf.nest.map_structure(lambda x: x[index], inputs) + action = fn2(last) + last = fn1(last, action, inp) + [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] + actions.append(action) + if reverse: + outputs = [list(reversed(x)) for x in outputs] + outputs = [tf.stack(x, 0) for x in outputs] + return tf.nest.pack_sequence_as(start, outputs), actions[0] + + + +def _mnd_sample(self, sample_shape=(), seed=None, name='sample'): + return tf.random.normal( + tuple(sample_shape) + tuple(self.event_shape), + self.mean(), self.stddev(), self.dtype, seed, name) + + +tfd.MultivariateNormalDiag.sample = _mnd_sample + + +def _cat_sample(self, sample_shape=(), seed=None, name='sample'): + assert len(sample_shape) in (0, 1), sample_shape + assert len(self.logits_parameter().shape) == 2 + indices = tf.random.categorical( + self.logits_parameter(), sample_shape[0] if sample_shape else 1, + self.dtype, seed, name) + if not sample_shape: + indices = indices[..., 0] + return indices + + +tfd.Categorical.sample = _cat_sample + + +class Every: + + def __init__(self, every): + self._every = every + self._last = None + + def __call__(self, step): + if self._last is None: + self._last = step + return True + if step >= self._last + self._every: + self._last += self._every + return True + return False + + +class Once: + + def __init__(self): + self._once = True + + def __call__(self): + if self._once: + self._once = False + return True + return False + + +def load_imgnet(train): + import pickle + name = 'train' if train else 'valid' + + # images_train. pkl and images_test.pkl to be downloaded from + + with open('images_{}.pkl'.format(name), 'rb') as fin: + imgnet = pickle.load(fin) + + imgnet = np.transpose(imgnet, axes=(0, 1, 3, 4, 2)) + + return imgnet diff --git a/wrappers.py b/wrappers.py index 6862705..693c9b8 100644 --- a/wrappers.py +++ b/wrappers.py @@ -3,474 +3,613 @@ import sys import threading import traceback - import gym import numpy as np from PIL import Image - +import cv2 class DeepMindControl: - def __init__(self, name, size=(64, 64), camera=None): - domain, task = name.split('_', 1) - if domain == 'cup': # Only domain with multiple words. - domain = 'ball_in_cup' - if isinstance(domain, str): - from dm_control import suite - self._env = suite.load(domain, task) - else: - assert task is None - self._env = domain() - self._size = size - if camera is None: - camera = dict(quadruped=2).get(domain, 0) - self._camera = camera - - @property - def observation_space(self): - spaces = {} - for key, value in self._env.observation_spec().items(): - spaces[key] = gym.spaces.Box( - -np.inf, np.inf, value.shape, dtype=np.float32) - spaces['image'] = gym.spaces.Box( - 0, 255, self._size + (3,), dtype=np.uint8) - return gym.spaces.Dict(spaces) - - @property - def action_space(self): - spec = self._env.action_spec() - return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) - - def step(self, action): - time_step = self._env.step(action) - obs = dict(time_step.observation) - obs['image'] = self.render() - reward = time_step.reward or 0 - done = time_step.last() - info = {'discount': np.array(time_step.discount, np.float32)} - return obs, reward, done, info - - def reset(self): - time_step = self._env.reset() - obs = dict(time_step.observation) - obs['image'] = self.render() - return obs - - def render(self, *args, **kwargs): - if kwargs.get('mode', 'rgb_array') != 'rgb_array': - raise ValueError("Only render mode 'rgb_array' is supported.") - return self._env.physics.render(*self._size, camera_id=self._camera) + def __init__(self, name, size=(64, 64), camera=None): + domain, task = name.split('_', 1) + + if domain == 'manip': + from dm_control import manipulation + self._env = manipulation.load(task + '_vision') + elif isinstance(domain, str): + if domain == 'cup': # Only domain with multiple words. + domain = 'ball_in_cup' + from dm_control import suite + self._env = suite.load(domain, task) + else: + assert task is None + self._env = domain() + self._size = size + if camera is None: + camera = dict(quadruped=2).get(domain, 0) + self._camera = camera + + @property + def observation_space(self): + spaces = {} + for key, value in self._env.observation_spec().items(): + spaces[key] = gym.spaces.Box( + -np.inf, np.inf, value.shape, dtype=np.float32) + spaces['image'] = gym.spaces.Box( + 0, 255, self._size + (3,), dtype=np.uint8) + return gym.spaces.Dict(spaces) + + @property + def action_space(self): + spec = self._env.action_spec() + return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) + + def step(self, action): + time_step = self._env.step(action) + obs = dict(time_step.observation) + obs['image'] = self.render() + reward = time_step.reward or 0 + done = time_step.last() + info = {'discount': np.array(time_step.discount, np.float32)} + return obs, reward, done, info + + def reset(self): + time_step = self._env.reset() + obs = dict(time_step.observation) + obs['image'] = self.render() + return obs + + def render(self, *args, **kwargs): + if kwargs.get('mode', 'rgb_array') != 'rgb_array': + raise ValueError("Only render mode 'rgb_array' is supported.") + return self._env.physics.render(*self._size, camera_id=self._camera) + + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = self._noisify_obs(obs, False) + obs = {'image': obs['image']} + if self._sparse and self._solved: + reward = 0 + elif self._sparse: + success = self._get_task_reward(self.task, 'success') + reward = float(success) + self._solved = bool(success) + return obs, reward, done, info + + def reset(self): + self._solved = False + obs = self._env.reset() + obs = self._noisify_obs(obs, False) + obs = {'image': obs['image']} + return obs class Atari: - LOCK = threading.Lock() - - def __init__( - self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, - life_done=False, sticky_actions=True): - import gym - version = 0 if sticky_actions else 4 - name = ''.join(word.title() for word in name.split('_')) - with self.LOCK: - self._env = gym.make('{}NoFrameskip-v{}'.format(name, version)) - self._action_repeat = action_repeat - self._size = size - self._grayscale = grayscale - self._noops = noops - self._life_done = life_done - self._lives = None - shape = self._env.observation_space.shape[:2] + (() if grayscale else (3,)) - self._buffers = [np.empty(shape, dtype=np.uint8) for _ in range(2)] - self._random = np.random.RandomState(seed=None) - - @property - def observation_space(self): - shape = self._size + (1 if self._grayscale else 3,) - space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) - return gym.spaces.Dict({'image': space}) - - @property - def action_space(self): - return self._env.action_space - - def close(self): - return self._env.close() - - def reset(self): - with self.LOCK: - self._env.reset() - noops = self._random.randint(1, self._noops + 1) - for _ in range(noops): - done = self._env.step(0)[2] - if done: + LOCK = threading.Lock() + + def __init__( + self, name, action_repeat=4, size=(84, 84), grayscale=True, noops=30, + life_done=False, sticky_actions=True): + import gym + version = 0 if sticky_actions else 4 + name = ''.join(word.title() for word in name.split('_')) + with self.LOCK: + self._env = gym.make('{}NoFrameskip-v{}'.format(name, version)) + self._action_repeat = action_repeat + self._size = size + self._grayscale = grayscale + self._noops = noops + self._life_done = life_done + self._lives = None + shape = self._env.observation_space.shape[:2] + \ + (() if grayscale else (3,)) + self._buffers = [np.empty(shape, dtype=np.uint8) for _ in range(2)] + self._random = np.random.RandomState(seed=None) + + @property + def observation_space(self): + shape = self._size + (1 if self._grayscale else 3,) + space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) + return gym.spaces.Dict({'image': space}) + + @property + def action_space(self): + return self._env.action_space + + def close(self): + return self._env.close() + + def reset(self): with self.LOCK: - self._env.reset() - self._lives = self._env.ale.lives() - if self._grayscale: - self._env.ale.getScreenGrayscale(self._buffers[0]) - else: - self._env.ale.getScreenRGB2(self._buffers[0]) - self._buffers[1].fill(0) - return self._get_obs() - - def step(self, action): - total_reward = 0.0 - for step in range(self._action_repeat): - _, reward, done, info = self._env.step(action) - total_reward += reward - if self._life_done: - lives = self._env.ale.lives() - done = done or lives < self._lives - self._lives = lives - if done: - break - elif step >= self._action_repeat - 2: - index = step - (self._action_repeat - 2) + self._env.reset() + noops = self._random.randint(1, self._noops + 1) + for _ in range(noops): + done = self._env.step(0)[2] + if done: + with self.LOCK: + self._env.reset() + self._lives = self._env.ale.lives() if self._grayscale: - self._env.ale.getScreenGrayscale(self._buffers[index]) + self._env.ale.getScreenGrayscale(self._buffers[0]) else: - self._env.ale.getScreenRGB2(self._buffers[index]) - obs = self._get_obs() - return obs, total_reward, done, info - - def render(self, mode): - return self._env.render(mode) - - def _get_obs(self): - if self._action_repeat > 1: - np.maximum(self._buffers[0], self._buffers[1], out=self._buffers[0]) - image = np.array(Image.fromarray(self._buffers[0]).resize( - self._size, Image.BILINEAR)) - image = np.clip(image, 0, 255).astype(np.uint8) - image = image[:, :, None] if self._grayscale else image - return {'image': image} + self._env.ale.getScreenRGB2(self._buffers[0]) + self._buffers[1].fill(0) + return self._get_obs() + + def step(self, action): + total_reward = 0.0 + for step in range(self._action_repeat): + _, reward, done, info = self._env.step(action) + total_reward += reward + if self._life_done: + lives = self._env.ale.lives() + done = done or lives < self._lives + self._lives = lives + if done: + break + elif step >= self._action_repeat - 2: + index = step - (self._action_repeat - 2) + if self._grayscale: + self._env.ale.getScreenGrayscale(self._buffers[index]) + else: + self._env.ale.getScreenRGB2(self._buffers[index]) + obs = self._get_obs() + return obs, total_reward, done, info + + def render(self, mode): + return self._env.render(mode) + + def _get_obs(self): + if self._action_repeat > 1: + np.maximum(self._buffers[0], + self._buffers[1], out=self._buffers[0]) + image = np.array(Image.fromarray(self._buffers[0]).resize( + self._size, Image.BILINEAR)) + image = np.clip(image, 0, 255).astype(np.uint8) + image = image[:, :, None] if self._grayscale else image + return {'image': image} class Collect: - def __init__(self, env, callbacks=None, precision=32): - self._env = env - self._callbacks = callbacks or () - self._precision = precision - self._episode = None - - def __getattr__(self, name): - return getattr(self._env, name) - - def step(self, action): - obs, reward, done, info = self._env.step(action) - obs = {k: self._convert(v) for k, v in obs.items()} - transition = obs.copy() - transition['action'] = action - transition['reward'] = reward - transition['discount'] = info.get('discount', np.array(1 - float(done))) - self._episode.append(transition) - if done: - episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} - episode = {k: self._convert(v) for k, v in episode.items()} - info['episode'] = episode - for callback in self._callbacks: - callback(episode) - return obs, reward, done, info - - def reset(self): - obs = self._env.reset() - transition = obs.copy() - transition['action'] = np.zeros(self._env.action_space.shape) - transition['reward'] = 0.0 - transition['discount'] = 1.0 - self._episode = [transition] - return obs - - def _convert(self, value): - value = np.array(value) - if np.issubdtype(value.dtype, np.floating): - dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] - elif np.issubdtype(value.dtype, np.signedinteger): - dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] - elif np.issubdtype(value.dtype, np.uint8): - dtype = np.uint8 - else: - raise NotImplementedError(value.dtype) - return value.astype(dtype) + def __init__(self, env, callbacks=None, precision=32): + self._env = env + self._callbacks = callbacks or () + self._precision = precision + self._episode = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = {k: self._convert(v) for k, v in obs.items()} + transition = obs.copy() + transition['action'] = action + transition['reward'] = reward + transition['discount'] = info.get( + 'discount', np.array(1 - float(done))) + self._episode.append(transition) + if done: + episode = {k: [t[k] for t in self._episode] + for k in self._episode[0]} + episode = {k: self._convert(v) for k, v in episode.items()} + info['episode'] = episode + for callback in self._callbacks: + callback(episode) + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + transition = obs.copy() + transition['action'] = np.zeros(self._env.action_space.shape) + transition['reward'] = 0.0 + transition['discount'] = 1.0 + self._episode = [transition] + return obs + + def _convert(self, value): + value = np.array(value) + if np.issubdtype(value.dtype, np.floating): + dtype = {16: np.float16, 32: np.float32, + 64: np.float64}[self._precision] + elif np.issubdtype(value.dtype, np.signedinteger): + dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] + elif np.issubdtype(value.dtype, np.uint8): + dtype = np.uint8 + else: + raise NotImplementedError(value.dtype) + return value.astype(dtype) class TimeLimit: - def __init__(self, env, duration): - self._env = env - self._duration = duration - self._step = None + def __init__(self, env, duration): + self._env = env + self._duration = duration + self._step = None + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + assert self._step is not None, 'Must reset environment.' + obs, reward, done, info = self._env.step(action) + self._step += 1 + if self._step >= self._duration: + done = True + if 'discount' not in info: + info['discount'] = np.array(1.0).astype(np.float32) + self._step = None + return obs, reward, done, info + + def reset(self): + self._step = 0 + return self._env.reset() + +class NaturalMujoco: + + def __init__(self, env, dataset): + self.dataset = dataset + self._pointer = (np.random.randint(self.dataset.shape[0]), 0) + self._env = env + + def __getattr__(self, name): + return getattr(self._env, name) + + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = self._noisify_obs(obs, done) + return obs, reward, done, info + + def _noisify_obs(self, obs, done): + obs = obs.copy() + img = obs['image'] + video_id, img_id = self._pointer + # fgbg = cv2.createBackgroundSubtractorKNN() + # fgbg = cv2.createBackgroundSubtractorMOG2(detectShadows=True) + # temp = fgbg.apply(img) != 255 + # fgmask = temp[:, :, None].repeat(3, axis=2) + # fgmask = ~(fgbg.apply(img) == 255)[:, :, None].repeat(3, axis=2) + + # extract only the agent pixels + fgmask = (img[:, :, 0] > 100)[:, :, None].repeat(3, axis=2) + #x = np.ones((64, 64), dtype=bool) + #x[16:48,16:48] = np.zeros((32, 32), dtype=bool) + #fgmask = (np.logical_or(img[:, :, 0] > 100,x) )[:, :, None].repeat(3, axis=2) + + if done: + video_id = np.random.randint(self.dataset.shape[0]) + img_id = 0 + else: + img_id = (img_id + 1) % self.dataset.shape[1] + + background = self.dataset[video_id, img_id] + #print(background) + #print(background.shape) + img = img * fgmask + background * (~fgmask) + + self._pointer = (video_id, img_id) + + obs['image'] = img + + return obs + + def reset(self): + obs = self._env.reset() + obs = self._noisify_obs(obs, False) + return obs + + +class CustomMujoco: + + def __init__(self, env, dataset): + self.dataset = dataset + self._pointer = (np.random.randint(self.dataset.shape[0]), np.random.randint(self.dataset.shape[0])) + self._env = env + + def __getattr__(self, name): + return getattr(self._env, name) - def __getattr__(self, name): - return getattr(self._env, name) + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = self._noisify_obs(obs, done) + return obs, reward, done, info + + def _noisify_obs(self, obs, done): + obs = obs.copy() + img = obs['image'] + video_id, img_id = self._pointer + # fgbg = cv2.createBackgroundSubtractorKNN() + # fgbg = cv2.createBackgroundSubtractorMOG2(detectShadows=True) + # temp = fgbg.apply(img) != 255 + # fgmask = temp[:, :, None].repeat(3, axis=2) + # fgmask = ~(fgbg.apply(img) == 255)[:, :, None].repeat(3, axis=2) + + # extract only the agent pixels (yellow) + #fgmask = (img[:, :, 0] > 100)[:, :, None].repeat(3, axis=2) + + #fgmask = ((img[:, :, 2] < 60) or (img[:, :, 0] > 170 and img[:, :, 1] > 170 and img[:, :, 1] > 170 ))[:, :, None].repeat(3, axis=2) + + fgmask = (np.logical_or((img[:, :, :] < 30)[:,:,2],np.logical_or(img[:, :, 2] < 60,(img[:, :, :] > 160)[:,:,2])))[:, :, None].repeat(3, axis=2) + + + if done: + video_id = np.random.randint(self.dataset.shape[0]) + img_id = np.random.randint(self.dataset.shape[0]) + else: + img_id = (img_id + 1) % self.dataset.shape[0] + + background = self.dataset[img_id] + #print(background) + #print(background.shape) + img = img * fgmask + background * (~fgmask) + + self._pointer = (video_id, img_id) + + obs['image'] = img + + return obs + + def reset(self): + obs = self._env.reset() + obs = self._noisify_obs(obs, False) + return obs - def step(self, action): - assert self._step is not None, 'Must reset environment.' - obs, reward, done, info = self._env.step(action) - self._step += 1 - if self._step >= self._duration: - done = True - if 'discount' not in info: - info['discount'] = np.array(1.0).astype(np.float32) - self._step = None - return obs, reward, done, info - def reset(self): - self._step = 0 - return self._env.reset() class ActionRepeat: - def __init__(self, env, amount): - self._env = env - self._amount = amount + def __init__(self, env, amount): + self._env = env + self._amount = amount - def __getattr__(self, name): - return getattr(self._env, name) + def __getattr__(self, name): + return getattr(self._env, name) - def step(self, action): - done = False - total_reward = 0 - current_step = 0 - while current_step < self._amount and not done: - obs, reward, done, info = self._env.step(action) - total_reward += reward - current_step += 1 - return obs, total_reward, done, info + def step(self, action): + done = False + total_reward = 0 + current_step = 0 + while current_step < self._amount and not done: + obs, reward, done, info = self._env.step(action) + total_reward += reward + current_step += 1 + return obs, total_reward, done, info class NormalizeActions: - def __init__(self, env): - self._env = env - self._mask = np.logical_and( - np.isfinite(env.action_space.low), - np.isfinite(env.action_space.high)) - self._low = np.where(self._mask, env.action_space.low, -1) - self._high = np.where(self._mask, env.action_space.high, 1) + def __init__(self, env): + self._env = env + self._mask = np.logical_and( + np.isfinite(env.action_space.low), + np.isfinite(env.action_space.high)) + self._low = np.where(self._mask, env.action_space.low, -1) + self._high = np.where(self._mask, env.action_space.high, 1) - def __getattr__(self, name): - return getattr(self._env, name) + def __getattr__(self, name): + return getattr(self._env, name) - @property - def action_space(self): - low = np.where(self._mask, -np.ones_like(self._low), self._low) - high = np.where(self._mask, np.ones_like(self._low), self._high) - return gym.spaces.Box(low, high, dtype=np.float32) + @property + def action_space(self): + low = np.where(self._mask, -np.ones_like(self._low), self._low) + high = np.where(self._mask, np.ones_like(self._low), self._high) + return gym.spaces.Box(low, high, dtype=np.float32) - def step(self, action): - original = (action + 1) / 2 * (self._high - self._low) + self._low - original = np.where(self._mask, original, action) - return self._env.step(original) + def step(self, action): + original = (action + 1) / 2 * (self._high - self._low) + self._low + original = np.where(self._mask, original, action) + return self._env.step(original) class ObsDict: - def __init__(self, env, key='obs'): - self._env = env - self._key = key + def __init__(self, env, key='obs'): + self._env = env + self._key = key - def __getattr__(self, name): - return getattr(self._env, name) + def __getattr__(self, name): + return getattr(self._env, name) - @property - def observation_space(self): - spaces = {self._key: self._env.observation_space} - return gym.spaces.Dict(spaces) + @property + def observation_space(self): + spaces = {self._key: self._env.observation_space} + return gym.spaces.Dict(spaces) - @property - def action_space(self): - return self._env.action_space + @property + def action_space(self): + return self._env.action_space - def step(self, action): - obs, reward, done, info = self._env.step(action) - obs = {self._key: np.array(obs)} - return obs, reward, done, info + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs = {self._key: np.array(obs)} + return obs, reward, done, info - def reset(self): - obs = self._env.reset() - obs = {self._key: np.array(obs)} - return obs + def reset(self): + obs = self._env.reset() + obs = {self._key: np.array(obs)} + return obs class OneHotAction: - def __init__(self, env): - assert isinstance(env.action_space, gym.spaces.Discrete) - self._env = env + def __init__(self, env): + assert isinstance(env.action_space, gym.spaces.Discrete) + self._env = env - def __getattr__(self, name): - return getattr(self._env, name) + def __getattr__(self, name): + return getattr(self._env, name) - @property - def action_space(self): - shape = (self._env.action_space.n,) - space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) - space.sample = self._sample_action - return space + @property + def action_space(self): + shape = (self._env.action_space.n,) + space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) + space.sample = self._sample_action + return space - def step(self, action): - index = np.argmax(action).astype(int) - reference = np.zeros_like(action) - reference[index] = 1 - if not np.allclose(reference, action): - raise ValueError(f'Invalid one-hot action:\n{action}') - return self._env.step(index) + def step(self, action): + index = np.argmax(action).astype(int) + reference = np.zeros_like(action) + reference[index] = 1 + if not np.allclose(reference, action): + raise ValueError(f'Invalid one-hot action:\n{action}') + return self._env.step(index) - def reset(self): - return self._env.reset() + def reset(self): + return self._env.reset() - def _sample_action(self): - actions = self._env.action_space.n - index = self._random.randint(0, actions) - reference = np.zeros(actions, dtype=np.float32) - reference[index] = 1.0 - return reference + def _sample_action(self): + actions = self._env.action_space.n + index = self._random.randint(0, actions) + reference = np.zeros(actions, dtype=np.float32) + reference[index] = 1.0 + return reference class RewardObs: - def __init__(self, env): - self._env = env + def __init__(self, env): + self._env = env - def __getattr__(self, name): - return getattr(self._env, name) + def __getattr__(self, name): + return getattr(self._env, name) - @property - def observation_space(self): - spaces = self._env.observation_space.spaces - assert 'reward' not in spaces - spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) - return gym.spaces.Dict(spaces) + @property + def observation_space(self): + spaces = self._env.observation_space.spaces + assert 'reward' not in spaces + spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) + return gym.spaces.Dict(spaces) - def step(self, action): - obs, reward, done, info = self._env.step(action) - obs['reward'] = reward - return obs, reward, done, info + def step(self, action): + obs, reward, done, info = self._env.step(action) + obs['reward'] = reward + return obs, reward, done, info - def reset(self): - obs = self._env.reset() - obs['reward'] = 0.0 - return obs + def reset(self): + obs = self._env.reset() + obs['reward'] = 0.0 + return obs class Async: - _ACCESS = 1 - _CALL = 2 - _RESULT = 3 - _EXCEPTION = 4 - _CLOSE = 5 - - def __init__(self, ctor, strategy='process'): - self._strategy = strategy - if strategy == 'none': - self._env = ctor() - elif strategy == 'thread': - import multiprocessing.dummy as mp - elif strategy == 'process': - import multiprocessing as mp - else: - raise NotImplementedError(strategy) - if strategy != 'none': - self._conn, conn = mp.Pipe() - self._process = mp.Process(target=self._worker, args=(ctor, conn)) - atexit.register(self.close) - self._process.start() - self._obs_space = None - self._action_space = None - - @property - def observation_space(self): - if not self._obs_space: - self._obs_space = self.__getattr__('observation_space') - return self._obs_space - - @property - def action_space(self): - if not self._action_space: - self._action_space = self.__getattr__('action_space') - return self._action_space - - def __getattr__(self, name): - if self._strategy == 'none': - return getattr(self._env, name) - self._conn.send((self._ACCESS, name)) - return self._receive() - - def call(self, name, *args, **kwargs): - blocking = kwargs.pop('blocking', True) - if self._strategy == 'none': - return functools.partial(getattr(self._env, name), *args, **kwargs) - payload = name, args, kwargs - self._conn.send((self._CALL, payload)) - promise = self._receive - return promise() if blocking else promise - - def close(self): - if self._strategy == 'none': - try: - self._env.close() - except AttributeError: - pass - return - try: - self._conn.send((self._CLOSE, None)) - self._conn.close() - except IOError: - # The connection was already closed. - pass - self._process.join() - - def step(self, action, blocking=True): - return self.call('step', action, blocking=blocking) - - def reset(self, blocking=True): - return self.call('reset', blocking=blocking) - - def _receive(self): - try: - message, payload = self._conn.recv() - except ConnectionResetError: - raise RuntimeError('Environment worker crashed.') - # Re-raise exceptions in the main process. - if message == self._EXCEPTION: - stacktrace = payload - raise Exception(stacktrace) - if message == self._RESULT: - return payload - raise KeyError(f'Received message of unexpected type {message}') - - def _worker(self, ctor, conn): - try: - env = ctor() - while True: + _ACCESS = 1 + _CALL = 2 + _RESULT = 3 + _EXCEPTION = 4 + _CLOSE = 5 + + def __init__(self, ctor, strategy='process'): + self._strategy = strategy + if strategy == 'none': + self._env = ctor() + elif strategy == 'thread': + import multiprocessing.dummy as mp + elif strategy == 'process': + import multiprocessing as mp + else: + raise NotImplementedError(strategy) + if strategy != 'none': + self._conn, conn = mp.Pipe() + self._process = mp.Process(target=self._worker, args=(ctor, conn)) + atexit.register(self.close) + self._process.start() + self._obs_space = None + self._action_space = None + + @property + def observation_space(self): + if not self._obs_space: + self._obs_space = self.__getattr__('observation_space') + return self._obs_space + + @property + def action_space(self): + if not self._action_space: + self._action_space = self.__getattr__('action_space') + return self._action_space + + def __getattr__(self, name): + if self._strategy == 'none': + return getattr(self._env, name) + self._conn.send((self._ACCESS, name)) + return self._receive() + + def call(self, name, *args, **kwargs): + blocking = kwargs.pop('blocking', True) + if self._strategy == 'none': + return functools.partial(getattr(self._env, name), *args, **kwargs) + payload = name, args, kwargs + self._conn.send((self._CALL, payload)) + promise = self._receive + return promise() if blocking else promise + + def close(self): + if self._strategy == 'none': + try: + self._env.close() + except AttributeError: + pass + return + try: + self._conn.send((self._CLOSE, None)) + self._conn.close() + except IOError: + # The connection was already closed. + pass + self._process.join() + + def step(self, action, blocking=True): + return self.call('step', action, blocking=blocking) + + def reset(self, blocking=True): + return self.call('reset', blocking=blocking) + + def _receive(self): + try: + message, payload = self._conn.recv() + except ConnectionResetError: + raise RuntimeError('Environment worker crashed.') + # Re-raise exceptions in the main process. + if message == self._EXCEPTION: + stacktrace = payload + raise Exception(stacktrace) + if message == self._RESULT: + return payload + raise KeyError(f'Received message of unexpected type {message}') + + def _worker(self, ctor, conn): try: - # Only block for short times to have keyboard exceptions be raised. - if not conn.poll(0.1): - continue - message, payload = conn.recv() - except (EOFError, KeyboardInterrupt): - break - if message == self._ACCESS: - name = payload - result = getattr(env, name) - conn.send((self._RESULT, result)) - continue - if message == self._CALL: - name, args, kwargs = payload - result = getattr(env, name)(*args, **kwargs) - conn.send((self._RESULT, result)) - continue - if message == self._CLOSE: - assert payload is None - break - raise KeyError(f'Received message of unknown type {message}') - except Exception: - stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) - print(f'Error in environment process: {stacktrace}') - conn.send((self._EXCEPTION, stacktrace)) - conn.close() + env = ctor() + while True: + try: + # Only block for short times to have keyboard exceptions be raised. + if not conn.poll(0.1): + continue + message, payload = conn.recv() + except (EOFError, KeyboardInterrupt): + break + if message == self._ACCESS: + name = payload + result = getattr(env, name) + conn.send((self._RESULT, result)) + continue + if message == self._CALL: + name, args, kwargs = payload + result = getattr(env, name)(*args, **kwargs) + conn.send((self._RESULT, result)) + continue + if message == self._CLOSE: + assert payload is None + break + raise KeyError(f'Received message of unknown type {message}') + except Exception: + stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) + print(f'Error in environment process: {stacktrace}') + conn.send((self._EXCEPTION, stacktrace)) + conn.close()