-
Notifications
You must be signed in to change notification settings - Fork 5
add contrastive loss and batch sampler #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- need some basic unit tests for
ContrastBatchSamplerandcontrastive_collate_fn. for the former, might be useful to break up some of the logic into smaller helper functions that can be independently tested, outside of the Sampler. for the latter, can use thebase_datamodule_contrastivefixture (detailed below) as input to the test, that way you can use__getitem__from an actual dataset - need some basic unit tests for
topkandbatch_wise_contrastive_loss - for
tests/data/test_datamodules.py::test_base_datamodule, why is the sampler no longerRandomSampler? - to test
ContrastBatchSampler, you can:
- add new pytest fixture in
tests/conftest.py,base_datamodule_contrastive, that takes inbase_dataset(like the currentbase_datamodule fixture) and then setsuse_sampler=True; this will build a datamodule at the beginning of running the tests, and then you can access it throughout the tests - make a new test
tests/data/test_datamodules.py::test_base_datamodule_contrastivethat takesbase_datamodule_contrastiveas an input. then you can write some tests for this datamodule, making sure it outputs the proper batches, the sampler is the correct one and doing what you expect, etc.
- need a test for a vit with the contrastive loss; to do so, you can:
- add the line
config['model_params']['use_infoNCE'] = Falseinconf.py,config_vitfixture. just to make sure if this default is ever changed the current tests run the way we expect - add a test
tests/models/test_vits.py::test_vit_autoencoder_contrastive_integration, which will just have one extra line after copying the config,config['model_params']['use_infoNCE'] = True; this will do a full end-to-end test with creating dataset/module, training the model, and running inference post-training (well, training for 1 or 2 epochs)
- I also think that we should apply the same augmentation to the anchor, positive, and negative frames in a given sample - otherwise there might be some training instability if, say, the anchor frame is flipped but the positive frame is not. what do you think? we did this in LP when we have context frames (t, t+/-1, t+/-2), I pasted the relevant code below for how we control randomness w/ imgaug for this
after all this, there still a question of verifying if the implementation is "correct" or not. what do you think about this? you could try training the model for 10 or 20 epochs, and see how the training curves match your old results. do you think that would be enough to convince you? or would we need to fully train out a model and test it on one or more downstream tasks? eventually we'll need to do the latter, it's mostly a question of timing though.
LP code snippet:
# need to apply the same transform to all context frames
seed = np.random.randint(low=0, high=123456)
transformed_images = []
for img in images:
self.imgaug_transform.seed_(seed)
transformed_image, transformed_keypoints = self.imgaug_transform(
images=[img], keypoints=[keypoints_on_image.numpy()]
)
Hey Matt, Thanks for your comments!! I fixed the code base and especially the unit tests part regarding your feedback.
|
themattinthehatt
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy of slack comments here for posterity:
I'm still not sure this code works properly:
subdirs = list(self.data_dir.iterdir())
self.image_list = {}
for idx, subdir in enumerate(subdirs):
start_idx = idx * self.offset + len(self.frame_idx)
subdir_frames = list(subdir.rglob('*.png'))
# store the index of each fram
for i, frame_path in enumerate(subdir_frames):
self.image_list[start_idx + i] = frame_path
if offset = 0, say, then the first pass through the for loop will have idx=0 and len(self.frame_idx) = 0
so start_idx=0
when we loop through subdir frames (say there are 15) we get keys for image list 0-14
then next pass through the loop, idx=1 and len(self.frame_idx) = 15 (should be self.image_list)
so start_idx=15
when we loop through subdir frames (say there are 15) we get keys for image list 15-29
ok, that's fine
when offset = 1:
first pass through for loop, idx=0, len(self.frame_idx)=0, and start_idx = 0
when we loop through subdir frames (say there are 15) we get keys for image list 0-14
then next pass through the loop, idx=1 and len(self.frame_idx) = 15
so start_idx=16 (1 * 1 + 15)
when we loop through subdir frames (say there are 15) we get keys for image list 16-30
but this doesn't solve the issue that I was mentioning above, where two frames from the same video that are not close together in time are paired as anchor and positive
my suggestion would be the following: in BaseDataset, remove the most recent additions you made, and just return to self.image_list = list(self.data_dir.rglob('*.png'))
in ContrastBatchSampler constructor, do
self.all_indices = list(range(self.num_samples))
self.anchor_indices = extract_anchor_indices(dataset.image_list)
you'll have to write extract_anchor_indices, but basically it will parse the image paths, and for each image it will extract (1) the video it comes from, and (2) it's frame index. then you can iterate through each of these, and add the index to an anchor_indices list if the index before/after it are valid neighboring frames (i.e. from the same video and +/-1 frame index).
then the iter method can be simplified:
def __iter__(self):
self.epoch += 1
if self.shuffle:
random.shuffle(self.anchor_indices)
used = set()
batches_returned = 0
# We'll keep sampling until we form all possible batches
idx_cursor = 0
while batches_returned < self.num_batches:
batch = []
# Keep pairing up references and positives until we have batch_size
while len(batch) < self.batch_size:
# If we run out of "unused" indices, we break early
# (especially if drop_last == True)
while idx_cursor < self.num_samples and self.anchor_indices[idx_cursor] in used:
idx_cursor += 1
if idx_cursor >= self.num_samples:
break
i = self.anchor_indices[idx_cursor]
# choose a random positive
i_p = np.random.choice([-1, 1])
# Now we have a reference i, a positive i_p
# Mark them as used
used.update([i])
batch.extend([i, i_p])
idx_cursor += 1
if idx_cursor >= self.num_samples:
break
# Use helper function to fill remaining batch
batch, used = fill_remaining_batch(
batch, self.batch_size, self.all_indices, used, self.drop_last
)
# If we failed to get a full batch size, then drop or return partial
if len(batch) < self.batch_size:
if self.drop_last:
break # discard partial batch
# else fill the remainder randomly from unused "far" indices
needed = self.batch_size - len(batch)
far_candidates = [x for x in self.all_indices if x not in used]
if len(far_candidates) < needed:
# can't fill
break
chosen = random.sample(far_candidates, needed)
used.update(chosen)
batch.extend(chosen)
yield batch
batches_returned += 1
this allows you to also get rid of find_positive_candidate and get_neighbor_indices
btw I'm also a bit confused about fill_remaining_batch - if we're here it means we've run out of anchor frames, right? does that mean previously that the batch would be padded out with a random assortment of other frames? if there are two random values will these erroneously get paired as an anchor and positive frame downstream?
tests/data/test_datamodules.py
Outdated
| # check train batch properties | ||
| train_dataloader = base_datamodule.train_dataloader() | ||
| assert isinstance(train_dataloader.sampler, RandomSampler) | ||
| assert isinstance(train_dataloader.sampler, RandomSampler) or isinstance(train_dataloader.sampler, ContrastBatchSampler) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the ContrastBatchSampler assertion can still be removed from each of these, right? If we're testing base_datamodule and not base_datamodule_contrastive the sampler should always be RandomSampler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, it should be RandomSampler in default.
tests/data/test_datamodules.py
Outdated
|
|
||
| # Check that we have unique indices (no duplicates within a batch) | ||
| unique_indices = torch.unique(batch['idx']) | ||
| assert len(unique_indices) <= expected_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe assert len(unique_indices) == len(batch['idx']) is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix!
Thanks for your thoughtful comment :). The reason I keep frame_idx in sampler is that our sampled frames from long videos are usually non-consecutive. The sampled frame list would be [9, 17, 22, 31, ...]. However, thus there's no positive frame for any anchor frame in the list. The frame_idx convert [9, 17, 22, 31, ...] --> [0, 1, 2, 3, ...]. In this way, the positive frame of idx 9 in the list will be 17. Does this make sense? Regarding fill_remaining_batch function, this is used to randomly sampled some frames to fill a batch. You are right, this fill_remaining_batch function makes things messy. Moreover, I didn't include this function in our original paper. I just removed this in the last commit. |
|
Hi @themattinthehatt, |
This pull request closes #4