diff --git a/.gitmodules b/.gitmodules
index b03c2b0..7d4f44b 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -6,3 +6,8 @@
path = models/stylegan2/stylegan2-pytorch
url = https://github.com/harskish/stylegan2-pytorch.git
ignore = untracked
+
+[submodule "stylegan2_ada/stylegan2-ada-pytorch"]
+ path = models/stylegan2_ada/stylegan2-ada-pytorch
+ url = https://github.com/NVlabs/stylegan2-ada-pytorch.git
+ ignore = untracked
diff --git a/README.md b/README.md
index a6a71a2..44d4eda 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,42 @@
+# Changes compared to the original repo
+* **Added StyleGAN2-ada support**
+ The following classes for StyleGAN2-ada are available for automatic download:
+ * `ffhq`
+ * `afhqcat`
+ * `afhqdog`
+ * `afhqwild`
+ * `brecahad`
+ * `cifar10`
+ * `metfaces`
+
+ For a custom class add the name and the resolution in the `configs` dictonary in `models/wrappers.py` in the `StyleGAN2_ada` constructor and place the checkpoint-file at `models/checkpoints/stylegan2_ada/stylegan2_{class_name}_{resolution}.pkl` (replace {class_name} and {resolution} with the ones you added to the `configs` dict.)
+
+ `partial_forward` for StyleGAN2-ada is currently not fully implemented, which means if you use a layer in the synthesis network as activation space, it could take longer than with other models, since the complete foreward-pass is always computeted, even if the used layer is located somewhere earlier.
+* **Added grayscale image support**
+* **Added another progress bar during the creation of the images**
+* **Added new args for `visualize.py` to control the outcome without changing the code:**
+
+argument | description | arg-datatype
+--- | --- | ---
+`--plot_directions` | Number of components/directions to plot |int
+`--plot_images` | Number of images per component/direction to plot | int
+`--video_directions` | Number of components/directions to create a video of | int
+`--video_images` | Number of frames within a video of one direction/component | int
+* **Added interactive 2D scatter plot of the used activation space:**
+
+
+
+argument | description | arg-datatype
+--- | --- | ---
+`--scatter` | Activate scatter-plot | -
+`--scatter_images` | Activate plotting corresponding generated images for each point | -
+`--scatter_samples` | Number of samples in the 2D scatter plot | int
+`--scatter_x` | Number of principal component for x-axis in the scatter plot | int
+`--scatter_y` | Number of principal component for y-axis in the scatter plot | int
+
+If `--scatter_images` is active, the interactive plot is saved as `.pickle` which can be opened with `python open_scatter.py [path]`.
+
+
# GANSpace: Discovering Interpretable GAN Controls


diff --git a/StyleGAN_scatter.png b/StyleGAN_scatter.png
new file mode 100644
index 0000000..80ebf45
Binary files /dev/null and b/StyleGAN_scatter.png differ
diff --git a/config.py b/config.py
index 5af238a..dbfdb78 100644
--- a/config.py
+++ b/config.py
@@ -27,10 +27,10 @@ def __str__(self):
for k, v in self.__dict__.items():
if k == 'default_args':
continue
-
+
in_default = k in self.default_args
same_value = self.default_args.get(k) == v
-
+
if in_default and same_value:
default[k] = v
else:
@@ -42,15 +42,15 @@ def __str__(self):
}
return json.dumps(config, indent=4)
-
+
def __repr__(self):
return self.__str__()
-
+
def from_dict(self, dictionary):
for k, v in dictionary.items():
setattr(self, k, v)
return self
-
+
def from_args(self, args=sys.argv[1:]):
parser = argparse.ArgumentParser(description='GAN component analysis config')
parser.add_argument('--model', dest='model', type=str, default='StyleGAN', help='The network to analyze') # StyleGAN, DCGAN, ProGAN, BigGAN-XYZ
@@ -67,6 +67,22 @@ def from_args(self, args=sys.argv[1:]):
parser.add_argument('--sigma', type=float, default=2.0, help='Number of stdevs to walk in visualize.py')
parser.add_argument('--inputs', type=str, default=None, help='Path to directory with named components')
parser.add_argument('--seed', type=int, default=None, help='Seed used in decomposition')
+ parser.add_argument('--plot_directions', dest='np_directions', type=int, default=14, help='Number of components/directions to plot')
+ parser.add_argument('--plot_images', dest='np_images', type=int, default=5, help='Number of images per component/direction to plot')
+ parser.add_argument('--video_directions', dest='nv_images', type=int, default=5, help='Number of components/directions to create a video of')
+ parser.add_argument('--video_images', dest='nv_images', type=int, default=150, help='Number of frames within a video of one direction/component')
+ parser.add_argument('--scatter', dest='show_scatter', action='store_true', help='Plot a 2D scatter-plot of the activation space of two principal components')
+ parser.add_argument('--scatter_samples', dest='scatter_samples', type=int, default=1000, help='Number of samples in the 2D scatter plot of the first two principal components')
+ parser.add_argument('--scatter_images', dest='scatter_images', action='store_true', help='Plot encoded images instead of points within the scatter plot')
+ parser.add_argument('--scatter_x', dest='scatter_x_axis_pc', type=int, default=1, help='Number of PC for x-axis in the scatter plot')
+ parser.add_argument('--scatter_y', dest='scatter_y_axis_pc', type=int, default=2, help='Number of PC for y-axis in the scatter plot')
+
+
args = parser.parse_args(args)
+ assert args.np_images % 2 != 0, 'The number of plotted images per component (--plot_images) have to be odd.'
+
+ if(args.model == "StyleGAN2-ada" and args.layer == "g_mapping"):
+ print("No layer \'g_mapping\' in StyleGAN2-ada. Assuming you meant \'mapping\'")
+ args.layer = "mapping"
- return self.from_dict(args.__dict__)
\ No newline at end of file
+ return self.from_dict(args.__dict__)
diff --git a/decomposition.py b/decomposition.py
index 4819e33..fa548a3 100644
--- a/decomposition.py
+++ b/decomposition.py
@@ -87,7 +87,7 @@ def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config):
n_samp = max(10_000, config.n) // B * B # make divisible
n_comp = comp.shape[0]
latent_dims = inst.model.get_latent_dims()
-
+
# We're looking for M s.t. M*P*G'(Z) = Z => M*A = Z
# Z = batch of latent vectors (n_samples x latent_dims)
# G'(Z) = batch of activations at intermediate layer
@@ -104,7 +104,7 @@ def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config):
# Dimensions other way around, so these are actually the transposes
A = np.zeros((n_samp, n_comp), dtype=np.float32)
Z = np.zeros((n_samp, latent_dims), dtype=np.float32)
-
+
# Project tensor X onto PCs, return coordinates
def project(X, comp):
N = X.shape[0]
@@ -131,7 +131,7 @@ def project(X, comp):
# gelsy = complete orthogonal factorization; sometimes faster
# gelss = SVD; slow but less memory hungry
M_t = scipy.linalg.lstsq(A, Z, lapack_driver='gelsd')[0] # torch.lstsq(Z, A)[0][:n_comp, :]
-
+
# Solution given by rows of M_t
Z_comp = M_t[:n_comp, :]
Z_mean = np.mean(Z, axis=0, keepdims=True)
@@ -182,6 +182,12 @@ def compute(config, dump_name, instrumented_model):
inst.retain_layer(layer_key)
model.partial_forward(model.sample_latent(1), layer_key)
sample_shape = inst.retained_features()[layer_key].shape
+
+ #StyleGAN2-ada's mapping networks copies it's result 18 times to [B,18,512] so the sample shape is different than the latent shape
+ #from wrapper, because it only returns [B,512], so that GANSpace can modify only specifc Style-Layers
+ if(model.model_name == "StyleGAN2_ada" and model.w_primary):
+ sample_shape = (sample_shape[0],sample_shape[2])
+
sample_dims = np.prod(sample_shape)
print('Feature shape:', sample_shape)
@@ -218,7 +224,7 @@ def compute(config, dump_name, instrumented_model):
# Must not depend on chosen batch size (reproducibility)
NB = max(B, max(2_000, 3*config.components)) # ipca: as large as possible!
-
+
samples = None
if not transformer.batch_support:
samples = np.zeros((N + NB, sample_dims), dtype=np.float32)
@@ -236,7 +242,7 @@ def compute(config, dump_name, instrumented_model):
latents[i*B:(i+1)*B] = model.sample_latent(n_samples=B).cpu().numpy()
# Decomposition on non-Gaussian latent space
- samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W'
+ samples_are_latents = layer_key in ['g_mapping', 'mapping', 'style'] and inst.model.latent_space_name() == 'W'
canceled = False
try:
@@ -245,7 +251,7 @@ def compute(config, dump_name, instrumented_model):
for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True):
for mb in range(0, NB, B):
z = torch.from_numpy(latents[gi+mb:gi+mb+B]).to(device)
-
+
if samples_are_latents:
# Decomposition on latents directly (e.g. StyleGAN W)
batch = z.reshape((B, -1))
@@ -253,7 +259,7 @@ def compute(config, dump_name, instrumented_model):
# Decomposition on intermediate layer
with torch.no_grad():
model.partial_forward(z, layer_key)
-
+
# Permuted to place PCA dimensions last
batch = inst.retained_features()[layer_key].reshape((B, -1))
@@ -268,21 +274,21 @@ def compute(config, dump_name, instrumented_model):
except KeyboardInterrupt:
if not transformer.batch_support:
sys.exit(1) # no progress yet
-
+
dump_name = dump_name.parent / dump_name.name.replace(f'n{N}', f'n{gi}')
print(f'Saving current state to "{dump_name.name}" before exiting')
canceled = True
-
+
if not transformer.batch_support:
X = samples # Use all samples
X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32) # TODO: activations surely multi-modal...!
X -= X_global_mean
-
+
print(f'[{timestamp()}] Fitting whole batch')
t_start_fit = datetime.datetime.now()
transformer.fit(X)
-
+
print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}')
assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero'
else:
@@ -291,7 +297,7 @@ def compute(config, dump_name, instrumented_model):
X -= X_global_mean
X_comp, X_stdev, X_var_ratio = transformer.get_components()
-
+
assert X_comp.shape[1] == sample_dims \
and X_comp.shape[0] == config.components \
and X_global_mean.shape[1] == sample_dims \
@@ -349,6 +355,7 @@ def compute(config, dump_name, instrumented_model):
del inst
del model
+
del X
del X_comp
del random_dirs
@@ -363,20 +370,20 @@ def get_or_compute(config, model=None, submit_config=None, force_recompute=False
if submit_config is None:
wrkdir = str(Path(__file__).parent.resolve())
submit_config = SimpleNamespace(run_dir_root = wrkdir, run_dir = wrkdir)
-
+
# Called directly by run.py
return _compute(submit_config, config, model, force_recompute)
def _compute(submit_config, config, model=None, force_recompute=False):
basedir = Path(submit_config.run_dir)
outdir = basedir / 'out'
-
+
if config.n is None:
raise RuntimeError('Must specify number of samples with -n=XXX')
if model and not isinstance(model, InstrumentedModel):
raise RuntimeError('Passed model has to be wrapped in "InstrumentedModel"')
-
+
if config.use_w and not 'StyleGAN' in config.model:
raise RuntimeError(f'Cannot change latent space of non-StyleGAN model {config.model}')
@@ -398,5 +405,25 @@ def _compute(submit_config, config, model=None, force_recompute=False):
t_start = datetime.datetime.now()
compute(config, dump_path, model)
print('Total time:', datetime.datetime.now() - t_start)
-
- return dump_path
\ No newline at end of file
+
+ return dump_path
+
+
+
+def imscatter(x, y, image, ax=None, zoom=1):
+ if ax is None:
+ ax = plt.gca()
+ try:
+ image = plt.imread(image)
+ except TypeError:
+ # Likely already an array...
+ pass
+ im = OffsetImage(image, zoom=zoom)
+ x, y = np.atleast_1d(x, y)
+ artists = []
+ for x0, y0 in zip(x, y):
+ ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
+ artists.append(ax.add_artist(ab))
+ ax.update_datalim(np.column_stack([x, y]))
+ ax.autoscale()
+ return artists
diff --git a/models/stylegan2_ada/__init__.py b/models/stylegan2_ada/__init__.py
new file mode 100644
index 0000000..e50fd7a
--- /dev/null
+++ b/models/stylegan2_ada/__init__.py
@@ -0,0 +1,18 @@
+import sys
+import os
+import shutil
+import glob
+import platform
+from pathlib import Path
+
+current_path = os.getcwd()
+
+module_path = Path(__file__).parent / 'stylegan2-ada-pytorch'
+sys.path.append(str(module_path.resolve()))
+os.chdir(module_path)
+
+import generate
+import legacy
+import dnnlib
+
+os.chdir(current_path)
diff --git a/models/stylegan2_ada/stylegan2-ada-pytorch b/models/stylegan2_ada/stylegan2-ada-pytorch
new file mode 160000
index 0000000..6f160b3
--- /dev/null
+++ b/models/stylegan2_ada/stylegan2-ada-pytorch
@@ -0,0 +1 @@
+Subproject commit 6f160b3d22b8b178ebe533a50d4d5e63aedba21d
diff --git a/models/wrappers.py b/models/wrappers.py
index 558e928..775bbd3 100644
--- a/models/wrappers.py
+++ b/models/wrappers.py
@@ -8,6 +8,7 @@
# OF ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.
+import math
import torch
import numpy as np
import re
@@ -21,8 +22,10 @@
from . import biggan
from . import stylegan
from . import stylegan2
+from . import stylegan2_ada
from abc import abstractmethod, ABC as AbstractBaseClass
from functools import singledispatch
+import PIL.Image
class BaseModel(AbstractBaseClass, torch.nn.Module):
@@ -80,7 +83,7 @@ def sample_np(self, z=None, n_samples=1, seed=None):
z = torch.tensor(z).to(self.device)
img = self.forward(z)
img_np = img.permute(0, 2, 3, 1).cpu().detach().numpy()
- return np.clip(img_np, 0.0, 1.0).squeeze()
+ return np.clip(img_np, 0.0, 1.0)#.squeeze()
# For models that use part of latent as conditioning
def get_conditional_state(self, z):
@@ -93,6 +96,140 @@ def set_conditional_state(self, z, c):
def named_modules(self, *args, **kwargs):
return self.model.named_modules(*args, **kwargs)
+# StyleGAN2-ada-pytorch
+class StyleGAN2_ada(BaseModel):
+ def __init__(self, device, class_name, truncation=1.0, use_w=False):
+ super(StyleGAN2_ada, self).__init__('StyleGAN2_ada', class_name or 'ffhq')
+ self.device = device
+ self.truncation = truncation
+ self.latent_avg = None
+ self.w_primary = use_w
+
+ # Image widths
+ configs = {
+ 'ffhq': 1024,
+ 'afhqcat': 512,
+ 'afhqdog': 512,
+ 'afhqwild': 512,
+ 'brecahad': 1024,
+ 'cifar10': 32,
+ 'metfaces': 1024,
+ }
+
+ assert self.outclass in configs, \
+ f'Invalid StyleGAN2-ada class {self.outclass}, should be one of [{", ".join(configs.keys())}]'
+
+ self.resolution = configs[self.outclass]
+ self.name = f'StyleGAN2-ada-{self.outclass}'
+ self.has_latent_residual = True
+ self.load_model()
+ self.set_noise_seed(0)
+
+ def latent_space_name(self):
+ return 'W' if self.w_primary else 'Z'
+
+ def use_w(self):
+ self.w_primary = True
+
+ def use_z(self):
+ self.w_primary = False
+
+
+ # URLs created with https://sites.google.com/site/gdocs2direct/
+ def download_checkpoint(self, outfile):
+ checkpoints = {
+ 'ffhq': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl',
+ 'afhqcat': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqcat.pkl',
+ 'afhqdog': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqdog.pkl',
+ 'afhqwild': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/afhqwild.pkl',
+ 'brecahad': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/brecahad.pkl',
+ 'cifar10': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl',
+ 'metfaces': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl'
+ }
+
+ url = checkpoints[self.outclass]
+ download_ckpt(url, outfile)
+
+
+ def load_model(self):
+ checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints')
+ checkpoint = Path(checkpoint_root) / f'stylegan2_ada/stylegan2_{self.outclass}_{self.resolution}.pkl'
+
+ if not checkpoint.is_file():
+ os.makedirs(checkpoint.parent, exist_ok=True)
+ self.download_checkpoint(checkpoint)
+
+ with stylegan2_ada.dnnlib.util.open_url(str(checkpoint)) as f:
+ self.model = stylegan2_ada.legacy.load_network_pkl(f)['G_ema'].to(self.device)
+
+
+ def sample_latent(self, n_samples=1, seed=None, truncation=None):
+ if seed is None:
+ seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state
+
+ z = torch.from_numpy(np.random.RandomState(seed).randn(n_samples, self.model.z_dim)).to(self.device)
+ c = torch.zeros([n_samples, self.model.c_dim], device=self.device) #no conditioning at the moment
+
+ if self.w_primary:
+ z = self.model.mapping(z,c)[:,0,:] #Just use one of the 18 copies
+ return z
+
+ def get_max_latents(self):
+ return self.model.num_ws
+
+ def set_output_class(self, new_class):
+ if self.outclass != new_class:
+ raise RuntimeError('StyleGAN2-ada: cannot change output class without reloading')
+
+ def forward(self, x):
+ if isinstance(x, list):
+ assert len(x) == self.model.num_ws, 'Must provide 1 or {0} latents'.format(str(self.model.num_ws))
+ if not self.w_primary:
+ label_shape = list(x[0].shape)
+ label_shape[-1] = self.model.c_dim
+ label = torch.zeros(label_shape, device=self.device)
+ x = [self.model.mapping.forward(l,label, truncation_psi=self.truncation)[:,0,:] for l in x]
+ x = torch.stack(x, dim=1)
+
+ else:
+ if not self.w_primary:
+ label_shape = list(x.shape)
+ label_shape[-1] = self.model.c_dim
+ label = torch.zeros(label_shape, device=self.device)
+ x = self.model.mapping.forward(x,label, truncation_psi=self.truncation)[:,0,:]
+ if(len(x.shape) != 3 or x.shape[1] != self.model.num_ws):
+ x = x.unsqueeze(1).expand(-1, self.model.num_ws, -1)
+
+ img = self.model.synthesis.forward(x, noise_mode='const',force_fp32= self.device.type == 'cpu')
+
+ return 0.5*(img+1)
+
+ def partial_forward(self, x, layer_name):
+ mapping = self.model.mapping
+ G = self.model.synthesis
+ #trunc = self.model._modules.get('truncation', lambda x : x)
+ #print("TEST1",x.shape)
+
+ if not self.w_primary:
+ c = torch.zeros([x.shape[0], self.model.c_dim], device=self.device)
+ x = mapping.forward(x,c) # handles list inputs
+
+ # Whole mapping
+ if 'mapping' in layer_name:
+ return
+ else:
+ G.forward(x, noise_mode='const',force_fp32= self.device.type == 'cpu') #TODO: Implement partial foreward for faster computation
+ return
+
+
+ def set_noise_seed(self, seed):
+ torch.manual_seed(seed)
+ self.noise = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=self.device)]
+
+ for i in range(3, int(math.log(self.model.img_resolution,2)) + 1):
+ for _ in range(2):
+ self.noise.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=self.device))
+
# PyTorch port of StyleGAN 2
class StyleGAN2(BaseModel):
def __init__(self, device, class_name, truncation=1.0, use_w=False):
@@ -153,13 +290,13 @@ def download_checkpoint(self, outfile):
def load_model(self):
checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints')
checkpoint = Path(checkpoint_root) / f'stylegan2/stylegan2_{self.outclass}_{self.resolution}.pt'
-
+
self.model = stylegan2.Generator(self.resolution, 512, 8).to(self.device)
if not checkpoint.is_file():
os.makedirs(checkpoint.parent, exist_ok=True)
self.download_checkpoint(checkpoint)
-
+
ckpt = torch.load(checkpoint)
self.model.load_state_dict(ckpt['g_ema'], strict=False)
self.latent_avg = ckpt['latent_avg'].to(self.device)
@@ -172,7 +309,7 @@ def sample_latent(self, n_samples=1, seed=None, truncation=None):
z = torch.from_numpy(
rng.standard_normal(512 * n_samples)
.reshape(n_samples, 512)).float().to(self.device) #[N, 512]
-
+
if self.w_primary:
z = self.model.style(z)
@@ -184,7 +321,7 @@ def get_max_latents(self):
def set_output_class(self, new_class):
if self.outclass != new_class:
raise RuntimeError('StyleGAN2: cannot change output class without reloading')
-
+
def forward(self, x):
x = x if isinstance(x, list) else [x]
out, _ = self.model(x, noise=self.noise,
@@ -246,7 +383,7 @@ def partial_forward(self, x, layer_name):
out = conv2(out, latent[:, i + 1], noise=noise[noise_i + 1])
if f'convs.{i}' in layer_name:
return
-
+
skip = to_rgb(out, latent[:, i + 2], skip)
if f'to_rgbs.{i//2}' in layer_name:
return
@@ -280,7 +417,7 @@ def __init__(self, device, class_name, truncation=1.0, use_w=False):
'bedrooms': 256,
'cars': 512,
'cats': 256,
-
+
# From https://github.com/justinpinkney/awesome-pretrained-stylegan
'vases': 1024,
'wikiart': 512,
@@ -311,7 +448,7 @@ def use_z(self):
def load_model(self):
checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints')
checkpoint = Path(checkpoint_root) / f'stylegan/stylegan_{self.outclass}_{self.resolution}.pt'
-
+
self.model = stylegan.StyleGAN_G(self.resolution).to(self.device)
urls_tf = {
@@ -341,7 +478,7 @@ def load_model(self):
download_ckpt(urls_tf[self.outclass], checkpoint_tf)
print('Converting TensorFlow checkpoint to PyTorch')
self.model.export_from_tf(checkpoint_tf)
-
+
self.model.load_weights(checkpoint)
def sample_latent(self, n_samples=1, seed=None, truncation=None):
@@ -352,10 +489,11 @@ def sample_latent(self, n_samples=1, seed=None, truncation=None):
noise = torch.from_numpy(
rng.standard_normal(512 * n_samples)
.reshape(n_samples, 512)).float().to(self.device) #[N, 512]
-
+
if self.w_primary:
noise = self.model._modules['g_mapping'].forward(noise)
-
+
+ #print("NOISE shape",noise.shape)
return noise
def get_max_latents(self):
@@ -374,7 +512,7 @@ def partial_forward(self, x, layer_name):
mapping = self.model._modules['g_mapping']
G = self.model._modules['g_synthesis']
trunc = self.model._modules.get('truncation', lambda x : x)
-
+ #print("TEST1",x.shape)
if not self.w_primary:
x = mapping.forward(x) # handles list inputs
@@ -423,7 +561,7 @@ def set_noise_seed(self, seed):
def for_each_child(this, name, func):
children = getattr(this, '_modules', [])
for child_name, module in children.items():
- for_each_child(module, f'{name}.{child_name}', func)
+ for_each_child(module, f'{name}.{child_name}', func)
func(this, name)
def modify(m, name):
@@ -460,7 +598,7 @@ def get_conditional_state(self, z):
def set_conditional_state(self, z, c):
z[:, -20:] = c
return z
-
+
def forward(self, x):
out = self.base_model.test(x)
return 0.5*(out+1)
@@ -483,7 +621,7 @@ def __init__(self, device, lsun_class=None):
def load_model(self):
checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints')
checkpoint = Path(checkpoint_root) / f'progan/{self.outclass}_lsun.pth'
-
+
if not checkpoint.is_file():
os.makedirs(checkpoint.parent, exist_ok=True)
url = f'http://netdissect.csail.mit.edu/data/ganmodel/karras/{self.outclass}_lsun.pth'
@@ -501,7 +639,7 @@ def forward(self, x):
if isinstance(x, list):
assert len(x) == 1, "ProGAN only supports a single global latent"
x = x[0]
-
+
out = self.model.forward(x)
return 0.5*(out+1)
@@ -534,15 +672,15 @@ def __init__(self, device, resolution, class_name, truncation=1.0):
# Default implementaiton fails without an internet
# connection, even if the model has been cached
- def load_model(self, name):
+ def load_model(self, name):
if name not in biggan.model.PRETRAINED_MODEL_ARCHIVE_MAP:
raise RuntimeError('Unknown BigGAN model name', name)
-
+
checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints')
model_path = Path(checkpoint_root) / name
os.makedirs(model_path, exist_ok=True)
-
+
model_file = model_path / biggan.model.WEIGHTS_NAME
config_file = model_path / biggan.model.CONFIG_NAME
model_url = biggan.model.PRETRAINED_MODEL_ARCHIVE_MAP[name]
@@ -562,10 +700,10 @@ def load_model(self, name):
def sample_latent(self, n_samples=1, truncation=None, seed=None):
if seed is None:
seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state
-
+
noise_vector = biggan.truncated_noise_sample(truncation=truncation or self.truncation, batch_size=n_samples, seed=seed)
- noise = torch.from_numpy(noise_vector) #[N, 128]
-
+ noise = torch.from_numpy(noise_vector) #[N, 128]
+
return noise.to(self.device)
# One extra for gen_z
@@ -577,7 +715,7 @@ def get_conditional_state(self, z):
def set_conditional_state(self, z, c):
self.v_class = c
-
+
def is_valid_class(self, class_id):
if isinstance(class_id, int):
return class_id < 1000
@@ -595,8 +733,8 @@ def set_output_class(self, class_id):
self.v_class = torch.from_numpy(biggan.one_hot_from_names([class_id])).to(self.device)
else:
raise RuntimeError(f'Unknown class identifier {class_id}')
-
- def forward(self, x):
+
+ def forward(self, x):
# Duplicate along batch dimension
if isinstance(x, list):
c = self.v_class.repeat(x[0].shape[0], 1)
@@ -626,7 +764,7 @@ def partial_forward(self, x, layer_name):
else:
class_label = self.v_class.repeat(x[0].shape[0], 1)
embed = len(x)*[self.model.embeddings(class_label)]
-
+
assert len(x) == self.model.n_latents, f'Expected {self.model.n_latents} latents, got {len(x)}'
assert len(embed) == self.model.n_latents, f'Expected {self.model.n_latents} class vectors, got {len(class_label)}'
@@ -653,18 +791,18 @@ def get_model(name, output_class, device, **kwargs):
# Check if optionally provided existing model can be reused
inst = kwargs.get('inst', None)
model = kwargs.get('model', None)
-
+
if inst or model:
cached = model or inst.model
-
+
network_same = (cached.model_name == name)
outclass_same = (cached.outclass == output_class)
can_change_class = ('BigGAN' in name)
-
+
if network_same and (outclass_same or can_change_class):
cached.set_output_class(output_class)
return cached
-
+
if name == 'DCGAN':
import warnings
warnings.filterwarnings("ignore", message="nn.functional.tanh is deprecated")
@@ -678,6 +816,8 @@ def get_model(name, output_class, device, **kwargs):
model = StyleGAN(device, class_name=output_class)
elif name == 'StyleGAN2':
model = StyleGAN2(device, class_name=output_class)
+ elif name == 'StyleGAN2-ada':
+ model = StyleGAN2_ada(device, class_name=output_class)
else:
raise RuntimeError(f'Unknown model {name}')
@@ -709,7 +849,7 @@ def get_instrumented_model(name, output_class, layers, device, **kwargs):
print(f"Layer '{layer_name}' not found in model!")
print("Available layers:", '\n'.join(module_names))
raise RuntimeError(f"Unknown layer '{layer_name}''")
-
+
# Reset StyleGANs to z mode for shape annotation
if hasattr(model, 'use_z'):
model.use_z()
@@ -732,4 +872,4 @@ def get_instrumented_model(name, output_class, layers, device, **kwargs):
@get_instrumented_model.register(Config)
def _(cfg, device, **kwargs):
kwargs['use_w'] = kwargs.get('use_w', cfg.use_w) # explicit arg can override cfg
- return get_instrumented_model(cfg.model, cfg.output_class, cfg.layer, device, **kwargs)
\ No newline at end of file
+ return get_instrumented_model(cfg.model, cfg.output_class, cfg.layer, device, **kwargs)
diff --git a/netdissect/modelconfig.py b/netdissect/modelconfig.py
index d0ee37a..caf886c 100644
--- a/netdissect/modelconfig.py
+++ b/netdissect/modelconfig.py
@@ -21,7 +21,7 @@ def create_instrumented_model(args, **kwargs):
gen: True for a generator model. One-pixel input assumed.
imgsize: For non-generator models, (y, x) dimensions for RGB input.
cuda: True to use CUDA.
-
+
The constructed model will be decorated with the following attributes:
input_shape: (usually 4d) tensor shape for single-image input.
output_shape: 4d tensor shape for output.
diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py
index 5f421c4..a338308 100644
--- a/notebooks/notebook_utils.py
+++ b/notebooks/notebook_utils.py
@@ -47,10 +47,10 @@ def _create_strip_impl(inst, mode, layer, latents, x_comp, z_comp, act_stdev, la
act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center)
# Batch over frames if there are more frames in strip than latents
-def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center):
+def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center):
inst.close()
batch_frames = [[] for _ in range(len(latents))]
-
+
B = min(num_frames, 5)
lep_padded = ((num_frames - 1) // B + 1) * B
sigma_range = np.linspace(-sigma, sigma, num_frames)
@@ -60,11 +60,11 @@ def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_st
for i_batch in range(lep_padded // B):
sigmas = sigma_range[i_batch*B:(i_batch+1)*B]
-
+
for i_lat in range(len(latents)):
z_single = latents[i_lat]
z_batch = z_single.repeat_interleave(B, axis=0)
-
+
zeroing_offset_act = 0
zeroing_offset_lat = 0
if center:
@@ -90,6 +90,7 @@ def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_st
z[i] = z[i] - zeroing_offset_lat + delta
if mode in ['activation', 'both']:
+ #HIER DRIN bei dem 1-ner batch sample
comp_batch = x_comp.repeat_interleave(B, axis=0)
delta = comp_batch * sigmas.reshape([-1] + [1]*(comp_batch.ndim - 1))
inst.edit_layer(layer, offset=delta * act_stdev - zeroing_offset_act)
@@ -97,7 +98,7 @@ def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_st
img_batch = inst.model.sample_np(z)
if img_batch.ndim == 3:
img_batch = np.expand_dims(img_batch, axis=0)
-
+
for j, img in enumerate(img_batch):
idx = i_batch*B + j
if idx < num_frames:
@@ -106,7 +107,7 @@ def _create_strip_batch_sigma(inst, mode, layer, latents, x_comp, z_comp, act_st
return batch_frames
# Batch over latents if there are more latents than frames in strip
-def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center):
+def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stdev, lat_stdev, act_mean, lat_mean, sigma, layer_start, layer_end, num_frames, center):
n_lat = len(latents)
B = min(n_lat, 5)
@@ -114,7 +115,7 @@ def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stde
if layer_end < 0 or layer_end > max_lat:
layer_end = max_lat
layer_start = np.clip(layer_start, 0, layer_end)
-
+
len_padded = ((n_lat - 1) // B + 1) * B
batch_frames = [[] for _ in range(n_lat)]
@@ -122,14 +123,14 @@ def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stde
zs = latents[i_batch*B:(i_batch+1)*B]
if len(zs) == 0:
continue
-
+
z_batch_single = torch.cat(zs, 0)
inst.close() # don't retain, remove edits
sigma_range = np.linspace(-sigma, sigma, num_frames, dtype=np.float32)
normalize = lambda v : v / torch.sqrt(torch.sum(v**2, dim=-1, keepdim=True) + 1e-8)
-
+
zeroing_offset_act = 0
zeroing_offset_lat = 0
if center:
@@ -163,7 +164,7 @@ def _create_strip_batch_lat(inst, mode, layer, latents, x_comp, z_comp, act_stde
img_batch = inst.model.sample_np(z)
if img_batch.ndim == 3:
img_batch = np.expand_dims(img_batch, axis=0)
-
+
for j, img in enumerate(img_batch):
img_idx = i_batch*B + j
if img_idx < n_lat:
@@ -176,12 +177,12 @@ def save_frames(title, model_name, rootdir, frames, strip_width=10):
test_name = prettify_name(title)
outdir = f'{rootdir}/{model_name}/{test_name}'
makedirs(outdir, exist_ok=True)
-
+
# Limit maximum resolution
max_H = 512
real_H = frames[0][0].shape[0]
ratio = min(1.0, max_H / real_H)
-
+
# Combined with first 10
strips = [np.hstack(frames) for frames in frames[:strip_width]]
if len(strips) >= strip_width:
@@ -193,8 +194,8 @@ def save_frames(title, model_name, rootdir, frames, strip_width=10):
im.save(f'{outdir}/{test_name}_all.png')
else:
print('Too few strips to create grid, creating just strips!')
-
+
for ex_num, strip in enumerate(frames[:strip_width]):
im = Image.fromarray(np.uint8(255*np.hstack(pad_frames(strip))))
im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS)
- im.save(f'{outdir}/{test_name}_{ex_num}.png')
\ No newline at end of file
+ im.save(f'{outdir}/{test_name}_{ex_num}.png')
diff --git a/open_scatter.py b/open_scatter.py
new file mode 100644
index 0000000..7b7f3f9
--- /dev/null
+++ b/open_scatter.py
@@ -0,0 +1,23 @@
+import matplotlib.pyplot as plt
+import pickle
+import os
+import sys
+
+def show_figure(fig):
+ # create a dummy figure and use its
+ # manager to display "fig"
+ dummy = plt.figure(num=1)
+ new_manager = dummy.canvas.manager
+ new_manager.canvas.figure = fig
+ fig.set_canvas(new_manager.canvas)
+ plt.axis('equal')
+ plt.show()
+
+plt.switch_backend('TkAgg')
+
+path = sys.argv[1]
+if os.path.isfile(path):
+ if(path.split('.')[-1] == 'pickle'):
+ print("Loading", path.split('/')[-1])
+ figx = pickle.load(open(path, 'rb'))
+ show_figure(figx)
diff --git a/utils.py b/utils.py
index e08ff90..5d09ea8 100644
--- a/utils.py
+++ b/utils.py
@@ -30,13 +30,13 @@ def pad_frames(strip, pad_fract_horiz=64, pad_fract_vert=0, pad_value=None):
pad_value = 1.0
else:
pad_value = np.iinfo(dtype).max
-
+
frames = [strip[0]]
for frame in strip[1:]:
if pad_fract_horiz > 0:
- frames.append(pad_value*np.ones((frame.shape[0], frame.shape[1]//pad_fract_horiz, 3), dtype=dtype))
+ frames.append(pad_value*np.ones((frame.shape[0], frame.shape[1]//pad_fract_horiz, frame.shape[2]), dtype=dtype))
elif pad_fract_vert > 0:
- frames.append(pad_value*np.ones((frame.shape[0]//pad_fract_vert, frame.shape[1], 3), dtype=dtype))
+ frames.append(pad_value*np.ones((frame.shape[0]//pad_fract_vert, frame.shape[1], frame.shape[2]), dtype=dtype))
frames.append(frame)
return frames
@@ -53,7 +53,7 @@ def download_google_drive(url, output_name):
if tokens is None:
tokens = re.search('(confirm=.)', str(r.content))
assert tokens is not None, 'Could not extract token from response'
-
+
url = url.replace('id=', f'{tokens[1]}&id=')
r = session.get(url, allow_redirects=True)
r.raise_for_status()
@@ -89,4 +89,4 @@ def download_ckpt(url, output_name):
elif 'mega.nz' in url:
download_manual(url, output_name)
else:
- download_generic(url, output_name)
\ No newline at end of file
+ download_generic(url, output_name)
diff --git a/visualize.py b/visualize.py
index 433ae2e..13def69 100644
--- a/visualize.py
+++ b/visualize.py
@@ -28,10 +28,128 @@
import sys
import datetime
import argparse
-from tqdm import trange
+from tqdm import trange, tqdm
from config import Config
from decomposition import get_random_dirs, get_or_compute, get_max_batch_size, SEED_VISUALIZATION
-from utils import pad_frames
+from utils import pad_frames
+from matplotlib.offsetbox import OffsetImage, AnnotationBbox
+import pickle
+from skimage.transform import resize
+from matplotlib.patches import Ellipse
+from estimators import get_estimator
+
+def make_2Dscatter(X_comp,X_global_mean,X_stdev,inst,model,layer_key,outdir,device,n_samples=100,with_images=False,x_axis_pc=1,y_axis_pc=2):
+ assert n_samples % 5 == 0, "n_samples has to be dividable by 5"
+ samples_are_from_w = layer_key in ['g_mapping', 'mapping', 'style'] and inst.model.latent_space_name() == 'W'
+ with torch.no_grad():
+ #draw new latents
+ latents = model.sample_latent(n_samples=n_samples)
+
+ if(samples_are_from_w):
+ activations = latents
+ else:
+ all_activations = []
+ for i in range(0,int(n_samples/5)):
+ z = latents[i*5:(i+1)*5:1]
+ model.partial_forward(z,layer_name)
+
+ activations_part = inst.retained_features()[layer_key].reshape((5, -1))
+ all_activations.append(activations_part)
+ activations = torch.cat(all_activations)
+
+ global_mean = torch.from_numpy(X_global_mean.reshape(-1))
+ activations = torch.sub(activations.cpu(),global_mean)
+
+ X_comp_2 = X_comp.squeeze().reshape((X_comp.shape[0],-1)).transpose(1,0)[:,[x_axis_pc-1,y_axis_pc-1]]
+ activations_reduced = activations @ X_comp_2
+ x = activations_reduced[:,0]
+ y = activations_reduced[:,1]
+
+ fig, ax = plt.subplots(1)
+ plt.scatter(x,y)
+ plt.xlabel("PC"+str(x_axis_pc))
+ plt.ylabel("PC"+str(y_axis_pc))
+ plt.plot(0,0,'rx',alpha=0.5,markersize=10,zorder=10)
+
+ sigma1 = Ellipse(xy=(0, 0), width=X_stdev[x_axis_pc-1], height=X_stdev[y_axis_pc-1],
+ edgecolor='r', fc='None', lw=2,alpha=0.3,zorder=10)
+ sigma2 = Ellipse(xy=(0, 0), width=2*X_stdev[x_axis_pc-1], height=2*X_stdev[y_axis_pc-1],
+ edgecolor='r', fc='None', lw=2,alpha=0.25,zorder=10)
+ sigma3 = Ellipse(xy=(0, 0), width=3*X_stdev[x_axis_pc-1], height=3*X_stdev[y_axis_pc-1],
+ edgecolor='r', fc='None', lw=2,alpha=0.2,zorder=10)
+ sigma4 = Ellipse(xy=(0, 0), width=4*X_stdev[x_axis_pc-1], height=4*X_stdev[y_axis_pc-1],
+ edgecolor='r', fc='None', lw=2,alpha=0.15,zorder=10)
+ ax.add_patch(sigma1)
+ ax.add_patch(sigma2)
+ ax.add_patch(sigma3)
+ ax.add_patch(sigma4)
+
+ if(with_images):
+ w_primary_save = model.w_primary
+ model.w_primary = True
+ pbar = tqdm(total=len(list(zip(x, y))),desc='Generating images')
+ for x0, y0 in zip(x, y):
+ #activations are already centered | x_0 and y_0 are therefore offsets based from the mean
+ #shift the mean in the corresponding directions and pass it to the network
+ with torch.no_grad():
+ #print(X_global_mean.reshape((X_global_mean.shape[0],-1)).shape,X_comp_2.shape,np.array([x0,y0]).shape)
+ latent = (X_global_mean.reshape((X_global_mean.shape[0],-1)).squeeze() + X_comp_2 @ np.array([x0,y0]).T).reshape(X_global_mean.shape)
+ latent = torch.from_numpy(latent).to(device)
+ #print(latent.shape)
+ img = model.forward(latent).squeeze()
+
+ if(len(img.shape) == 3):
+ _cmap = 'viridis'
+ img = np.clip(img.cpu().numpy().transpose(1,2,0).astype(np.float32),0,1)
+ else:
+ _cmap = 'gray'
+ img = img.cpu().numpy()
+
+ img = resize(img,(256,256)) #downscale images
+ ab = AnnotationBbox(OffsetImage(img,0.2,cmap=_cmap), (x0, y0), frameon=False)
+ ax.add_artist(ab)
+ pbar.update(1)
+
+ pbar.close()
+ model.w_primary = w_primary_save
+ #Save interactive image as binary
+ with open(outdir/model.name/layer_key.lower()/est_id/f'scatter_images{str(n_samples)}_{"PC"+str(x_axis_pc)}_{"PC"+str(y_axis_pc)}.pickle', 'wb') as pickle_file:
+ pickle.dump(fig, pickle_file)
+ else:
+ plt.savefig(outdir/model.name/layer_key.lower()/est_id / f'scatter{str(n_samples)}_{"PC"+str(x_axis_pc)}_{"PC"+str(y_axis_pc)}.jpg', dpi=300)
+
+ show()
+
+def plot_explained_variance(X_var_ratio,X_dim,args):
+ #PCA on complete random space to compare:
+ transformer = get_estimator(args.estimator, args.components, args.sparsity)
+ seed = np.random.randint(np.iinfo(np.int32).max) # use (reproducible) global rand state
+ random_samples = torch.from_numpy(np.random.RandomState(seed).randn(10000, X_dim[-1]))
+ transformer.fit(random_samples)
+ _, _, random_var_ratio = transformer.get_components()
+
+ fig, ax = plt.subplots(1)
+
+ X_cumm_lst = []
+ X_cumm = 0
+ random_cumm_lst = []
+ random_cumm = 0
+ for i in range(X_var_ratio.shape[0]):
+ X_cumm += X_var_ratio[i]
+ X_cumm_lst.append(X_cumm)
+ random_cumm += random_var_ratio[i]
+ random_cumm_lst.append(random_cumm)
+
+ plt.plot(X_cumm_lst,label="activation space")
+ plt.plot(random_cumm_lst,label="random space")
+ #plt.plot(X_var_ratio,label="activation space")
+ #plt.plot(random_var_ratio,label="random space")
+ plt.title("cumulative variance ratio")
+ #plt.title("variance ratio")
+ plt.xlabel("principal component")
+ plt.ylabel("ratio")
+ plt.legend()
+ plt.show()
def x_closest(p):
distances = np.sqrt(np.sum((X - p)**2, axis=-1))
@@ -50,7 +168,7 @@ def make_mp4(imgs, duration_secs, outname):
FFMPEG_BIN = shutil.which("ffmpeg")
assert FFMPEG_BIN is not None, 'ffmpeg not found, install with "conda install -c conda-forge ffmpeg"'
assert len(imgs[0].shape) == 3, 'Invalid shape of frame data'
-
+ format_str = 'rgb24' if imgs[0].shape[-1] > 1 else 'gray'
resolution = imgs[0].shape[0:2]
fps = int(len(imgs) / duration_secs)
@@ -59,7 +177,7 @@ def make_mp4(imgs, duration_secs, outname):
'-f', 'rawvideo',
'-vcodec','rawvideo',
'-s', f'{resolution[0]}x{resolution[1]}', # size of one frame
- '-pix_fmt', 'rgb24',
+ '-pix_fmt', f'{format_str}',
'-r', f'{fps}',
'-i', '-', # imput from pipe
'-an', # no audio
@@ -67,8 +185,10 @@ def make_mp4(imgs, duration_secs, outname):
'-preset', 'slow',
'-crf', '17',
str(Path(outname).with_suffix('.mp4')) ]
-
+
+ print((imgs[0] * 255).astype(np.uint8).reshape(-1).shape)
frame_data = np.concatenate([(x * 255).astype(np.uint8).reshape(-1) for x in imgs])
+ print(frame_data.shape)
with sp.Popen(command, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE) as p:
ret = p.communicate(frame_data.tobytes())
if p.returncode != 0:
@@ -83,29 +203,37 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std
x_range = np.linspace(-scale, scale, n_cols, dtype=np.float32) # scale in sigmas
rows = []
- for r in range(n_rows):
+ for r in trange(n_rows,unit="component"):
curr_row = []
out_batch = create_strip_centered(inst, edit_type, layer_key, [latent],
act_comp[r], lat_comp[r], act_stdev[r], lat_stdev[r], act_mean, lat_mean, scale, 0, -1, n_cols)[0]
+ #print("len(out_batch) =",len(out_batch))
for i, img in enumerate(out_batch):
curr_row.append(('c{}_{:.2f}'.format(r, x_range[i]), img))
rows.append(curr_row[:n_cols])
inst.remove_edits()
-
+
if make_plots:
# If more rows than columns, make several blocks side by side
n_blocks = 2 if n_rows > n_cols else 1
-
+
for r, data in enumerate(rows):
# Add white borders
- imgs = pad_frames([img for _, img in data])
-
+ imgs = pad_frames([img for _, img in data])
+
coord = ((r * n_blocks) % n_rows) + ((r * n_blocks) // n_rows)
plt.subplot(n_rows//n_blocks, n_blocks, 1 + coord)
- plt.imshow(np.hstack(imgs))
-
+
+ if(imgs[0].shape[2] > 1):
+ img_row = np.hstack(imgs)
+ _cmap = 'viridis'
+ else:
+ img_row = np.hstack(imgs).squeeze()
+ _cmap = 'gray'
+ plt.imshow(img_row,cmap=_cmap)
+
# Custom x-axis labels
W = imgs[0].shape[1] # image width
P = imgs[1].shape[1] # padding width
@@ -119,6 +247,23 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std
return [img for row in rows for img in row]
+def get_edit_name(mode):
+ if mode == 'activation':
+ is_stylegan = 'StyleGAN' in args.model
+ is_w = layer_key in ['style', 'g_mapping','mapping']
+ return 'W' if (is_stylegan and is_w) else 'ACT'
+ elif mode == 'latent':
+ return model.latent_space_name()
+ elif mode == 'both':
+ return 'BOTH'
+ else:
+ raise RuntimeError(f'Unknown edit mode {mode}')
+
+def show():
+ if args.batch_mode:
+ plt.close('all')
+ else:
+ plt.show()
######################
### Visualize results
@@ -144,7 +289,6 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std
device = torch.device('cuda' if has_gpu else 'cpu')
layer_key = args.layer
layer_name = layer_key #layer_key.lower().split('.')[-1]
-
basedir = Path(__file__).parent.resolve()
outdir = basedir / 'out'
@@ -152,6 +296,7 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std
inst = get_instrumented_model(args.model, args.output_class, layer_key, device, use_w=args.use_w)
model = inst.model
feature_shape = inst.feature_shape[layer_key]
+
latent_shape = model.get_latent_shape()
print('Feature shape:', feature_shape)
@@ -196,12 +341,6 @@ def make_grid(latent, lat_mean, lat_comp, lat_stdev, act_mean, act_comp, act_std
max_batch = args.batch_size or (get_max_batch_size(inst, device) if has_gpu else 1)
print('Batch size:', max_batch)
- def show():
- if args.batch_mode:
- plt.close('all')
- else:
- plt.show()
-
print(f'[{timestamp()}] Creating visualizations')
# Ensure visualization gets new samples
@@ -221,60 +360,61 @@ def show():
sparsity = np.mean(X_comp == 0) # percentage of zero values in components
print(f'Sparsity: {sparsity:.2f}')
- def get_edit_name(mode):
- if mode == 'activation':
- is_stylegan = 'StyleGAN' in args.model
- is_w = layer_key in ['style', 'g_mapping']
- return 'W' if (is_stylegan and is_w) else 'ACT'
- elif mode == 'latent':
- return model.latent_space_name()
- elif mode == 'both':
- return 'BOTH'
- else:
- raise RuntimeError(f'Unknown edit mode {mode}')
-
# Only visualize applicable edit modes
- if args.use_w and layer_key in ['style', 'g_mapping']:
+ if args.use_w and layer_key in ['style', 'g_mapping','mapping']:
edit_modes = ['latent'] # activation edit is the same
else:
edit_modes = ['activation', 'latent']
+ #plot_explained_variance(X_var_ratio,X_comp.shape[1:],args)
+
+ #Scatter 2D of PC1 - PC2
+ #(X_comp,inst,model,layer_key,outdir,n_samples=100
+ if(args.show_scatter):
+ make_2Dscatter(X_comp,X_global_mean,X_stdev,inst,model,layer_key,outdir,device,
+ n_samples=args.scatter_samples,with_images=args.scatter_images,x_axis_pc=args.scatter_x_axis_pc,y_axis_pc=args.scatter_y_axis_pc)
+
# Summary grid, real components
for edit_mode in edit_modes:
+ print("edit_mode =",edit_mode)
plt.figure(figsize = (14,12))
plt.suptitle(f"{args.estimator.upper()}: {model.name} - {layer_name}, {get_edit_name(edit_mode)} edit", size=16)
make_grid(tensors.Z_global_mean, tensors.Z_global_mean, tensors.Z_comp, tensors.Z_stdev, tensors.X_global_mean,
- tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14)
+ tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=args.np_directions, n_cols=args.np_images)
plt.savefig(outdir_summ / f'components_{get_edit_name(edit_mode)}.jpg', dpi=300)
show()
+ print("args.make_video =",args.make_video)
if args.make_video:
- components = 15
- instances = 150
-
+ components = args.nv_directions #10
+ instances = args.nv_images#150
+
# One reasonable, one over the top
for sigma in [args.sigma, 3*args.sigma]:
for c in range(components):
for edit_mode in edit_modes:
+ print("Make grid for video")
frames = make_grid(tensors.Z_global_mean, tensors.Z_global_mean, tensors.Z_comp[c:c+1, :, :], tensors.Z_stdev[c:c+1], tensors.X_global_mean,
tensors.X_comp[c:c+1, :, :], tensors.X_stdev[c:c+1], n_rows=1, n_cols=instances, scale=sigma, make_plots=False, edit_type=edit_mode)
plt.close('all')
+ print("Done!")
frames = [x for _, x in frames]
frames = frames + frames[::-1]
+ print("num_frames =",len(frames)) #num_frames = 300
make_mp4(frames, 5, outdir_comp / f'{get_edit_name(edit_mode)}_sigma{sigma}_comp{c}.mp4')
-
+
# Summary grid, random directions
# Using the stdevs of the principal components for same norm
random_dirs_act = torch.from_numpy(get_random_dirs(n_comp, np.prod(sample_shape)).reshape(-1, *sample_shape)).to(device)
random_dirs_z = torch.from_numpy(get_random_dirs(n_comp, np.prod(inst.input_shape)).reshape(-1, *latent_shape)).to(device)
-
+
for edit_mode in edit_modes:
plt.figure(figsize = (14,12))
plt.suptitle(f"{model.name} - {layer_name}, random directions w/ PC stdevs, {get_edit_name(edit_mode)} edit", size=16)
make_grid(tensors.Z_global_mean, tensors.Z_global_mean, random_dirs_z, tensors.Z_stdev,
- tensors.X_global_mean, random_dirs_act, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14)
+ tensors.X_global_mean, random_dirs_act, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=args.np_directions, n_cols=args.np_images)
plt.savefig(outdir_summ / f'random_dirs_{get_edit_name(edit_mode)}.jpg', dpi=300)
show()
@@ -291,14 +431,14 @@ def get_edit_name(mode):
plt.figure(figsize = (14,12))
plt.suptitle(f"{args.estimator.upper()}: {model.name} - {layer_name}, {get_edit_name(edit_mode)} edit", size=16)
make_grid(z, tensors.Z_global_mean, tensors.Z_comp, tensors.Z_stdev,
- tensors.X_global_mean, tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=14)
+ tensors.X_global_mean, tensors.X_comp, tensors.X_stdev, scale=args.sigma, edit_type=edit_mode, n_rows=args.np_directions, n_cols=args.np_images)
plt.savefig(outdir_summ / f'samp{img_idx}_real_{get_edit_name(edit_mode)}.jpg', dpi=300)
show()
if args.make_video:
- components = 5
- instances = 150
-
+ components = args.nv_directions #10
+ instances = args.nv_images#150
+
# One reasonable, one over the top
for sigma in [args.sigma, 3*args.sigma]: #[2, 5]:
for edit_mode in edit_modes:
@@ -311,4 +451,4 @@ def get_edit_name(mode):
frames = frames + frames[::-1]
make_mp4(frames, 5, outdir_inst / f'{get_edit_name(edit_mode)}_sigma{sigma}_img{img_idx}_comp{c}.mp4')
- print('Done in', datetime.datetime.now() - t_start)
\ No newline at end of file
+ print('Done in', datetime.datetime.now() - t_start)