diff --git a/.gitmodules b/.gitmodules index b03c2b0..7d4f44b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,3 +6,8 @@ path = models/stylegan2/stylegan2-pytorch url = https://github.com/harskish/stylegan2-pytorch.git ignore = untracked + +[submodule "stylegan2_ada/stylegan2-ada-pytorch"] + path = models/stylegan2_ada/stylegan2-ada-pytorch + url = https://github.com/NVlabs/stylegan2-ada-pytorch.git + ignore = untracked diff --git a/README.md b/README.md index a6a71a2..44d4eda 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,42 @@ +# Changes compared to the original repo +* **Added StyleGAN2-ada support**
+ The following classes for StyleGAN2-ada are available for automatic download: + * `ffhq` + * `afhqcat` + * `afhqdog` + * `afhqwild` + * `brecahad` + * `cifar10` + * `metfaces` + + For a custom class add the name and the resolution in the `configs` dictonary in `models/wrappers.py` in the `StyleGAN2_ada` constructor and place the checkpoint-file at `models/checkpoints/stylegan2_ada/stylegan2_{class_name}_{resolution}.pkl` (replace {class_name} and {resolution} with the ones you added to the `configs` dict.) + + `partial_forward` for StyleGAN2-ada is currently not fully implemented, which means if you use a layer in the synthesis network as activation space, it could take longer than with other models, since the complete foreward-pass is always computeted, even if the used layer is located somewhere earlier. +* **Added grayscale image support** +* **Added another progress bar during the creation of the images** +* **Added new args for `visualize.py` to control the outcome without changing the code:** + +argument | description | arg-datatype +--- | --- | --- +`--plot_directions` | Number of components/directions to plot |int +`--plot_images` | Number of images per component/direction to plot | int +`--video_directions` | Number of components/directions to create a video of | int +`--video_images` | Number of frames within a video of one direction/component | int +* **Added interactive 2D scatter plot of the used activation space:** + + + +argument | description | arg-datatype +--- | --- | --- +`--scatter` | Activate scatter-plot | - +`--scatter_images` | Activate plotting corresponding generated images for each point | - +`--scatter_samples` | Number of samples in the 2D scatter plot | int +`--scatter_x` | Number of principal component for x-axis in the scatter plot | int +`--scatter_y` | Number of principal component for y-axis in the scatter plot | int + +If `--scatter_images` is active, the interactive plot is saved as `.pickle` which can be opened with `python open_scatter.py [path]`. + + # GANSpace: Discovering Interpretable GAN Controls ![Python 3.7](https://img.shields.io/badge/python-3.7-green.svg) ![PyTorch 1.3](https://img.shields.io/badge/pytorch-1.3-green.svg) diff --git a/StyleGAN_scatter.png b/StyleGAN_scatter.png new file mode 100644 index 0000000..80ebf45 Binary files /dev/null and b/StyleGAN_scatter.png differ diff --git a/config.py b/config.py index 5af238a..dbfdb78 100644 --- a/config.py +++ b/config.py @@ -27,10 +27,10 @@ def __str__(self): for k, v in self.__dict__.items(): if k == 'default_args': continue - + in_default = k in self.default_args same_value = self.default_args.get(k) == v - + if in_default and same_value: default[k] = v else: @@ -42,15 +42,15 @@ def __str__(self): } return json.dumps(config, indent=4) - + def __repr__(self): return self.__str__() - + def from_dict(self, dictionary): for k, v in dictionary.items(): setattr(self, k, v) return self - + def from_args(self, args=sys.argv[1:]): parser = argparse.ArgumentParser(description='GAN component analysis config') parser.add_argument('--model', dest='model', type=str, default='StyleGAN', help='The network to analyze') # StyleGAN, DCGAN, ProGAN, BigGAN-XYZ @@ -67,6 +67,22 @@ def from_args(self, args=sys.argv[1:]): parser.add_argument('--sigma', type=float, default=2.0, help='Number of stdevs to walk in visualize.py') parser.add_argument('--inputs', type=str, default=None, help='Path to directory with named components') parser.add_argument('--seed', type=int, default=None, help='Seed used in decomposition') + parser.add_argument('--plot_directions', dest='np_directions', type=int, default=14, help='Number of components/directions to plot') + parser.add_argument('--plot_images', dest='np_images', type=int, default=5, help='Number of images per component/direction to plot') + parser.add_argument('--video_directions', dest='nv_images', type=int, default=5, help='Number of components/directions to create a video of') + parser.add_argument('--video_images', dest='nv_images', type=int, default=150, help='Number of frames within a video of one direction/component') + parser.add_argument('--scatter', dest='show_scatter', action='store_true', help='Plot a 2D scatter-plot of the activation space of two principal components') + parser.add_argument('--scatter_samples', dest='scatter_samples', type=int, default=1000, help='Number of samples in the 2D scatter plot of the first two principal components') + parser.add_argument('--scatter_images', dest='scatter_images', action='store_true', help='Plot encoded images instead of points within the scatter plot') + parser.add_argument('--scatter_x', dest='scatter_x_axis_pc', type=int, default=1, help='Number of PC for x-axis in the scatter plot') + parser.add_argument('--scatter_y', dest='scatter_y_axis_pc', type=int, default=2, help='Number of PC for y-axis in the scatter plot') + + args = parser.parse_args(args) + assert args.np_images % 2 != 0, 'The number of plotted images per component (--plot_images) have to be odd.' + + if(args.model == "StyleGAN2-ada" and args.layer == "g_mapping"): + print("No layer \'g_mapping\' in StyleGAN2-ada. Assuming you meant \'mapping\'") + args.layer = "mapping" - return self.from_dict(args.__dict__) \ No newline at end of file + return self.from_dict(args.__dict__) diff --git a/decomposition.py b/decomposition.py index 4819e33..fa548a3 100644 --- a/decomposition.py +++ b/decomposition.py @@ -87,7 +87,7 @@ def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config): n_samp = max(10_000, config.n) // B * B # make divisible n_comp = comp.shape[0] latent_dims = inst.model.get_latent_dims() - + # We're looking for M s.t. M*P*G'(Z) = Z => M*A = Z # Z = batch of latent vectors (n_samples x latent_dims) # G'(Z) = batch of activations at intermediate layer @@ -104,7 +104,7 @@ def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config): # Dimensions other way around, so these are actually the transposes A = np.zeros((n_samp, n_comp), dtype=np.float32) Z = np.zeros((n_samp, latent_dims), dtype=np.float32) - + # Project tensor X onto PCs, return coordinates def project(X, comp): N = X.shape[0] @@ -131,7 +131,7 @@ def project(X, comp): # gelsy = complete orthogonal factorization; sometimes faster # gelss = SVD; slow but less memory hungry M_t = scipy.linalg.lstsq(A, Z, lapack_driver='gelsd')[0] # torch.lstsq(Z, A)[0][:n_comp, :] - + # Solution given by rows of M_t Z_comp = M_t[:n_comp, :] Z_mean = np.mean(Z, axis=0, keepdims=True) @@ -182,6 +182,12 @@ def compute(config, dump_name, instrumented_model): inst.retain_layer(layer_key) model.partial_forward(model.sample_latent(1), layer_key) sample_shape = inst.retained_features()[layer_key].shape + + #StyleGAN2-ada's mapping networks copies it's result 18 times to [B,18,512] so the sample shape is different than the latent shape + #from wrapper, because it only returns [B,512], so that GANSpace can modify only specifc Style-Layers + if(model.model_name == "StyleGAN2_ada" and model.w_primary): + sample_shape = (sample_shape[0],sample_shape[2]) + sample_dims = np.prod(sample_shape) print('Feature shape:', sample_shape) @@ -218,7 +224,7 @@ def compute(config, dump_name, instrumented_model): # Must not depend on chosen batch size (reproducibility) NB = max(B, max(2_000, 3*config.components)) # ipca: as large as possible! - + samples = None if not transformer.batch_support: samples = np.zeros((N + NB, sample_dims), dtype=np.float32) @@ -236,7 +242,7 @@ def compute(config, dump_name, instrumented_model): latents[i*B:(i+1)*B] = model.sample_latent(n_samples=B).cpu().numpy() # Decomposition on non-Gaussian latent space - samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W' + samples_are_latents = layer_key in ['g_mapping', 'mapping', 'style'] and inst.model.latent_space_name() == 'W' canceled = False try: @@ -245,7 +251,7 @@ def compute(config, dump_name, instrumented_model): for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True): for mb in range(0, NB, B): z = torch.from_numpy(latents[gi+mb:gi+mb+B]).to(device) - + if samples_are_latents: # Decomposition on latents directly (e.g. StyleGAN W) batch = z.reshape((B, -1)) @@ -253,7 +259,7 @@ def compute(config, dump_name, instrumented_model): # Decomposition on intermediate layer with torch.no_grad(): model.partial_forward(z, layer_key) - + # Permuted to place PCA dimensions last batch = inst.retained_features()[layer_key].reshape((B, -1)) @@ -268,21 +274,21 @@ def compute(config, dump_name, instrumented_model): except KeyboardInterrupt: if not transformer.batch_support: sys.exit(1) # no progress yet - + dump_name = dump_name.parent / dump_name.name.replace(f'n{N}', f'n{gi}') print(f'Saving current state to "{dump_name.name}" before exiting') canceled = True - + if not transformer.batch_support: X = samples # Use all samples X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32) # TODO: activations surely multi-modal...! X -= X_global_mean - + print(f'[{timestamp()}] Fitting whole batch') t_start_fit = datetime.datetime.now() transformer.fit(X) - + print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}') assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero' else: @@ -291,7 +297,7 @@ def compute(config, dump_name, instrumented_model): X -= X_global_mean X_comp, X_stdev, X_var_ratio = transformer.get_components() - + assert X_comp.shape[1] == sample_dims \ and X_comp.shape[0] == config.components \ and X_global_mean.shape[1] == sample_dims \ @@ -349,6 +355,7 @@ def compute(config, dump_name, instrumented_model): del inst del model + del X del X_comp del random_dirs @@ -363,20 +370,20 @@ def get_or_compute(config, model=None, submit_config=None, force_recompute=False if submit_config is None: wrkdir = str(Path(__file__).parent.resolve()) submit_config = SimpleNamespace(run_dir_root = wrkdir, run_dir = wrkdir) - + # Called directly by run.py return _compute(submit_config, config, model, force_recompute) def _compute(submit_config, config, model=None, force_recompute=False): basedir = Path(submit_config.run_dir) outdir = basedir / 'out' - + if config.n is None: raise RuntimeError('Must specify number of samples with -n=XXX') if model and not isinstance(model, InstrumentedModel): raise RuntimeError('Passed model has to be wrapped in "InstrumentedModel"') - + if config.use_w and not 'StyleGAN' in config.model: raise RuntimeError(f'Cannot change latent space of non-StyleGAN model {config.model}') @@ -398,5 +405,25 @@ def _compute(submit_config, config, model=None, force_recompute=False): t_start = datetime.datetime.now() compute(config, dump_path, model) print('Total time:', datetime.datetime.now() - t_start) - - return dump_path \ No newline at end of file + + return dump_path + + + +def imscatter(x, y, image, ax=None, zoom=1): + if ax is None: + ax = plt.gca() + try: + image = plt.imread(image) + except TypeError: + # Likely already an array... + pass + im = OffsetImage(image, zoom=zoom) + x, y = np.atleast_1d(x, y) + artists = [] + for x0, y0 in zip(x, y): + ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False) + artists.append(ax.add_artist(ab)) + ax.update_datalim(np.column_stack([x, y])) + ax.autoscale() + return artists diff --git a/models/stylegan2_ada/__init__.py b/models/stylegan2_ada/__init__.py new file mode 100644 index 0000000..e50fd7a --- /dev/null +++ b/models/stylegan2_ada/__init__.py @@ -0,0 +1,18 @@ +import sys +import os +import shutil +import glob +import platform +from pathlib import Path + +current_path = os.getcwd() + +module_path = Path(__file__).parent / 'stylegan2-ada-pytorch' +sys.path.append(str(module_path.resolve())) +os.chdir(module_path) + +import generate +import legacy +import dnnlib + +os.chdir(current_path) diff --git a/models/stylegan2_ada/stylegan2-ada-pytorch b/models/stylegan2_ada/stylegan2-ada-pytorch new file mode 160000 index 0000000..6f160b3 --- /dev/null +++ b/models/stylegan2_ada/stylegan2-ada-pytorch @@ -0,0 +1 @@ +Subproject commit 6f160b3d22b8b178ebe533a50d4d5e63aedba21d diff --git a/models/wrappers.py b/models/wrappers.py index 558e928..775bbd3 100644 --- a/models/wrappers.py +++ b/models/wrappers.py @@ -8,6 +8,7 @@ # OF ANY KIND, either express or implied. See the License for the specific language # governing permissions and limitations under the License. +import math import torch import numpy as np import re @@ -21,8 +22,10 @@ from . import biggan from . import stylegan from . import stylegan2 +from . import stylegan2_ada from abc import abstractmethod, ABC as AbstractBaseClass from functools import singledispatch +import PIL.Image class BaseModel(AbstractBaseClass, torch.nn.Module): @@ -80,7 +83,7 @@ def sample_np(self, z=None, n_samples=1, seed=None): z = torch.tensor(z).to(self.device) img = self.forward(z) img_np = img.permute(0, 2, 3, 1).cpu().detach().numpy() - return np.clip(img_np, 0.0, 1.0).squeeze() + return np.clip(img_np, 0.0, 1.0)#.squeeze() # For models that use part of latent as conditioning def get_conditional_state(self, z): @@ -93,6 +96,140 @@ def set_conditional_state(self, z, c): def named_modules(self, *args, **kwargs): return self.model.named_modules(*args, **kwargs) +# StyleGAN2-ada-pytorch +class StyleGAN2_ada(BaseModel): + def __init__(self, device, class_name, truncation=1.0, use_w=False): + super(StyleGAN2_ada, self).__init__('StyleGAN2_ada', class_name or 'ffhq') + self.device = device + self.truncation = truncation + self.latent_avg = None + self.w_primary = use_w + + # Image widths + configs = { + 'ffhq': 1024, + 'afhqcat': 512, + 'afhqdog': 512, + 'afhqwild': 512, + 'brecahad': 1024, + 'cifar10': 32, + 'metfaces': 1024, + } + + assert self.outclass in configs, \ + f'Invalid StyleGAN2-ada class {self.outclass}, should be one of [{", ".join(configs.keys())}]' + + self.resolution = configs[self.outclass] + self.name = f'StyleGAN2-ada-{self.outclass}' + self.has_latent_residual = True + self.load_model() + self.set_noise_seed(0) + + def latent_space_name(self): + return 'W' if self.w_primary else 'Z' + + def use_w(self): + self.w_primary = True + + def use_z(self): + self.w_primary = False + + + # URLs created with https://sites.google.com/site/gdocs2direct/ + def download_checkpoint(self, outfile): + checkpoints = { + 'ffhq': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl', + 'afhqcat': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqcat.pkl', + 'afhqdog': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqdog.pkl', + 'afhqwild': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqwild.pkl', + 'brecahad': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/brecahad.pkl', + 'cifar10': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl', + 'metfaces': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl' + } + + url = checkpoints[self.outclass] + download_ckpt(url, outfile) + + + def load_model(self): + checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') + checkpoint = Path(checkpoint_root) / f'stylegan2_ada/stylegan2_{self.outclass}_{self.resolution}.pkl' + + if not checkpoint.is_file(): + os.makedirs(checkpoint.parent, exist_ok=True) + self.download_checkpoint(checkpoint) + + with stylegan2_ada.dnnlib.util.open_url(str(checkpoint)) as f: + self.model = stylegan2_ada.legacy.load_network_pkl(f)['G_ema'].to(self.device) + + + def sample_latent(self, n_samples=1, seed=None, truncation=None): + if seed is None: + seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state + + z = torch.from_numpy(np.random.RandomState(seed).randn(n_samples, self.model.z_dim)).to(self.device) + c = torch.zeros([n_samples, self.model.c_dim], device=self.device) #no conditioning at the moment + + if self.w_primary: + z = self.model.mapping(z,c)[:,0,:] #Just use one of the 18 copies + return z + + def get_max_latents(self): + return self.model.num_ws + + def set_output_class(self, new_class): + if self.outclass != new_class: + raise RuntimeError('StyleGAN2-ada: cannot change output class without reloading') + + def forward(self, x): + if isinstance(x, list): + assert len(x) == self.model.num_ws, 'Must provide 1 or {0} latents'.format(str(self.model.num_ws)) + if not self.w_primary: + label_shape = list(x[0].shape) + label_shape[-1] = self.model.c_dim + label = torch.zeros(label_shape, device=self.device) + x = [self.model.mapping.forward(l,label, truncation_psi=self.truncation)[:,0,:] for l in x] + x = torch.stack(x, dim=1) + + else: + if not self.w_primary: + label_shape = list(x.shape) + label_shape[-1] = self.model.c_dim + label = torch.zeros(label_shape, device=self.device) + x = self.model.mapping.forward(x,label, truncation_psi=self.truncation)[:,0,:] + if(len(x.shape) != 3 or x.shape[1] != self.model.num_ws): + x = x.unsqueeze(1).expand(-1, self.model.num_ws, -1) + + img = self.model.synthesis.forward(x, noise_mode='const',force_fp32= self.device.type == 'cpu') + + return 0.5*(img+1) + + def partial_forward(self, x, layer_name): + mapping = self.model.mapping + G = self.model.synthesis + #trunc = self.model._modules.get('truncation', lambda x : x) + #print("TEST1",x.shape) + + if not self.w_primary: + c = torch.zeros([x.shape[0], self.model.c_dim], device=self.device) + x = mapping.forward(x,c) # handles list inputs + + # Whole mapping + if 'mapping' in layer_name: + return + else: + G.forward(x, noise_mode='const',force_fp32= self.device.type == 'cpu') #TODO: Implement partial foreward for faster computation + return + + + def set_noise_seed(self, seed): + torch.manual_seed(seed) + self.noise = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=self.device)] + + for i in range(3, int(math.log(self.model.img_resolution,2)) + 1): + for _ in range(2): + self.noise.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=self.device)) + # PyTorch port of StyleGAN 2 class StyleGAN2(BaseModel): def __init__(self, device, class_name, truncation=1.0, use_w=False): @@ -153,13 +290,13 @@ def download_checkpoint(self, outfile): def load_model(self): checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') checkpoint = Path(checkpoint_root) / f'stylegan2/stylegan2_{self.outclass}_{self.resolution}.pt' - + self.model = stylegan2.Generator(self.resolution, 512, 8).to(self.device) if not checkpoint.is_file(): os.makedirs(checkpoint.parent, exist_ok=True) self.download_checkpoint(checkpoint) - + ckpt = torch.load(checkpoint) self.model.load_state_dict(ckpt['g_ema'], strict=False) self.latent_avg = ckpt['latent_avg'].to(self.device) @@ -172,7 +309,7 @@ def sample_latent(self, n_samples=1, seed=None, truncation=None): z = torch.from_numpy( rng.standard_normal(512 * n_samples) .reshape(n_samples, 512)).float().to(self.device) #[N, 512] - + if self.w_primary: z = self.model.style(z) @@ -184,7 +321,7 @@ def get_max_latents(self): def set_output_class(self, new_class): if self.outclass != new_class: raise RuntimeError('StyleGAN2: cannot change output class without reloading') - + def forward(self, x): x = x if isinstance(x, list) else [x] out, _ = self.model(x, noise=self.noise, @@ -246,7 +383,7 @@ def partial_forward(self, x, layer_name): out = conv2(out, latent[:, i + 1], noise=noise[noise_i + 1]) if f'convs.{i}' in layer_name: return - + skip = to_rgb(out, latent[:, i + 2], skip) if f'to_rgbs.{i//2}' in layer_name: return @@ -280,7 +417,7 @@ def __init__(self, device, class_name, truncation=1.0, use_w=False): 'bedrooms': 256, 'cars': 512, 'cats': 256, - + # From https://github.com/justinpinkney/awesome-pretrained-stylegan 'vases': 1024, 'wikiart': 512, @@ -311,7 +448,7 @@ def use_z(self): def load_model(self): checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') checkpoint = Path(checkpoint_root) / f'stylegan/stylegan_{self.outclass}_{self.resolution}.pt' - + self.model = stylegan.StyleGAN_G(self.resolution).to(self.device) urls_tf = { @@ -341,7 +478,7 @@ def load_model(self): download_ckpt(urls_tf[self.outclass], checkpoint_tf) print('Converting TensorFlow checkpoint to PyTorch') self.model.export_from_tf(checkpoint_tf) - + self.model.load_weights(checkpoint) def sample_latent(self, n_samples=1, seed=None, truncation=None): @@ -352,10 +489,11 @@ def sample_latent(self, n_samples=1, seed=None, truncation=None): noise = torch.from_numpy( rng.standard_normal(512 * n_samples) .reshape(n_samples, 512)).float().to(self.device) #[N, 512] - + if self.w_primary: noise = self.model._modules['g_mapping'].forward(noise) - + + #print("NOISE shape",noise.shape) return noise def get_max_latents(self): @@ -374,7 +512,7 @@ def partial_forward(self, x, layer_name): mapping = self.model._modules['g_mapping'] G = self.model._modules['g_synthesis'] trunc = self.model._modules.get('truncation', lambda x : x) - + #print("TEST1",x.shape) if not self.w_primary: x = mapping.forward(x) # handles list inputs @@ -423,7 +561,7 @@ def set_noise_seed(self, seed): def for_each_child(this, name, func): children = getattr(this, '_modules', []) for child_name, module in children.items(): - for_each_child(module, f'{name}.{child_name}', func) + for_each_child(module, f'{name}.{child_name}', func) func(this, name) def modify(m, name): @@ -460,7 +598,7 @@ def get_conditional_state(self, z): def set_conditional_state(self, z, c): z[:, -20:] = c return z - + def forward(self, x): out = self.base_model.test(x) return 0.5*(out+1) @@ -483,7 +621,7 @@ def __init__(self, device, lsun_class=None): def load_model(self): checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') checkpoint = Path(checkpoint_root) / f'progan/{self.outclass}_lsun.pth' - + if not checkpoint.is_file(): os.makedirs(checkpoint.parent, exist_ok=True) url = f'http://netdissect.csail.mit.edu/data/ganmodel/karras/{self.outclass}_lsun.pth' @@ -501,7 +639,7 @@ def forward(self, x): if isinstance(x, list): assert len(x) == 1, "ProGAN only supports a single global latent" x = x[0] - + out = self.model.forward(x) return 0.5*(out+1) @@ -534,15 +672,15 @@ def __init__(self, device, resolution, class_name, truncation=1.0): # Default implementaiton fails without an internet # connection, even if the model has been cached - def load_model(self, name): + def load_model(self, name): if name not in biggan.model.PRETRAINED_MODEL_ARCHIVE_MAP: raise RuntimeError('Unknown BigGAN model name', name) - + checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') model_path = Path(checkpoint_root) / name os.makedirs(model_path, exist_ok=True) - + model_file = model_path / biggan.model.WEIGHTS_NAME config_file = model_path / biggan.model.CONFIG_NAME model_url = biggan.model.PRETRAINED_MODEL_ARCHIVE_MAP[name] @@ -562,10 +700,10 @@ def load_model(self, name): def sample_latent(self, n_samples=1, truncation=None, seed=None): if seed is None: seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state - + noise_vector = biggan.truncated_noise_sample(truncation=truncation or self.truncation, batch_size=n_samples, seed=seed) - noise = torch.from_numpy(noise_vector) #[N, 128] - + noise = torch.from_numpy(noise_vector) #[N, 128] + return noise.to(self.device) # One extra for gen_z @@ -577,7 +715,7 @@ def get_conditional_state(self, z): def set_conditional_state(self, z, c): self.v_class = c - + def is_valid_class(self, class_id): if isinstance(class_id, int): return class_id < 1000 @@ -595,8 +733,8 @@ def set_output_class(self, class_id): self.v_class = torch.from_numpy(biggan.one_hot_from_names([class_id])).to(self.device) else: raise RuntimeError(f'Unknown class identifier {class_id}') - - def forward(self, x): + + def forward(self, x): # Duplicate along batch dimension if isinstance(x, list): c = self.v_class.repeat(x[0].shape[0], 1) @@ -626,7 +764,7 @@ def partial_forward(self, x, layer_name): else: class_label = self.v_class.repeat(x[0].shape[0], 1) embed = len(x)*[self.model.embeddings(class_label)] - + assert len(x) == self.model.n_latents, f'Expected {self.model.n_latents} latents, got {len(x)}' assert len(embed) == self.model.n_latents, f'Expected {self.model.n_latents} class vectors, got {len(class_label)}' @@ -653,18 +791,18 @@ def get_model(name, output_class, device, **kwargs): # Check if optionally provided existing model can be reused inst = kwargs.get('inst', None) model = kwargs.get('model', None) - + if inst or model: cached = model or inst.model - + network_same = (cached.model_name == name) outclass_same = (cached.outclass == output_class) can_change_class = ('BigGAN' in name) - + if network_same and (outclass_same or can_change_class): cached.set_output_class(output_class) return cached - + if name == 'DCGAN': import warnings warnings.filterwarnings("ignore", message="nn.functional.tanh is deprecated") @@ -678,6 +816,8 @@ def get_model(name, output_class, device, **kwargs): model = StyleGAN(device, class_name=output_class) elif name == 'StyleGAN2': model = StyleGAN2(device, class_name=output_class) + elif name == 'StyleGAN2-ada': + model = StyleGAN2_ada(device, class_name=output_class) else: raise RuntimeError(f'Unknown model {name}') @@ -709,7 +849,7 @@ def get_instrumented_model(name, output_class, layers, device, **kwargs): print(f"Layer '{layer_name}' not found in model!") print("Available layers:", '\n'.join(module_names)) raise RuntimeError(f"Unknown layer '{layer_name}''") - + # Reset StyleGANs to z mode for shape annotation if hasattr(model, 'use_z'): model.use_z() @@ -732,4 +872,4 @@ def get_instrumented_model(name, output_class, layers, device, **kwargs): @get_instrumented_model.register(Config) def _(cfg, device, **kwargs): kwargs['use_w'] = kwargs.get('use_w', cfg.use_w) # explicit arg can override cfg - return get_instrumented_model(cfg.model, cfg.output_class, cfg.layer, device, **kwargs) \ No newline at end of file + return get_instrumented_model(cfg.model, cfg.output_class, cfg.layer, device, **kwargs) diff --git a/netdissect/modelconfig.py b/netdissect/modelconfig.py index d0ee37a..caf886c 100644 --- a/netdissect/modelconfig.py +++ b/netdissect/modelconfig.py @@ -21,7 +21,7 @@ def create_instrumented_model(args, **kwargs): gen: True for a generator model. One-pixel input assumed. imgsize: For non-generator models, (y, x) dimensions for RGB input. cuda: True to use CUDA. - + The constructed model will be decorated with the following attributes: input_shape: (usually 4d) tensor shape for single-image input. output_shape: 4d tensor shape for output. diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py index 5f421c4..a338308 100644 --- a/notebooks/notebook_utils.py +++ b/notebooks/notebook_utils.py @@ -47,10 +47,10 @@ def _create_strip_impl(inst, mode, layer, latents, x_comp, z_comp, act_stdev, la act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center) # Batch over frames if there are more frames in strip than latents -def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center): +def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center): inst.close() batch_frames = [[] for _ in range(len(latents))] - + B = min(num_frames, 5) lep_padded = ((num_frames - 1) // B + 1) * B sigma_range = np.linspace(-sigma, sigma, num_frames) @@ -60,11 +60,11 @@ def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_st for i_batch in range(lep_padded // B): sigmas = sigma_range[i_batch*B:(i_batch+1)*B] - + for i_lat in range(len(latents)): z_single = latents[i_lat] z_batch = z_single.repeat_interleave(B, axis=0) - + zeroing_offset_act = 0 zeroing_offset_lat = 0 if center: @@ -90,6 +90,7 @@ def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_st z[i] = z[i] - zeroing_offset_lat + delta if mode in ['activation', 'both']: + #HIER DRIN bei dem 1-ner batch sample comp_batch = x_comp.repeat_interleave(B, axis=0) delta = comp_batch * sigmas.reshape([-1] + [1]*(comp_batch.ndim - 1)) inst.edit_layer(layer, offset=delta * act_stdev - zeroing_offset_act) @@ -97,7 +98,7 @@ def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_st img_batch = inst.model.sample_np(z) if img_batch.ndim == 3: img_batch = np.expand_dims(img_batch, axis=0) - + for j, img in enumerate(img_batch): idx = i_batch*B + j if idx < num_frames: @@ -106,7 +107,7 @@ def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_st return batch_frames # Batch over latents if there are more latents than frames in strip -def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center): +def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center): n_lat = len(latents) B = min(n_lat, 5) @@ -114,7 +115,7 @@ def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stde if layer_end < 0 or layer_end > max_lat: layer_end = max_lat layer_start = np.clip(layer_start, 0, layer_end) - + len_padded = ((n_lat - 1) // B + 1) * B batch_frames = [[] for _ in range(n_lat)] @@ -122,14 +123,14 @@ def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stde zs = latents[i_batch*B:(i_batch+1)*B] if len(zs) == 0: continue - + z_batch_single = torch.cat(zs, 0) inst.close() # don't retain, remove edits sigma_range = np.linspace(-sigma, sigma, num_frames, dtype=np.float32) normalize = lambda v : v / torch.sqrt(torch.sum(v**2, dim=-1, keepdim=True) + 1e-8) - + zeroing_offset_act = 0 zeroing_offset_lat = 0 if center: @@ -163,7 +164,7 @@ def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stde img_batch = inst.model.sample_np(z) if img_batch.ndim == 3: img_batch = np.expand_dims(img_batch, axis=0) - + for j, img in enumerate(img_batch): img_idx = i_batch*B + j if img_idx < n_lat: @@ -176,12 +177,12 @@ def save_frames(title, model_name, rootdir, frames, strip_width=10): test_name = prettify_name(title) outdir = f'{rootdir}/{model_name}/{test_name}' makedirs(outdir, exist_ok=True) - + # Limit maximum resolution max_H = 512 real_H = frames[0][0].shape[0] ratio = min(1.0, max_H / real_H) - + # Combined with first 10 strips = [np.hstack(frames) for frames in frames[:strip_width]] if len(strips) >= strip_width: @@ -193,8 +194,8 @@ def save_frames(title, model_name, rootdir, frames, strip_width=10): im.save(f'{outdir}/{test_name}_all.png') else: print('Too few strips to create grid, creating just strips!') - + for ex_num, strip in enumerate(frames[:strip_width]): im = Image.fromarray(np.uint8(255*np.hstack(pad_frames(strip)))) im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS) - im.save(f'{outdir}/{test_name}_{ex_num}.png') \ No newline at end of file + im.save(f'{outdir}/{test_name}_{ex_num}.png') diff --git a/open_scatter.py b/open_scatter.py new file mode 100644 index 0000000..7b7f3f9 --- /dev/null +++ b/open_scatter.py @@ -0,0 +1,23 @@ +import matplotlib.pyplot as plt +import pickle +import os +import sys + +def show_figure(fig): + # create a dummy figure and use its + # manager to display "fig" + dummy = plt.figure(num=1) + new_manager = dummy.canvas.manager + new_manager.canvas.figure = fig + fig.set_canvas(new_manager.canvas) + plt.axis('equal') + plt.show() + +plt.switch_backend('TkAgg') + +path = sys.argv[1] +if os.path.isfile(path): + if(path.split('.')[-1] == 'pickle'): + print("Loading", path.split('/')[-1]) + figx = pickle.load(open(path, 'rb')) + show_figure(figx) diff --git a/utils.py b/utils.py index e08ff90..5d09ea8 100644 --- a/utils.py +++ b/utils.py @@ -30,13 +30,13 @@ def pad_frames(strip, pad_fract_horiz=64, pad_fract_vert=0, pad_value=None): pad_value = 1.0 else: pad_value = np.iinfo(dtype).max - + frames = [strip[0]] for frame in strip[1:]: if pad_fract_horiz > 0: - frames.append(pad_value*np.ones((frame.shape[0], frame.shape[1]//pad_fract_horiz, 3), dtype=dtype)) + frames.append(pad_value*np.ones((frame.shape[0], frame.shape[1]//pad_fract_horiz, frame.shape[2]), dtype=dtype)) elif pad_fract_vert > 0: - frames.append(pad_value*np.ones((frame.shape[0]//pad_fract_vert, frame.shape[1], 3), dtype=dtype)) + frames.append(pad_value*np.ones((frame.shape[0]//pad_fract_vert, frame.shape[1], frame.shape[2]), dtype=dtype)) frames.append(frame) return frames @@ -53,7 +53,7 @@ def download_google_drive(url, output_name): if tokens is None: tokens = re.search('(confirm=.)', str(r.content)) assert tokens is not None, 'Could not extract token from response' - + url = url.replace('id=', f'{tokens[1]}&id=') r = session.get(url, allow_redirects=True) r.raise_for_status() @@ -89,4 +89,4 @@ def download_ckpt(url, output_name): elif 'mega.nz' in url: download_manual(url, output_name) else: - download_generic(url, output_name) \ No newline at end of file + download_generic(url, output_name) diff --git a/visualize.py b/visualize.py index 433ae2e..13def69 100644 --- a/visualize.py +++ b/visualize.py @@ -28,10 +28,128 @@ import sys import datetime import argparse -from tqdm import trange +from tqdm import trange, tqdm from config import Config from decomposition import get_random_dirs, get_or_compute, get_max_batch_size, SEED_VISUALIZATION -from utils import pad_frames +from utils import pad_frames +from matplotlib.offsetbox import OffsetImage, AnnotationBbox +import pickle +from skimage.transform import resize +from matplotlib.patches import Ellipse +from estimators import get_estimator + +def make_2Dscatter(X_comp,X_global_mean,X_stdev,inst,model,layer_key,outdir,device,n_samples=100,with_images=False,x_axis_pc=1,y_axis_pc=2): + assert n_samples % 5 == 0, "n_samples has to be dividable by 5" + samples_are_from_w = layer_key in ['g_mapping', 'mapping', 'style'] and inst.model.latent_space_name() == 'W' + with torch.no_grad(): + #draw new latents + latents = model.sample_latent(n_samples=n_samples) + + if(samples_are_from_w): + activations = latents + else: + all_activations = [] + for i in range(0,int(n_samples/5)): + z = latents[i*5:(i+1)*5:1] + model.partial_forward(z,layer_name) + + activations_part = inst.retained_features()[layer_key].reshape((5, -1)) + all_activations.append(activations_part) + activations = torch.cat(all_activations) + + global_mean = torch.from_numpy(X_global_mean.reshape(-1)) + activations = torch.sub(activations.cpu(),global_mean) + + X_comp_2 = X_comp.squeeze().reshape((X_comp.shape[0],-1)).transpose(1,0)[:,[x_axis_pc-1,y_axis_pc-1]] + activations_reduced = activations @ X_comp_2 + x = activations_reduced[:,0] + y = activations_reduced[:,1] + + fig, ax = plt.subplots(1) + plt.scatter(x,y) + plt.xlabel("PC"+str(x_axis_pc)) + plt.ylabel("PC"+str(y_axis_pc)) + plt.plot(0,0,'rx',alpha=0.5,markersize=10,zorder=10) + + sigma1 = Ellipse(xy=(0, 0), width=X_stdev[x_axis_pc-1], height=X_stdev[y_axis_pc-1], + edgecolor='r', fc='None', lw=2,alpha=0.3,zorder=10) + sigma2 = Ellipse(xy=(0, 0), width=2*X_stdev[x_axis_pc-1], height=2*X_stdev[y_axis_pc-1], + edgecolor='r', fc='None', lw=2,alpha=0.25,zorder=10) + sigma3 = Ellipse(xy=(0, 0), width=3*X_stdev[x_axis_pc-1], height=3*X_stdev[y_axis_pc-1], + edgecolor='r', fc='None', lw=2,alpha=0.2,zorder=10) + sigma4 = Ellipse(xy=(0, 0), width=4*X_stdev[x_axis_pc-1], height=4*X_stdev[y_axis_pc-1], + edgecolor='r', fc='None', lw=2,alpha=0.15,zorder=10) + ax.add_patch(sigma1) + ax.add_patch(sigma2) + ax.add_patch(sigma3) + ax.add_patch(sigma4) + + if(with_images): + w_primary_save = model.w_primary + model.w_primary = True + pbar = tqdm(total=len(list(zip(x, y))),desc='Generating images') + for x0, y0 in zip(x, y): + #activations are already centered | x_0 and y_0 are therefore offsets based from the mean + #shift the mean in the corresponding directions and pass it to the network + with torch.no_grad(): + #print(X_global_mean.reshape((X_global_mean.shape[0],-1)).shape,X_comp_2.shape,np.array([x0,y0]).shape) + latent = (X_global_mean.reshape((X_global_mean.shape[0],-1)).squeeze() + X_comp_2 @ np.array([x0,y0]).T).reshape(X_global_mean.shape) + latent = torch.from_numpy(latent).to(device) + #print(latent.shape) + img = model.forward(latent).squeeze() + + if(len(img.shape) == 3): + _cmap = 'viridis' + img = np.clip(img.cpu().numpy().transpose(1,2,0).astype(np.float32),0,1) + else: + _cmap = 'gray' + img = img.cpu().numpy() + + img = resize(img,(256,256)) #downscale images + ab = AnnotationBbox(OffsetImage(img,0.2,cmap=_cmap), (x0, y0), frameon=False) + ax.add_artist(ab) + pbar.update(1) + + pbar.close() + model.w_primary = w_primary_save + #Save interactive image as binary + with open(outdir/model.name/layer_key.lower()/est_id/f'scatter_images{str(n_samples)}_{"PC"+str(x_axis_pc)}_{"PC"+str(y_axis_pc)}.pickle', 'wb') as pickle_file: + pickle.dump(fig, pickle_file) + else: + plt.savefig(outdir/model.name/layer_key.lower()/est_id / f'scatter{str(n_samples)}_{"PC"+str(x_axis_pc)}_{"PC"+str(y_axis_pc)}.jpg', dpi=300) + + show() + +def plot_explained_variance(X_var_ratio,X_dim,args): + #PCA on complete random space to compare: + transformer = get_estimator(args.estimator, args.components, args.sparsity) + seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state + random_samples = torch.from_numpy(np.random.RandomState(seed).randn(10000, X_dim[-1])) + transformer.fit(random_samples) + _, _, random_var_ratio = transformer.get_components() + + fig, ax = plt.subplots(1) + + X_cumm_lst = [] + X_cumm = 0 + random_cumm_lst = [] + random_cumm = 0 + for i in range(X_var_ratio.shape[0]): + X_cumm += X_var_ratio[i] + X_cumm_lst.append(X_cumm) + random_cumm += random_var_ratio[i] + random_cumm_lst.append(random_cumm) + + plt.plot(X_cumm_lst,label="activation space") + plt.plot(random_cumm_lst,label="random space") + #plt.plot(X_var_ratio,label="activation space") + #plt.plot(random_var_ratio,label="random space") + plt.title("cumulative variance ratio") + #plt.title("variance ratio") + plt.xlabel("principal component") + plt.ylabel("ratio") + plt.legend() + plt.show() def x_closest(p): distances = np.sqrt(np.sum((X - p)**2, axis=-1)) @@ -50,7 +168,7 @@ def make_mp4(imgs, duration_secs, outname): FFMPEG_BIN = shutil.which("ffmpeg") assert FFMPEG_BIN is not None, 'ffmpeg not found, install with "conda install -c conda-forge ffmpeg"' assert len(imgs[0].shape) == 3, 'Invalid shape of frame data' - + format_str = 'rgb24' if imgs[0].shape[-1] > 1 else 'gray' resolution = imgs[0].shape[0:2] fps = int(len(imgs) / duration_secs) @@ -59,7 +177,7 @@ def make_mp4(imgs, duration_secs, outname): '-f', 'rawvideo', '-vcodec','rawvideo', '-s', f'{resolution[0]}x{resolution[1]}', # size of one frame - '-pix_fmt', 'rgb24', + '-pix_fmt', f'{format_str}', '-r', f'{fps}', '-i', '-', # imput from pipe '-an', # no audio @@ -67,8 +185,10 @@ def make_mp4(imgs, duration_secs, outname): '-preset', 'slow', '-crf', '17', str(Path(outname).with_suffix('.mp4')) ] - + + print((imgs[0] * 255).astype(np.uint8).reshape(-1).shape) frame_data = np.concatenate([(x * 255).astype(np.uint8).reshape(-1) for x in imgs]) + print(frame_data.shape) with sp.Popen(command, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE) as p: ret = p.communicate(frame_data.tobytes()) if p.returncode != 0: @@ -83,29 +203,37 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std x_range = np.linspace(-scale, scale, n_cols, dtype=np.float32) # scale in sigmas rows = [] - for r in range(n_rows): + for r in trange(n_rows,unit="component"): curr_row = [] out_batch = create_strip_centered(inst, edit_type, layer_key, [latent], act_comp[r], lat_comp[r], act_stdev[r], lat_stdev[r], act_mean, lat_mean, scale, 0, -1, n_cols)[0] + #print("len(out_batch) =",len(out_batch)) for i, img in enumerate(out_batch): curr_row.append(('c{}_{:.2f}'.format(r, x_range[i]), img)) rows.append(curr_row[:n_cols]) inst.remove_edits() - + if make_plots: # If more rows than columns, make several blocks side by side n_blocks = 2 if n_rows > n_cols else 1 - + for r, data in enumerate(rows): # Add white borders - imgs = pad_frames([img for _, img in data]) - + imgs = pad_frames([img for _, img in data]) + coord = ((r * n_blocks) % n_rows) + ((r * n_blocks) // n_rows) plt.subplot(n_rows//n_blocks, n_blocks, 1 + coord) - plt.imshow(np.hstack(imgs)) - + + if(imgs[0].shape[2] > 1): + img_row = np.hstack(imgs) + _cmap = 'viridis' + else: + img_row = np.hstack(imgs).squeeze() + _cmap = 'gray' + plt.imshow(img_row,cmap=_cmap) + # Custom x-axis labels W = imgs[0].shape[1] # image width P = imgs[1].shape[1] # padding width @@ -119,6 +247,23 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std return [img for row in rows for img in row] +def get_edit_name(mode): + if mode == 'activation': + is_stylegan = 'StyleGAN' in args.model + is_w = layer_key in ['style', 'g_mapping','mapping'] + return 'W' if (is_stylegan and is_w) else 'ACT' + elif mode == 'latent': + return model.latent_space_name() + elif mode == 'both': + return 'BOTH' + else: + raise RuntimeError(f'Unknown edit mode {mode}') + +def show(): + if args.batch_mode: + plt.close('all') + else: + plt.show() ###################### ### Visualize results @@ -144,7 +289,6 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std device = torch.device('cuda' if has_gpu else 'cpu') layer_key = args.layer layer_name = layer_key #layer_key.lower().split('.')[-1] - basedir = Path(__file__).parent.resolve() outdir = basedir / 'out' @@ -152,6 +296,7 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std inst = get_instrumented_model(args.model, args.output_class, layer_key, device, use_w=args.use_w) model = inst.model feature_shape = inst.feature_shape[layer_key] + latent_shape = model.get_latent_shape() print('Feature shape:', feature_shape) @@ -196,12 +341,6 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std max_batch = args.batch_size or (get_max_batch_size(inst, device) if has_gpu else 1) print('Batch size:', max_batch) - def show(): - if args.batch_mode: - plt.close('all') - else: - plt.show() - print(f'[{timestamp()}] Creating visualizations') # Ensure visualization gets new samples @@ -221,60 +360,61 @@ def show(): sparsity = np.mean(X_comp == 0) # percentage of zero values in components print(f'Sparsity: {sparsity:.2f}') - def get_edit_name(mode): - if mode == 'activation': - is_stylegan = 'StyleGAN' in args.model - is_w = layer_key in ['style', 'g_mapping'] - return 'W' if (is_stylegan and is_w) else 'ACT' - elif mode == 'latent': - return model.latent_space_name() - elif mode == 'both': - return 'BOTH' - else: - raise RuntimeError(f'Unknown edit mode {mode}') - # Only visualize applicable edit modes - if args.use_w and layer_key in ['style', 'g_mapping']: + if args.use_w and layer_key in ['style', 'g_mapping','mapping']: edit_modes = ['latent'] # activation edit is the same else: edit_modes = ['activation', 'latent'] + #plot_explained_variance(X_var_ratio,X_comp.shape[1:],args) + + #Scatter 2D of PC1 - PC2 + #(X_comp,inst,model,layer_key,outdir,n_samples=100 + if(args.show_scatter): + make_2Dscatter(X_comp,X_global_mean,X_stdev,inst,model,layer_key,outdir,device, + n_samples=args.scatter_samples,with_images=args.scatter_images,x_axis_pc=args.scatter_x_axis_pc,y_axis_pc=args.scatter_y_axis_pc) + # Summary grid, real components for edit_mode in edit_modes: + print("edit_mode =",edit_mode) plt.figure(figsize = (14,12)) plt.suptitle(f"{args.estimator.upper()}: {model.name} - {layer_name}, {get_edit_name(edit_mode)} edit", size=16) make_grid(tensors.Z_global_mean, tensors.Z_global_mean, tensors.Z_comp, tensors.Z_stdev, tensors.X_global_mean, - tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14) + tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=args.np_directions, n_cols=args.np_images) plt.savefig(outdir_summ / f'components_{get_edit_name(edit_mode)}.jpg', dpi=300) show() + print("args.make_video =",args.make_video) if args.make_video: - components = 15 - instances = 150 - + components = args.nv_directions #10 + instances = args.nv_images#150 + # One reasonable, one over the top for sigma in [args.sigma, 3*args.sigma]: for c in range(components): for edit_mode in edit_modes: + print("Make grid for video") frames = make_grid(tensors.Z_global_mean, tensors.Z_global_mean, tensors.Z_comp[c:c+1, :, :], tensors.Z_stdev[c:c+1], tensors.X_global_mean, tensors.X_comp[c:c+1, :, :], tensors.X_stdev[c:c+1], n_rows=1, n_cols=instances, scale=sigma, make_plots=False, edit_type=edit_mode) plt.close('all') + print("Done!") frames = [x for _, x in frames] frames = frames + frames[::-1] + print("num_frames =",len(frames)) #num_frames = 300 make_mp4(frames, 5, outdir_comp / f'{get_edit_name(edit_mode)}_sigma{sigma}_comp{c}.mp4') - + # Summary grid, random directions # Using the stdevs of the principal components for same norm random_dirs_act = torch.from_numpy(get_random_dirs(n_comp, np.prod(sample_shape)).reshape(-1, *sample_shape)).to(device) random_dirs_z = torch.from_numpy(get_random_dirs(n_comp, np.prod(inst.input_shape)).reshape(-1, *latent_shape)).to(device) - + for edit_mode in edit_modes: plt.figure(figsize = (14,12)) plt.suptitle(f"{model.name} - {layer_name}, random directions w/ PC stdevs, {get_edit_name(edit_mode)} edit", size=16) make_grid(tensors.Z_global_mean, tensors.Z_global_mean, random_dirs_z, tensors.Z_stdev, - tensors.X_global_mean, random_dirs_act, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14) + tensors.X_global_mean, random_dirs_act, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=args.np_directions, n_cols=args.np_images) plt.savefig(outdir_summ / f'random_dirs_{get_edit_name(edit_mode)}.jpg', dpi=300) show() @@ -291,14 +431,14 @@ def get_edit_name(mode): plt.figure(figsize = (14,12)) plt.suptitle(f"{args.estimator.upper()}: {model.name} - {layer_name}, {get_edit_name(edit_mode)} edit", size=16) make_grid(z, tensors.Z_global_mean, tensors.Z_comp, tensors.Z_stdev, - tensors.X_global_mean, tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14) + tensors.X_global_mean, tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=args.np_directions, n_cols=args.np_images) plt.savefig(outdir_summ / f'samp{img_idx}_real_{get_edit_name(edit_mode)}.jpg', dpi=300) show() if args.make_video: - components = 5 - instances = 150 - + components = args.nv_directions #10 + instances = args.nv_images#150 + # One reasonable, one over the top for sigma in [args.sigma, 3*args.sigma]: #[2, 5]: for edit_mode in edit_modes: @@ -311,4 +451,4 @@ def get_edit_name(mode): frames = frames + frames[::-1] make_mp4(frames, 5, outdir_inst / f'{get_edit_name(edit_mode)}_sigma{sigma}_img{img_idx}_comp{c}.mp4') - print('Done in', datetime.datetime.now() - t_start) \ No newline at end of file + print('Done in', datetime.datetime.now() - t_start)