Skip to content

Feat/sgl#9

Open
sherinrasheed wants to merge 4 commits intofdalvi:masterfrom
sherinrasheed:feat/sgl
Open

Feat/sgl#9
sherinrasheed wants to merge 4 commits intofdalvi:masterfrom
sherinrasheed:feat/sgl

Conversation

@sherinrasheed
Copy link

No description provided.

1) added interpretation.utils.get_group_index(): function returns group_index parameter required for training SGL probe. 2) added interpretation.linear_probe._train_sgl_probe() : function to train a logistic regression probe with SGL regularization using GroupLasso library object. 3) imported libraries -time and -LogisticGroupLasso in interpretation.linear_probe
…_probe() to include options for regularization method

 Edited function  interpretation.linear_probe.train_logistic_regression_probe() to accomodate options for Sparse Group Lasso and Group Lasso Regularization. Function will choose ElasticNet Regularization by default
…be parameter initialization

this function is called within the function interpretation.linear_probe.train_logistic_regression_probe()
Copy link
Owner

@fdalvi fdalvi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good to me, if you can just generally add comments to the new functions you've added, I think we will be good to merge!

plt.title('Layer '+str(layer)+' weights')
plt.show()

def sparsity(model_name, data, **kwargs) :
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the following functions seem "internal" to the lib, can you prepend them with an underscore so the name becomes _sparsity? Also it might be better to name it to something a bit more clear, is this plotting sparsity? Computing it?

import torch
import torch.nn as nn
from torch.autograd import Variable
from group_lasso import LogisticGroupLasso
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this lib to requirements.txt along with a version that you know works well

regularization = "elastic_net",
**kwargs
):

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the doc string below with the new "regularization" parameter?

@fdalvi
Copy link
Owner

fdalvi commented Apr 24, 2022

Hello @sherinrasheed, hope you are well! I know you've been busy lately, but it would be awesome if we can push these changes (its mostly documentation related). We want to do some major changes to the codebase and merging this after those changes would be a bit painful.

Let me know if you need any assistance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants