Skip to content

我尝试用论文《DATA SELECTION VIA OPTIMAL CONTROL FOR LANGUAGE MODELS》方法去筛选带有错误标签的数据,发现错误得标签也有可能会有较高得得分。 #326

@huangluyao

Description

@huangluyao

当我用下游数据集求解出某一个时刻的θ求出∇J(θ)后,我拿错误的标签得到的梯度向量跟∇J(θ)去做点积,发现他们的方向也有可能是一致的。这就意味着,错误的标签在这一时刻,方向一致的情况下,由于错误的标签的梯度的幅值很大,它可以被打上很高的分数。从而导致错误的标签无法被正确的筛查出来。

实现的代码如下

def compute_downstream_grad(model, downstream_dataset, device):
    """计算下游损失J(θ)对参数的梯度∇J(θ)(公式5中的∇J(θ_t*))"""
    model.eval()
    dataloader = DataLoader(downstream_dataset, batch_size=32)
    total_grads = 0
    step = 0
    # 累积下游损失
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = F.nll_loss(outputs, targets)
        grads = torch.autograd.grad(loss, model.parameters())
        grad_vec = torch.cat([g.view(-1) for g in grads])  # 梯度向量展平
        total_grads += grad_vec
        step += 1
    # 返回与模型参数结构一致的梯度列表
    avg_grads = total_grads / step
    return avg_grads  


for t in range(50):

    # 计算下游损失J(θ)对参数的梯度∇J(θ)(公式5中的∇J(θ_t*))
    downstream_dataset = dataset2
    model.load_state_dict(theta_t[t])
    grad_J = compute_downstream_grad(model, downstream_dataset, device)
    lam = lambda_history[t]

    for i in range(15):
        x, y = proxy_dataset[i]  # 单个样本
        x = x.unsqueeze(0).to(device)  # 扩展为batch维度
        y = torch.tensor([y]).to(device)
        outputs = model(x)
        loss = F.nll_loss(outputs, y)
        grads = torch.autograd.grad(loss, model.parameters())
        grad_vec = torch.cat([g.view(-1) for g in grads])  # 梯度向量展平 
        l2 = torch.norm(grad_vec)              # 计算向量大小
        # 输出样本梯度和grad_J的对齐程度,向量大小以及和λ*的对齐程度
        print(f"t={t}, index={i}, value={float(torch.dot(grad_J, grad_vec))}, l2={l2}, lam={float(torch.dot(lam, grad_vec))}")

得到的结果如下:

t=49, index=0, value=0.0067689307034015656, l2=0.03711961209774017, lam=0.0067689307034015656
t=49, index=1, value=0.011063024401664734, l2=0.11148233711719513, lam=0.011063024401664734
t=49, index=2, value=-0.11018553376197815, l2=2.268237829208374, lam=-0.11018553376197815
t=49, index=3, value=9.259251594543457, l2=126.85670471191406, lam=9.259251594543457		 # 错误标签
t=49, index=4, value=11.486056327819824, l2=277.9612121582031, lam=11.486056327819824        # 错误标签
t=49, index=5, value=0.0041717528365552425, l2=0.03683656081557274, lam=0.0041717528365552425
t=49, index=6, value=13.245539665222168, l2=219.6069793701172, lam=13.245539665222168
t=49, index=7, value=-0.011622816324234009, l2=0.26048213243484497, lam=-0.011622816324234009
t=49, index=8, value=-31.391651153564453, l2=247.3076934814453, lam=-31.391651153564453      # 错误标签
t=49, index=9, value=-18.298208236694336, l2=203.1321563720703, lam=-18.298208236694336      # 错误标签
t=49, index=10, value=-0.010450446978211403, l2=0.17120294272899628, lam=-0.010450446978211403
t=49, index=11, value=0.01904112473130226, l2=0.21086302399635315, lam=0.01904112473130226
t=49, index=12, value=-0.09251315891742706, l2=1.8408218622207642, lam=-0.09251315891742706
t=49, index=13, value=-0.0006138212629593909, l2=0.02753078006207943, lam=-0.0006138212629593909
t=49, index=14, value=0.05529172718524933, l2=0.3928619921207428, lam=0.05529172718524933


t=0, index=0, value=1.263808012008667, l2=11.897555351257324, lam=0.9541520476341248
t=0, index=1, value=2.0186030864715576, l2=9.998517990112305, lam=3.284712314605713
t=0, index=2, value=1.4496126174926758, l2=11.505668640136719, lam=1.1915363073349
t=0, index=3, value=-0.13733036816120148, l2=13.236851692199707, lam=1.6994320154190063   # 错误标签
t=0, index=4, value=0.15488216280937195, l2=13.605332374572754, lam=0.9128973484039307    # 错误标签
t=0, index=5, value=4.99484395980835, l2=13.224201202392578, lam=5.8155364990234375
t=0, index=6, value=0.9346104264259338, l2=10.993036270141602, lam=0.6073954701423645
t=0, index=7, value=1.463005781173706, l2=9.999319076538086, lam=1.1618057489395142
t=0, index=8, value=-2.1967873573303223, l2=14.000296592712402, lam=-1.0258207321166992   # 错误标签
t=0, index=9, value=-0.03054516203701496, l2=11.326922416687012, lam=-0.7511491775512695  # 错误标签
t=0, index=10, value=1.8088934421539307, l2=10.429534912109375, lam=2.203115940093994
t=0, index=11, value=2.495976448059082, l2=11.449640274047852, lam=-0.038659755140542984
t=0, index=12, value=1.2066035270690918, l2=11.890233993530273, lam=0.37567102909088135
t=0, index=13, value=0.32585108280181885, l2=11.816431999206543, lam=3.5074808597564697
t=0, index=14, value=1.493959665298462, l2=8.72955322265625, lam=3.2007670402526855

我从结果发现,错误的标注样本求出的梯度,也有可能和优化的方向是一致,而错误的标签,它的梯度的幅值往往比较大,这就会导致最终求得的gamma会很大,从而被认为是一个高质量得数据。

所以,是不是我这里的复现方法有一些不太正确?
我的模型是一个非常简单的CNN分类网络,数据集是MNIST手写数据集。
里面的downstream_dataset是MNIST的测试集
里面的proxy_dataset,也是测试集,但是我故意修改了一些标签,将它变成一个不正确的label。
模型的θ是训练阶段,每个epoch开始之前,将数据和θ进行保存,总共50个epoch。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions