查看原文
其他

ECCV 2022 | 多教师对抗鲁棒性蒸馏方法

陈兆宇 PaperWeekly 2022-12-13

©作者 | 陈兆宇

单位 | 复旦大学ROILab

研究方向 | 对抗样本



论文标题:

Enhanced Accuracy and Robustness via Multi-teacher Adversarial Distillation

论文来源:

ECCV 2022

论文链接:

https://link.springer.com/10.1007/978-3-031-19772-7_34

开源代码:

https://github.com/zhaoshiji123/MTARD





解决的问题


对抗训练是目前提高模型对抗鲁棒性最主要的方法。然而,在提高 DNN 的鲁棒性的同时,对抗训练在一些通用场景中存在一些缺点。首先,从对抗训练中获得的模型的鲁棒性与模型的大小有关一般来说,模型越大意味着鲁棒性越好。然而,由于各种实际因素的限制,大型模型在实际部署中往往不受青睐。

其次,经过对抗训练的 DNN 识别干净样本的准确性远不如经过正常训练的 DNN,这限制了在实际场景中的大规模使用。一些研究人员试图降低对抗训练带来的对正常样本准确性的负面影响,但效果仍然不理想。

在本文中,作者研究了通过对抗蒸馏来提高小型 DNN 的正常精度和鲁棒精度的方法。Adversarial Robustness Distillation(ARD)用于通过从大型鲁棒模型中提炼来提高小型模型的鲁棒性,它将大型模型视为教师,将小型模型视为学生。尽管之前的工作(RSLAD)通过鲁棒的软标签提高了鲁棒性,但与常规训练的性能相比,正常样本的准确性仍然不理想。

受多任务学习的启发,作者提出了多教师对抗鲁棒性蒸馏(MTARD),通过使用不同的教师模型,每个教师模型负责他们精通的内容。以提高学生模型的鲁棒性和识别正常样本的准确性。例如,我们应用一个鲁棒的教师模型和一个正常训练的教师模型来同时指导鲁棒性和准确性。

然而,由于神经网络的复杂性,教师模型对学生模型有不同程度的影响,甚至会造成灾难性的遗忘。为了缓解这种现象,作者还设计了一种联合训练算法,在对抗蒸馏的不同阶段动态调整教师模型对学生网络的影响。



提出的方法


2.1 多教师对抗鲁棒蒸馏


受多任务学习的启发,作者希望不仅能提高模型的鲁棒性,还能在对抗蒸馏中保持正常样本的准确性以往的对抗蒸馏方法只带来了对抗训练过的单一模型,鲁棒性强,但对正常图像的识别能力较弱。作为唯一的指导,学生模型往往符合教师模型的分布,导致识别正常图像的能力较低。使用 GT one-hot 标签作为学习目标来提高正常样本的识别率仍然不是一个理想的选择。因此,作者引入了一个预训练的正常教师模型来指导对抗蒸馏的过程。

MTARD 中对学生的训练仍然基于对抗训练。在知识蒸馏的对抗教师和正常教师(clean teacher)的指导下,希望学生能够从对抗教师那里学习鲁棒性,并从干净教师那里学习识别正常样本的能力。为了产生履行两位教师责任的软标签,正常教师的输入是来自原始数据集的初始正常样本。

相反,对抗教师的输入是学生模型在内部最大化中产生的对抗样本。学生输入分为正常样本和对抗样本。正常样本和对抗样本的输出将由对抗软标签和干净软标签引导,以监督外部最小化的学生模型训练。基本 MTARD 的 minimax 优化框架定义如下:


MTARD 的目标是学习一个小型学生网络,该网络既具有作为对抗预训练教师网络的鲁棒性能,又具有作为正常预训练教师网络的正常性能。然而在实际操作过程中,不同教师模型同时进行的知识蒸馏会影响学生模型的学习。学生从多个老师那里学习的强度不容易控制。如果一个老师主导了学生的学习,学生模型就很难从另一个老师那里学到相对的能力,甚至会导致灾难性的遗忘,所以处理多位教师的情况就成了要解决的问题。


2.2 自适应归一化损失函数


在数学层面上,最终用于学生模型更新的 MTARD 在时间 t 的总损失可以表示为 ,可以表示如下:


由于教师对学生的影响程度可以表示为对抗损失 和正常训练损失 的值,因此控制多教师学习程度的关键是控制 。受多任务学习中梯度正则化方法的启发,作者提出了一种算法来控制对抗教师和正常教师的鲁棒学习,这被称为 MTARD 中使用的自适应归一化损失。

对损失函数进行推广,每个教师模型的损失函数记为 ,其权重为 ,那么总的损失函记为:


其中,N 代表 loss 的数量,t 代表第 t 次时间。目标是通过在每次更新时以相似的速率动态调整 来通过它们的相对大小将 置于一个共同的尺度上,并且每个 在整个更新过程后都有一个相对公平的下降,最终训练好的模型可以同样受到损失背后各种影响因素的影响。

作者选择相对损失函数(relative loss)记为评价指标,定义为:


是最开始的损失函数值。特别是在作者的设置中,假设 相比较小的值意味着模型适合目标。 作为度量可以反映 从最开始到时间 t 的变化幅度。 的较低值对应于 相对较快的训练速度。通过引入相对损失 ,可以动态平衡 的影响,以此作为客观标准,得到相对损失权重 ,其公式如下( 为指数):


因此, 意味着 的学习率,则权重可以表示为:


在本文中,N=2,即 。因此,权重的更新可以表述为:


在实际层面上,MTARD 中使用的 Adaptive Normalization Loss 可以在整个训练周期中抑制更强的教师的快速增长。如果一个老师在一段时间内与另一个老师相比过度指导学生,Adaptive Normalization Loss 可以通过控制损失权重动态地抑制这个老师的教学能力,而另一个老师的能力在接下来的一段时间内会变得更强。

然而,这种趋势并不是绝对的。如果注意到原来的强老师变弱了,Adaptive Normalization Loss 会让原来的强老师再次变强。最后,学生可以从两个老师那里学习得很好,获得干净和鲁棒的能力,而不是在自适应归一化损失的调整下出现部分能力。




实验和效果


数据集:CIFAR10 和 100

对比方法:SAT,TRADES,ARD,RSLAD

学生和教师模型:学生模型为 ResNet-18 和 MobileNet-V2。在 CIFAR10 上,正常教师模型为 ResNet-56,对抗教师模型为 WideResNet-34-10(TRADES);在 CIFAR100 上,正常教师模型为 WideResNet-22-6,对抗教师模型为 WideResNet-70-16。


评价指标:除了正常和鲁棒准确率,作者引入了 Weighted Roubst Accuracy(W-Robust)作为新指标,其中 ,意味着正常准确率和鲁棒的一样重要:


白盒攻击(ResNet-18):



白盒攻击(MobileNet-V2):



黑盒攻击(ResNet-18):


黑盒攻击(MobileNet-V2):


模块的消融:




总结和不足


该方法主要是在 clean 和 robust 准确性上达到了平衡。只看鲁棒性的话,还是不如 RSLAD 的。



更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存