Skip to content

Conversation

@PPWangyc
Copy link
Collaborator

@PPWangyc PPWangyc commented Jun 25, 2025

This pull request closes #4

  • Implement a batch sampler for anchor frames and positive frames.
  • Implement contrastive loss for ViT model pretraining.

Copy link
Contributor

@themattinthehatt themattinthehatt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. need some basic unit tests for ContrastBatchSampler and contrastive_collate_fn. for the former, might be useful to break up some of the logic into smaller helper functions that can be independently tested, outside of the Sampler. for the latter, can use the base_datamodule_contrastive fixture (detailed below) as input to the test, that way you can use __getitem__ from an actual dataset
  2. need some basic unit tests for topk and batch_wise_contrastive_loss
  3. for tests/data/test_datamodules.py::test_base_datamodule, why is the sampler no longer RandomSampler?
  4. to test ContrastBatchSampler, you can:
  • add new pytest fixture in tests/conftest.py, base_datamodule_contrastive, that takes in base_dataset (like the current base_datamodule fixture) and then sets use_sampler=True; this will build a datamodule at the beginning of running the tests, and then you can access it throughout the tests
  • make a new test tests/data/test_datamodules.py::test_base_datamodule_contrastive that takes base_datamodule_contrastive as an input. then you can write some tests for this datamodule, making sure it outputs the proper batches, the sampler is the correct one and doing what you expect, etc.
  1. need a test for a vit with the contrastive loss; to do so, you can:
  • add the line config['model_params']['use_infoNCE'] = False in conf.py, config_vit fixture. just to make sure if this default is ever changed the current tests run the way we expect
  • add a test tests/models/test_vits.py::test_vit_autoencoder_contrastive_integration, which will just have one extra line after copying the config, config['model_params']['use_infoNCE'] = True; this will do a full end-to-end test with creating dataset/module, training the model, and running inference post-training (well, training for 1 or 2 epochs)
  1. I also think that we should apply the same augmentation to the anchor, positive, and negative frames in a given sample - otherwise there might be some training instability if, say, the anchor frame is flipped but the positive frame is not. what do you think? we did this in LP when we have context frames (t, t+/-1, t+/-2), I pasted the relevant code below for how we control randomness w/ imgaug for this

after all this, there still a question of verifying if the implementation is "correct" or not. what do you think about this? you could try training the model for 10 or 20 epochs, and see how the training curves match your old results. do you think that would be enough to convince you? or would we need to fully train out a model and test it on one or more downstream tasks? eventually we'll need to do the latter, it's mostly a question of timing though.

LP code snippet:

# need to apply the same transform to all context frames
seed = np.random.randint(low=0, high=123456)
transformed_images = []
for img in images:
    self.imgaug_transform.seed_(seed)
    transformed_image, transformed_keypoints = self.imgaug_transform(
        images=[img], keypoints=[keypoints_on_image.numpy()]
     )

@PPWangyc
Copy link
Collaborator Author

  1. need some basic unit tests for ContrastBatchSampler and contrastive_collate_fn. for the former, might be useful to break up some of the logic into smaller helper functions that can be independently tested, outside of the Sampler. for the latter, can use the base_datamodule_contrastive fixture (detailed below) as input to the test, that way you can use __getitem__ from an actual dataset
  2. need some basic unit tests for topk and batch_wise_contrastive_loss
  3. for tests/data/test_datamodules.py::test_base_datamodule, why is the sampler no longer RandomSampler?
  4. to test ContrastBatchSampler, you can:
  • add new pytest fixture in tests/conftest.py, base_datamodule_contrastive, that takes in base_dataset (like the current base_datamodule fixture) and then sets use_sampler=True; this will build a datamodule at the beginning of running the tests, and then you can access it throughout the tests
  • make a new test tests/data/test_datamodules.py::test_base_datamodule_contrastive that takes base_datamodule_contrastive as an input. then you can write some tests for this datamodule, making sure it outputs the proper batches, the sampler is the correct one and doing what you expect, etc.
  1. need a test for a vit with the contrastive loss; to do so, you can:
  • add the line config['model_params']['use_infoNCE'] = False in conf.py, config_vit fixture. just to make sure if this default is ever changed the current tests run the way we expect
  • add a test tests/models/test_vits.py::test_vit_autoencoder_contrastive_integration, which will just have one extra line after copying the config, config['model_params']['use_infoNCE'] = True; this will do a full end-to-end test with creating dataset/module, training the model, and running inference post-training (well, training for 1 or 2 epochs)
  1. I also think that we should apply the same augmentation to the anchor, positive, and negative frames in a given sample - otherwise there might be some training instability if, say, the anchor frame is flipped but the positive frame is not. what do you think? we did this in LP when we have context frames (t, t+/-1, t+/-2), I pasted the relevant code below for how we control randomness w/ imgaug for this

after all this, there still a question of verifying if the implementation is "correct" or not. what do you think about this? you could try training the model for 10 or 20 epochs, and see how the training curves match your old results. do you think that would be enough to convince you? or would we need to fully train out a model and test it on one or more downstream tasks? eventually we'll need to do the latter, it's mostly a question of timing though.

LP code snippet:

# need to apply the same transform to all context frames
seed = np.random.randint(low=0, high=123456)
transformed_images = []
for img in images:
    self.imgaug_transform.seed_(seed)
    transformed_image, transformed_keypoints = self.imgaug_transform(
        images=[img], keypoints=[keypoints_on_image.numpy()]
     )

Hey Matt,

Thanks for your comments!! I fixed the code base and especially the unit tests part regarding your feedback.

  • Unit test (1~5): I added functions for the samplers.py and created unit tests for them in test_samplers. According to your guidance, I also added tests for topk, batch_wise_contrastive_loss, test_base_datamodule_contrastive. Moreover, I fixed the trainloader to RandomSampler.
  • Augmentation (6): We tried the consistent augmentation, but the encoding performance results are even worse. So I keep the "aggressive" augmentation which separately for all context frames.

Copy link
Contributor

@themattinthehatt themattinthehatt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy of slack comments here for posterity:

I'm still not sure this code works properly:

            subdirs = list(self.data_dir.iterdir())
            self.image_list = {}
            for idx, subdir in enumerate(subdirs):
                start_idx = idx * self.offset + len(self.frame_idx)
                subdir_frames = list(subdir.rglob('*.png'))
                # store the index of each fram
                for i, frame_path in enumerate(subdir_frames):
                    self.image_list[start_idx + i] = frame_path

if offset = 0, say, then the first pass through the for loop will have idx=0 and len(self.frame_idx) = 0
so start_idx=0
when we loop through subdir frames (say there are 15) we get keys for image list 0-14
then next pass through the loop, idx=1 and len(self.frame_idx) = 15 (should be self.image_list)
so start_idx=15
when we loop through subdir frames (say there are 15) we get keys for image list 15-29
ok, that's fine

when offset = 1:
first pass through for loop, idx=0, len(self.frame_idx)=0, and start_idx = 0
when we loop through subdir frames (say there are 15) we get keys for image list 0-14
then next pass through the loop, idx=1 and len(self.frame_idx) = 15
so start_idx=16 (1 * 1 + 15)
when we loop through subdir frames (say there are 15) we get keys for image list 16-30

but this doesn't solve the issue that I was mentioning above, where two frames from the same video that are not close together in time are paired as anchor and positive

my suggestion would be the following: in BaseDataset, remove the most recent additions you made, and just return to self.image_list = list(self.data_dir.rglob('*.png'))

in ContrastBatchSampler constructor, do

self.all_indices = list(range(self.num_samples))
self.anchor_indices = extract_anchor_indices(dataset.image_list)

you'll have to write extract_anchor_indices, but basically it will parse the image paths, and for each image it will extract (1) the video it comes from, and (2) it's frame index. then you can iterate through each of these, and add the index to an anchor_indices list if the index before/after it are valid neighboring frames (i.e. from the same video and +/-1 frame index).
then the iter method can be simplified:

def __iter__(self):
        self.epoch += 1
        
        if self.shuffle:
            random.shuffle(self.anchor_indices)
        
        used = set()
        batches_returned = 0
        
        # We'll keep sampling until we form all possible batches
        idx_cursor = 0
        
        while batches_returned < self.num_batches:
            batch = []
            # Keep pairing up references and positives until we have batch_size
            while len(batch) < self.batch_size:
                
                # If we run out of "unused" indices, we break early
                # (especially if drop_last == True)
                while idx_cursor < self.num_samples and self.anchor_indices[idx_cursor] in used:
                    idx_cursor += 1
                if idx_cursor >= self.num_samples:
                    break
                
                i = self.anchor_indices[idx_cursor]
                
                # choose a random positive
                i_p = np.random.choice([-1, 1])
                
                # Now we have a reference i, a positive i_p
                # Mark them as used
                used.update([i])
                batch.extend([i, i_p])
                
                idx_cursor += 1
                if idx_cursor >= self.num_samples:
                    break
            
            # Use helper function to fill remaining batch
            batch, used = fill_remaining_batch(
                batch, self.batch_size, self.all_indices, used, self.drop_last
            )
            
            # If we failed to get a full batch size, then drop or return partial
            if len(batch) < self.batch_size:
                if self.drop_last:
                    break  # discard partial batch
                # else fill the remainder randomly from unused "far" indices
                needed = self.batch_size - len(batch)
                far_candidates = [x for x in self.all_indices if x not in used]
                if len(far_candidates) < needed:
                    # can't fill
                    break
                chosen = random.sample(far_candidates, needed)
                used.update(chosen)
                batch.extend(chosen)
            yield batch
            batches_returned += 1

this allows you to also get rid of find_positive_candidate and get_neighbor_indices
btw I'm also a bit confused about fill_remaining_batch - if we're here it means we've run out of anchor frames, right? does that mean previously that the batch would be padded out with a random assortment of other frames? if there are two random values will these erroneously get paired as an anchor and positive frame downstream?

# check train batch properties
train_dataloader = base_datamodule.train_dataloader()
assert isinstance(train_dataloader.sampler, RandomSampler)
assert isinstance(train_dataloader.sampler, RandomSampler) or isinstance(train_dataloader.sampler, ContrastBatchSampler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ContrastBatchSampler assertion can still be removed from each of these, right? If we're testing base_datamodule and not base_datamodule_contrastive the sampler should always be RandomSampler

Copy link
Collaborator Author

@PPWangyc PPWangyc Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it should be RandomSampler in default.


# Check that we have unique indices (no duplicates within a batch)
unique_indices = torch.unique(batch['idx'])
assert len(unique_indices) <= expected_batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert len(unique_indices) == len(batch['idx']) is better?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix!

@PPWangyc
Copy link
Collaborator Author

PPWangyc commented Jul 3, 2025

copy of slack comments here for posterity:

I'm still not sure this code works properly:

            subdirs = list(self.data_dir.iterdir())
            self.image_list = {}
            for idx, subdir in enumerate(subdirs):
                start_idx = idx * self.offset + len(self.frame_idx)
                subdir_frames = list(subdir.rglob('*.png'))
                # store the index of each fram
                for i, frame_path in enumerate(subdir_frames):
                    self.image_list[start_idx + i] = frame_path

if offset = 0, say, then the first pass through the for loop will have idx=0 and len(self.frame_idx) = 0 so start_idx=0 when we loop through subdir frames (say there are 15) we get keys for image list 0-14 then next pass through the loop, idx=1 and len(self.frame_idx) = 15 (should be self.image_list) so start_idx=15 when we loop through subdir frames (say there are 15) we get keys for image list 15-29 ok, that's fine

when offset = 1: first pass through for loop, idx=0, len(self.frame_idx)=0, and start_idx = 0 when we loop through subdir frames (say there are 15) we get keys for image list 0-14 then next pass through the loop, idx=1 and len(self.frame_idx) = 15 so start_idx=16 (1 * 1 + 15) when we loop through subdir frames (say there are 15) we get keys for image list 16-30

but this doesn't solve the issue that I was mentioning above, where two frames from the same video that are not close together in time are paired as anchor and positive

my suggestion would be the following: in BaseDataset, remove the most recent additions you made, and just return to self.image_list = list(self.data_dir.rglob('*.png'))

in ContrastBatchSampler constructor, do

self.all_indices = list(range(self.num_samples))
self.anchor_indices = extract_anchor_indices(dataset.image_list)

you'll have to write extract_anchor_indices, but basically it will parse the image paths, and for each image it will extract (1) the video it comes from, and (2) it's frame index. then you can iterate through each of these, and add the index to an anchor_indices list if the index before/after it are valid neighboring frames (i.e. from the same video and +/-1 frame index). then the iter method can be simplified:

def __iter__(self):
        self.epoch += 1
        
        if self.shuffle:
            random.shuffle(self.anchor_indices)
        
        used = set()
        batches_returned = 0
        
        # We'll keep sampling until we form all possible batches
        idx_cursor = 0
        
        while batches_returned < self.num_batches:
            batch = []
            # Keep pairing up references and positives until we have batch_size
            while len(batch) < self.batch_size:
                
                # If we run out of "unused" indices, we break early
                # (especially if drop_last == True)
                while idx_cursor < self.num_samples and self.anchor_indices[idx_cursor] in used:
                    idx_cursor += 1
                if idx_cursor >= self.num_samples:
                    break
                
                i = self.anchor_indices[idx_cursor]
                
                # choose a random positive
                i_p = np.random.choice([-1, 1])
                
                # Now we have a reference i, a positive i_p
                # Mark them as used
                used.update([i])
                batch.extend([i, i_p])
                
                idx_cursor += 1
                if idx_cursor >= self.num_samples:
                    break
            
            # Use helper function to fill remaining batch
            batch, used = fill_remaining_batch(
                batch, self.batch_size, self.all_indices, used, self.drop_last
            )
            
            # If we failed to get a full batch size, then drop or return partial
            if len(batch) < self.batch_size:
                if self.drop_last:
                    break  # discard partial batch
                # else fill the remainder randomly from unused "far" indices
                needed = self.batch_size - len(batch)
                far_candidates = [x for x in self.all_indices if x not in used]
                if len(far_candidates) < needed:
                    # can't fill
                    break
                chosen = random.sample(far_candidates, needed)
                used.update(chosen)
                batch.extend(chosen)
            yield batch
            batches_returned += 1

this allows you to also get rid of find_positive_candidate and get_neighbor_indices btw I'm also a bit confused about fill_remaining_batch - if we're here it means we've run out of anchor frames, right? does that mean previously that the batch would be padded out with a random assortment of other frames? if there are two random values will these erroneously get paired as an anchor and positive frame downstream?

Thanks for your thoughtful comment :). The reason I keep frame_idx in sampler is that our sampled frames from long videos are usually non-consecutive. The sampled frame list would be [9, 17, 22, 31, ...]. However, thus there's no positive frame for any anchor frame in the list. The frame_idx convert [9, 17, 22, 31, ...] --> [0, 1, 2, 3, ...]. In this way, the positive frame of idx 9 in the list will be 17. Does this make sense?

Regarding fill_remaining_batch function, this is used to randomly sampled some frames to fill a batch. You are right, this fill_remaining_batch function makes things messy. Moreover, I didn't include this function in our original paper. I just removed this in the last commit.

@PPWangyc
Copy link
Collaborator Author

PPWangyc commented Jul 6, 2025

Hi @themattinthehatt,
I added the extract_anchor_indicies function and its unit tests :). Please let me know if anything I could fix. Thanks!

@themattinthehatt themattinthehatt self-requested a review July 9, 2025 03:28
@themattinthehatt themattinthehatt merged commit bf81c6f into main Jul 9, 2025
1 check passed
@themattinthehatt themattinthehatt deleted the contrastive branch July 9, 2025 03:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

implement contrastive loss

3 participants