-
Notifications
You must be signed in to change notification settings - Fork 16
Description
If the metamer object moves from CUDA to CPU after synthesize has been called, the optimizer object (and probably the scheduler?) doesn't properly move with it. With Adam (our default), the optimizer carries around several tensors in its state_dict(), which remain on cuda.
Synthesize can be called again, but it looks like optimization no longer happens: the metamer is no longer changing. Can see this because loss stays the same and, if ran with store_progress=True, met.saved_metamer.diff(dim=0).sum((-2,-1)) shows that metamer is not changing at all.
This state dict contains a tensor, "step", which normally tracks the number of synthesis iterations (i.e., the number of times optimizer.step has been called) but which no longer updates after moving from cuda to cpu. So that can probably also be used to diagnose this problem.
Originally ran into this issue when synthesizing a metamer on cuda, then calling met.to('cpu'); met.save('tmp.pt') and was unable to load the resulting file on a cpu-only machine, despite the fact that all the "important attributes" were on the cpu.
Same thing happens when changing dtypes.
Ideally, find a way to move any pytorch optimizer easily after synthesize has been called. Unfortunately, optimizer objects don't have a .to method, but maybe can manually move over tensors in state_dict?
If not possible, raise an error when doing synthesize -> to -> synthesize (and save?).
A workaround (for devices, not dtypes) is to save and then load with map_location, e.g.,:
import plenoptic as po
img = po.data.einstein().to(0).to(torch.float64)
model = po.simul.Gaussian(30).eval()
po.tools.remove_grad(model)
model = model.to(DEVICE).to(img.dtype)
met = po.synth.Metamer(img, model)
met.synthesize(10)
met.save("tmp.pt")
met = po.synth.Metamer(img, model)
met.load("tmp.pt", map_location="cpu")Other work-around, for both dtype and device, is to explicitly reset the state_dict and then save/load. if you want to resume synthesis, this is definitely not what you want to do, so I'm not sure if this is actually useful (if you don't care about resuming synthesis, then the map_location trick is fine for device, and the fact that the optimizer no longer updates the metamer is not an issue)
import plenoptic as po
img = po.data.einstein().to(0).to(torch.float64)
model = po.simul.Gaussian(30).eval()
po.tools.remove_grad(model)
model = model.to(DEVICE).to(img.dtype)
met = po.synth.Metamer(img, model)
met.setup() # needed to initialize optimizer
init_state_dict = met.optimizer.state_dict()
met.synthesize(10)
met.optimizer.load_state_dict(init_state_dict)
met.to('cpu')
met.to(torch.float32)
met.save("tmp.pt")
met = po.synth.Metamer(img, model)
met.load("tmp.pt")Note that if you move devices/dtypes before calling synthesize (even if after setup, pretty sure), this isn't a problem.