-
Notifications
You must be signed in to change notification settings - Fork 357
Open
Description
当我用下游数据集求解出某一个时刻的θ求出∇J(θ)后,我拿错误的标签得到的梯度向量跟∇J(θ)去做点积,发现他们的方向也有可能是一致的。这就意味着,错误的标签在这一时刻,方向一致的情况下,由于错误的标签的梯度的幅值很大,它可以被打上很高的分数。从而导致错误的标签无法被正确的筛查出来。
实现的代码如下
def compute_downstream_grad(model, downstream_dataset, device):
"""计算下游损失J(θ)对参数的梯度∇J(θ)(公式5中的∇J(θ_t*))"""
model.eval()
dataloader = DataLoader(downstream_dataset, batch_size=32)
total_grads = 0
step = 0
# 累积下游损失
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = F.nll_loss(outputs, targets)
grads = torch.autograd.grad(loss, model.parameters())
grad_vec = torch.cat([g.view(-1) for g in grads]) # 梯度向量展平
total_grads += grad_vec
step += 1
# 返回与模型参数结构一致的梯度列表
avg_grads = total_grads / step
return avg_grads
for t in range(50):
# 计算下游损失J(θ)对参数的梯度∇J(θ)(公式5中的∇J(θ_t*))
downstream_dataset = dataset2
model.load_state_dict(theta_t[t])
grad_J = compute_downstream_grad(model, downstream_dataset, device)
lam = lambda_history[t]
for i in range(15):
x, y = proxy_dataset[i] # 单个样本
x = x.unsqueeze(0).to(device) # 扩展为batch维度
y = torch.tensor([y]).to(device)
outputs = model(x)
loss = F.nll_loss(outputs, y)
grads = torch.autograd.grad(loss, model.parameters())
grad_vec = torch.cat([g.view(-1) for g in grads]) # 梯度向量展平
l2 = torch.norm(grad_vec) # 计算向量大小
# 输出样本梯度和grad_J的对齐程度,向量大小以及和λ*的对齐程度
print(f"t={t}, index={i}, value={float(torch.dot(grad_J, grad_vec))}, l2={l2}, lam={float(torch.dot(lam, grad_vec))}")
得到的结果如下:
t=49, index=0, value=0.0067689307034015656, l2=0.03711961209774017, lam=0.0067689307034015656
t=49, index=1, value=0.011063024401664734, l2=0.11148233711719513, lam=0.011063024401664734
t=49, index=2, value=-0.11018553376197815, l2=2.268237829208374, lam=-0.11018553376197815
t=49, index=3, value=9.259251594543457, l2=126.85670471191406, lam=9.259251594543457 # 错误标签
t=49, index=4, value=11.486056327819824, l2=277.9612121582031, lam=11.486056327819824 # 错误标签
t=49, index=5, value=0.0041717528365552425, l2=0.03683656081557274, lam=0.0041717528365552425
t=49, index=6, value=13.245539665222168, l2=219.6069793701172, lam=13.245539665222168
t=49, index=7, value=-0.011622816324234009, l2=0.26048213243484497, lam=-0.011622816324234009
t=49, index=8, value=-31.391651153564453, l2=247.3076934814453, lam=-31.391651153564453 # 错误标签
t=49, index=9, value=-18.298208236694336, l2=203.1321563720703, lam=-18.298208236694336 # 错误标签
t=49, index=10, value=-0.010450446978211403, l2=0.17120294272899628, lam=-0.010450446978211403
t=49, index=11, value=0.01904112473130226, l2=0.21086302399635315, lam=0.01904112473130226
t=49, index=12, value=-0.09251315891742706, l2=1.8408218622207642, lam=-0.09251315891742706
t=49, index=13, value=-0.0006138212629593909, l2=0.02753078006207943, lam=-0.0006138212629593909
t=49, index=14, value=0.05529172718524933, l2=0.3928619921207428, lam=0.05529172718524933
t=0, index=0, value=1.263808012008667, l2=11.897555351257324, lam=0.9541520476341248
t=0, index=1, value=2.0186030864715576, l2=9.998517990112305, lam=3.284712314605713
t=0, index=2, value=1.4496126174926758, l2=11.505668640136719, lam=1.1915363073349
t=0, index=3, value=-0.13733036816120148, l2=13.236851692199707, lam=1.6994320154190063 # 错误标签
t=0, index=4, value=0.15488216280937195, l2=13.605332374572754, lam=0.9128973484039307 # 错误标签
t=0, index=5, value=4.99484395980835, l2=13.224201202392578, lam=5.8155364990234375
t=0, index=6, value=0.9346104264259338, l2=10.993036270141602, lam=0.6073954701423645
t=0, index=7, value=1.463005781173706, l2=9.999319076538086, lam=1.1618057489395142
t=0, index=8, value=-2.1967873573303223, l2=14.000296592712402, lam=-1.0258207321166992 # 错误标签
t=0, index=9, value=-0.03054516203701496, l2=11.326922416687012, lam=-0.7511491775512695 # 错误标签
t=0, index=10, value=1.8088934421539307, l2=10.429534912109375, lam=2.203115940093994
t=0, index=11, value=2.495976448059082, l2=11.449640274047852, lam=-0.038659755140542984
t=0, index=12, value=1.2066035270690918, l2=11.890233993530273, lam=0.37567102909088135
t=0, index=13, value=0.32585108280181885, l2=11.816431999206543, lam=3.5074808597564697
t=0, index=14, value=1.493959665298462, l2=8.72955322265625, lam=3.2007670402526855
我从结果发现,错误的标注样本求出的梯度,也有可能和优化的方向是一致,而错误的标签,它的梯度的幅值往往比较大,这就会导致最终求得的gamma会很大,从而被认为是一个高质量得数据。
所以,是不是我这里的复现方法有一些不太正确?
我的模型是一个非常简单的CNN分类网络,数据集是MNIST手写数据集。
里面的downstream_dataset是MNIST的测试集
里面的proxy_dataset,也是测试集,但是我故意修改了一些标签,将它变成一个不正确的label。
模型的θ是训练阶段,每个epoch开始之前,将数据和θ进行保存,总共50个epoch。
Metadata
Metadata
Assignees
Labels
No labels