这篇论文可以概括成一句话: 它是在说,原始的 SplitNN(也就是不加密、只靠“切开模型”来隐私保护的 split learning)并不可靠;一个“老老实实按协议执行、但心里想偷看”的服务器,仅凭客户端模型的网络结构和训练时看到的 smashed data,就能一边把输入样本反推出来,一边把客户端那部分模型“偷”出一个功能相近的替身;如果标签被放在客户端最后一层来“保护”,服务器还可能把标签也推出来。 论文发表在 WPES 2022,题目就叫 UnSplit ,意思是“把 split learning 再拼回去”。
先讲背景。SplitNN 的基本想法是:把一个神经网络切成前半段和后半段,前半段在客户端算,后半段在服务器算;客户端不把原始数据发给服务器,只发前半段输出的中间表示,也就是 smashed data。论文里画了三种典型设置:一种是样本和标签都在客户端;一种是样本在服务器、标签在客户端;还有一种是样本和标签都在客户端,但最后一小段模型也留在客户端,这样标签就不用显式发给服务器了。作者的关键观点是: “服务器看不到原始输入”不等于“服务器拿不到输入信息” ,因为 smashed data 本身就带着大量关于输入和客户端模型的信息。
这篇论文的定位很明确:它不是讨论一个特别强的恶意服务器,而是讨论 honest-but-curious (诚实但好奇)服务器。这里“诚实”指它并不篡改训练流程,不往客户端发奇怪的梯度,不主动偏离协议;“好奇”指它会把训练过程中本来就能看到的信息,拿来做额外攻击。作者强调,这种攻击最麻烦的地方在于: 客户端几乎没法从协议行为上察觉服务器在偷看 。对第一个攻击,服务器只假设自己知道客户端模型的 架构 (architecture),不知道参数、没有相似公开数据、也不能主动查询客户端模型;对第二个攻击,服务器还需要一个额外假设:客户端最后那段“保标签”的模型只有一层,而且服务器能拿到这层反向传播回来的梯度。
论文的第一部分,也是它最核心的部分,是 模型反演 + 模型窃取联合攻击 。这里要先分清两个目标。
第一, 模型反演 (model inversion)是想把客户端输入 $x$ 反推出来。
第二, 模型窃取 (model stealing)是想得到一个替代客户端模型的克隆模型 $\tilde f_1$ 。
这篇论文的独特点在于:它不是只做“把输入倒出来”,而是把“倒输入”和“偷模型”绑在一起做。作者认为,这两个目标是互相帮助的:如果你连客户端那部分模型也一起估,就更容易找到能产生同样 smashed data 的输入;反过来,如果你找到一批能对上 smashed data 的输入,也更容易把客户端模型拟合出来。论文自己把这一点称为一种“symbiotic combination”,也就是“共生式结合”。
它的数学形式并不复杂。作者把整个网络写成
$$ F(\theta, x)=f_2(\theta_2, f_1(\theta_1, x)), $$
其中 $f_1$ 是客户端前半段, $f_2$ 是服务器后半段。服务器真正能看到的是 $f_1(\theta_1,x)$ ,也就是 smashed data。攻击者就自己造一个同构的客户端克隆 $\tilde f_1(\tilde\theta_1,\tilde x)$ ,同时去优化“假的输入” $\tilde x$ 和“假的客户端参数” $\tilde\theta_1$ ,让它的输出尽量接近真实 smashed data。作者写成两个目标:
$$ \tilde x^*=\arg\min_{\tilde x}\; \mathrm{MSE}(\tilde f_1(\tilde\theta_1,\tilde x), f_1(\theta_1,x))+\lambda\,TV(\tilde x), $$
$$ \tilde\theta_1^*=\arg\min_{\tilde\theta_1}\; \mathrm{MSE}(\tilde f_1(\tilde\theta_1,\tilde x), f_1(\theta_1,x)). $$
这里我把每个符号都说清楚: $x$ 是真实输入; $\tilde x$ 是攻击者猜的输入; $\theta_1$ 是真实客户端参数; $\tilde\theta_1$ 是攻击者猜的客户端参数; $\mathrm{MSE}$ 是均方误差,用来衡量两个 smashed data 有多接近; $TV(\tilde x)$ 是 total variation(总变分)正则项,它偏好更平滑的图像,减少噪声伪影; $\lambda$ 是这个平滑项的权重。
一个很关键的实现细节是: 作者没有同时更新 $\tilde x$ 和 $\tilde\theta_1$ 。他们说,直接联合做一次梯度更新,实验上效果不好,容易陷入不理想的状态。所以他们采用了 coordinate gradient descent (坐标式梯度下降):先固定 $\tilde\theta_1$ ,只更新 $\tilde x$ ;再固定 $\tilde x$ ,只更新 $\tilde\theta_1$ ;如此交替,直到收敛。你可以把它理解成“先让伪输入去对 smashed data,再让伪模型去对 smashed data,轮流逼近”。这也是这篇论文方法上最重要的技巧之一。
这一步为什么能工作?直觉上说,smashed data 不是随便的随机向量,它是“真实客户端模型对真实输入加工后的结果”。如果服务器知道客户端模型架构,那它就知道这个结果应该来自什么样的函数类。于是它虽然不知道真正的参数和真正的输入,但可以在“同架构模型 + 候选输入”的空间里找一个组合,让输出对上观察到的 smashed data。因为早层特征通常还保留很多局部纹理和轮廓,所以这个优化往往能找到“看起来像原图”的 $\tilde x$ ;与此同时,那个能稳定产出相近 smashed data 的 $\tilde f_1$ 也就成了一个功能相近的客户端替身。这个解释是我对论文方法机制的归纳,论文本身主要靠实验支持,并没有给出形式化可恢复性定理。方法描述本身见论文第 3.2 节。
论文的第二个攻击是 标签推断 (label inference)。这里作者把记号换了一下,容易看晕:在这一小节里, $f_1$ 表示服务器端那部分模型, $f_2$ 表示客户端最后那一层模型。攻击场景是:客户端不想把标签直接给服务器,于是把最后一层和 loss 留在本地算;但服务器会收到从这最后一层反传回来的梯度。作者的想法是:如果客户端最后一层只有一层,而且标签空间是离散、可枚举的,那么服务器就可以把每个候选标签都试一遍,用自己随机初始化的同构模型 $\tilde f_2$ 计算每个候选标签下的梯度,再和真实收到的梯度做比较,哪个最接近,就猜哪个标签。公式是
$$ \tilde y^*=\arg\min_{\tilde y}\; \mathrm{MSE}\!\left( \frac{\partial L(f_2(f_1(x)),y)}{\partial \theta_2}, \frac{\partial L(\tilde f_2(f_1(x)),\tilde y)}{\partial \tilde\theta_2} \right). $$
这里 $y$ 是真实标签, $\tilde y$ 是候选标签, $L$ 是训练损失, $\theta_2,\tilde\theta_2$ 分别是真实和克隆最后一层的参数。简单说就是: 标签不同,最后一层梯度的“指纹”不同;服务器枚举标签,找梯度最像的那个。
这个标签攻击为什么能做到那么强?因为最后一层如果只有一层,梯度和“输入到最后一层的特征 + 真实标签”绑定得非常紧。作者报告说,在这种设定下, 标签可以被 perfect accuracy,也就是 100% 正确率地推出来 。而一旦标签被推出来,服务器其实就已经拿到了训练监督信号,再把自己偷出来的克隆模型接上去训练,就能得到和客户端差不多的整体功能。也就是说,这不是一个“只泄露一点标签统计”的攻击,而是会直接击穿“我把最后一层留本地就安全了”的想法。
论文还专门解释了 多客户端为什么也挡不住 。理由有两个。第一,SplitNN 的服务器在任一时刻实际上只和一个客户端交互,所以攻击可以对每个客户端轮流做。第二,多客户端训练时通常共享一组不断更新的客户端参数,因此从服务器视角看,多客户端只是把数据分散在不同物理端上,本质上仍然像“一个更大的客户端数据集”。所以作者说,他们的攻击几乎不需要修改就能推广到 $n$ -client 设置。