用于半监督医学图像分割的多模态对比互学习和伪标签再学习


Multi-modal constrastive mutual learning and pseudo-label re-learning for semi-supervised medical image segmentation

用于半监督医学图像分割的多模态对比互学习和伪标签再学习

Journal: [ELSEVIER: Medical Image Analysis] (中科院分区一区top,jcr分区Q1) Author: Shuo Zhang, et al. Hebei University of Technology, China Code Link: https://zhenghuaxu.info/publications/ Key Words: 医学图像分割,多模态图像,半监督学习,一致性正则化,伪标签

目录

补充知识

  • 多模态 模态:结构系统的固有震动特性    传统的多模态指的是多种模态的信息,包括:文本、图像、视频、音频等,多模态大多研究的是不同类型的数据的融合的问题[1]    但在医学图像领域,通常是指因成像机理不同而从不同层面提供的信息,多模态图像通常指MRI,PET,CT图像等,如下图

CT PET

   也可以是一种成像模式不同采集参数下的图像,例如MRI不同序列采集的图像,T1,T1CE,T2,Flair等,如下图

four modalities

[1] 什么是多模态?-知乎

  • mask mask 有掩膜或遮罩的意思。    标签,在医学图像中指的是感兴趣区域,一般为疾病病灶,在分割上也就是分割的目标。通常为放射科医生手工标注,称为金标准,或ground turth。    在Segmentation文件中,某类标签通常对应一个具体的值,例如在BraTS2019数据集中,ED区(浮肿)值为2,ET区(增强肿瘤)值为4,NET区(坏疽)值为1

  • 一致性正则化    一致性正则化是半监督深度学习中的一类经典做法,该方法的基本思想是:对于一个输入,即使受到扰动,网络仍然可以产生和原来一致的输出,或者是一个结果近似的、结果向量距离较近的、输出空间分布上聚集的输出。    例如半监督领域的经典论文 Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results 中利用Teacher模型和Student模型的预测计算一致性损失,FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence 中利用强增强的预测和弱增强产生的伪标签计算一致性损失。

  • 对比学习    对比学习(Contrastive Learning)是无监督学习领域的一种方法,它主要侧重于学习同类实例之间的共同特征,区分非同类实例之间的不同之处。对于任意数据x,对比学习的目标是学习一个编码器$f$,使得 $$ score(f(x),f(x^+))>socre(f(x),f(x^-)) $$    其中$$x^+$$是和$$x$$相似的正样本,$$x^-$$是和$$x$$不相似的负样本,$$score$$是一个度量函数来衡量样本间的相似度。    具体来说,对比学习通常会构造一个对比损失,为每个样本生成正实例特征和负实例特征,通过学习一个嵌入空间,利用相似度度量两个嵌入的接近程度,来自同一实例的样本被拉得更近,来自不同实例的样本被推得更远。

  • 蒙特卡洛dropout    Monte-Carlo Dropout(蒙特卡罗 dropout),简称 MC dropout,是一种从贝叶斯理论出发的 Dropout 理解方式,将 Dropout 解释为高斯过程的贝叶斯近似。理论证明看起来挺复杂,参考论文:Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning. 以及这篇论文的 Appendix。    MC dropout用起来比较简单,不需要修改现有的神经网络模型,只需要神经网络模型中带dropout层,无论是标准的dropout还是其变种。在训练的时候,MC dropout表现形式和dropout没有什么区别,按照正常模型训练方式训练即可。区别在于:MC Dropout在测试的时候,在前向传播过程,神经网络的dropout不能关闭(model.eval())。    MC dropout的MC体现在我们需要对同一个输入进行多次前向传播过程,这样在dropout的加持下可以得到“不同网络结构”的输出,将这些输出进行平均和统计方差,即可得到模型的预测结果及uncertainty。而且,这个过程是可以并行的,所以在时间上可以等于进行一次前向传播。

论文主体

摘要

  • 现状
  • 半监督学习在具有少量标记数据的医学图像分割中具很大潜力,但目前的大多数方法都只使用单模态数据,多模态数据能够提供不同的语义信息,从而提升半监督的分割性能
  • 现有少量的半监督多模态分割方法,他们几乎都有一个共同的缺点,多模态数据的处理部分高度耦合,因此模型的训练阶段和推理阶段都需要多模态数据
  • 主要贡献
  • 提出了一种半监督对比互学习(Semi-CML)分割框架,提出一种新的区域相似性对比损失(ASC Loss)利用不同模态之间的跨模态信息和预测一直性来进行对比互学习
  • 进一步开发了伪标签再学习(PReL)方案来解决两种模态的分割性能的差异(即存在一种模态下的分割性能通常优于另外一种模态)
  • 在两个公共多模态数据集上进行了实验。结果表明,使用PReL的Semi-CML大大优于最先进的半监督分割方法,实现了与100%标记数据的完全监督分割方法相似(甚至有时更好)的分割性能,数据标注的成本大大降低
  • 进行了消融实验,评估了ASC Loss和PReL模块的有效性

本文方法

   给定双模态数据集$$D^{'}$$和$$D^{''}$$,标记数据的数量为$$N$$,未标记数据的数量为$$M$$,将标记他们对应的标记数据集和未标记数据集$$D_{L}^{'}$$、$$D_{U}^{'}$$和$$D_{L}^{''}$$、$$D_{U}^{''}$$定义如下:

$$ D_{L}^{'}=\left{\left(x_{i}^{'},y_{i}\right)\right}{i=1}^{N},D^{'}=\left{\left(x_{i}^{'}\right)\right}_{i=N+1}^{N+M} $$

   其中,$$x_{i}^{'},x_{i}^{''}\in R^{H\times W}$$是大小为$$H\times W$$的不同的图像模态,$$y_{i} \in \left{0,1\right}^{C \times H \times W}$$是具有$$C$$类的金标准(ground-truth),已知,两种图像模态$$x_{i}^{'},x_{i}^{''}$$对应相同的掩码$$y_{i}$$,即使用的是配准图像

  • 半监督对比学习

   受多模态数据和对比自监督学习中内在相关性的潜在价值的启发,为克服多模态融合模型中的高耦合问题,提出一种新的低耦合的半监督对比互学习框架Semi-CML,如下图所示

semi-CML stage1

图1 半监督对比互学习(Semi-CML)框架

  • 首先构建两个具有相同结构的U-Net,将他们作为两种不同模态的分割框架
  • 通过两个U-Net的前向传播从而获得不同模态的两个mini-batch的预测图,标记批次进行监督学习,未标记批次通过MSE损失进行简单的相互学习,通过区域相似性对比损失进行深度相互学习

  • 双模态监督学习    训练阶段将具有相同mask的两种不同图像模态作为输入,在网络的输出端得到两个网络的预测结果。对于这两个模型的监督训练,用每个mini-batch中带有mask的来计算监督损失。    监督损失骰子损失(Dice Loss)二元交叉熵损失(Binary CrossEntropy Loss, BCE) 的加权组成。    将两种模态的分割网络分别定义为$$F(·)$$和$$G(·)$$,监督损失定义如下:

$$ L_{sup}(\hat{y},y)=\beta L_{bce}(\hat{y},y)+\gamma L_{dice}(\hat{y},y) $$

   $$\beta$$和$$\gamma$$对应两个损失的权重,总的监督损失如下:

$$ min_{F,G}L_{sup}^{Total}(F,G) $$

  1. 跨模态知识对比互学习    首先采用简单的一致性正则化项来进行简单的互学习,然后针对深度互学习提出基于区域相似度的对比损失。    简单的一致性正则化:MSE一致性损失

$$ \begin{align} \min_{F,G}L_{mse}\left(F,G\right)=\mathbb{E}_{x^{'},x^{''}}\left[{|F(x^{'})-G(x^{''})|}^{2}\right] \end{align} $$

   通过最小化两种模态之间的预测差异来实现简单的相互学习。    但是MSE的性能不够好,因为:

  •    MSE通常只构建成对预测之间的距离误差,只能关注成对模态的预测相似性,而不能关注未配对的跨模态数据和未配对的同模态数据的预测相异性。这导致网络容易陷入过拟合,因为它只关注容易学习的成对模态相似性,导致网络无法学习更深入的未成对跨模态互补信息和同模态互补信息。
  •    MSE只考虑了两次预测中每个像素点之间的欧几里得距离度量,无法关注预测目标的边缘和区域上下文信息。这对于需要关注预测目标的边缘和区域的图像分割来说是一个缺点。    针对以上MSE loss的问题,作者提出了新颖的区域相似性对比(Area-Similarity Contrastive, ASC)损失来克服这些问题。(至于为什么不是用ASC来直接替代MSE而是采用了MSE+ASC的组合,可以在消融实验中找到)    深度互学习:ASC损失    从两种不同的图像模态中随机选择$K$个未标记样本作为mini-batch,因此生成$2K$个数据点。对应于相同模态的两个图像模态作为正样本对。则一共生成$K$个正样本对,定义为集合$I^{+}$,对于每一个正样本对,取mini-batch中剩余$2\left(K-1\right)$个数据作为负样本集,定义为$I^{-}$。在大小为$K$的mini-batch中,正样本集和负样本集描述为:

$$ \begin{align} &I^{+} = {\left{\left(F(x_{i}^{'}),G(x_{i}^{''})\right)\right}}{i=1}^{K} \ &I^{-} = {\left{F(x^{'})\right}}{j \neq i}^{K-1} \cup {\left{G(x^{''})\right}}_{j \neq i}^{K-1} \end{align} $$

   同一疾病的病变区域和位置可能在同一患者扫描的体数据的相邻切片之间比较相似,可能对负样本集的构建产生不良影响。因此值从大量未标记数据中抽取小批量进行对比学习,即在设置mini-batch时将值设置的非常小,尽可能地避免了将相邻或相近切片构建负样本对的可能性。    个人理解: self-understanding for ASC

图2 帮助理解ASC
   如图是一个mini-batch,K=4,共选取2个模态共计8个切片,来自三个体数据α, β, γ。$(a1,b1)$可以构成一个正样本对,$(a2,b2)$可以构成一个正样本对,一共构建4个正样本对,正样本集$I^{+}={(a1,b1),(a2,b2),(a3,b3),(a4,b4)}$。对于正样本集中的任意一个正样本对,例如$(a1,b1)$,正样本集中所有的其他的切片构成该样本对对应的负样本集$I^{-}={a2,b2,a3,b3,a4,b4}$,我们可以发现正样本集中的元素都是样本对,负样本集中的元素都是具体的切片。    由于mini-batch设置的非常小,根据pytorch中dataloader的加载方式,同一模态下的切片最大的概率是来自不同的体数据的,即使如上图所示来自同一体数据($a1$和$a2$),他们也极大的概率是来自于相距较远的切片。而取到相邻或相近切片的情况概率非常小,几乎不会发生。    为衡量正负样本对之间的区域上下文相似度,使用Dice相似度系数作为度量,定义如下: $$ \begin{align} S_{dice}(y_{1},y_{2})=\frac{2 \times \sum_{\substack{\dot {y_{1}} \in y_{1},\\dot {y_{2}} \in y_{2}}}\dot y_{1},\dot y_{2}}{\sum_{\substack{\dot y_{1} \in y_{1},\\dot y_{2} \in y_{2}}}(\dot y_{1}+\dot y_{2})} \end{align} $$    其中,$\dot y_{1}$和$\dot y_{2}$分别表示两个预测$y_{1}$和$y_{2}$的每个像素的值,将正对的$ASC$损失定义如下: $$ \begin{align} l_{asc}(\hat y,\widetilde y)= - \log{\frac {\exp \Big({S_{dice}\left(\hat y, \widetilde y \right)}\Big)}{\exp \Big({S_{dice}\left(\hat y, \widetilde y \right)}\Big) + \sum_{h \in I^{-}} \exp \Big( {S_{dice}(\hat y, h)}\Big)}} \end{align} $$    其中,$(\hat y,\widetilde y)$是$I^{+}$中的正样本对,$h$为$I^{-}$中$\hat y$对应的负样本对,一个mini-batch中$K$个未标记图像的总ASC损失如下: $$ \begin{equation} \begin{aligned} &\min_{F,G}L_{ASC}(F,G)= \mathbb{E_{x^{'},x^{''}}} \ &\left[\frac {1}{2K} \sum_{(F^{'},G^{''}) \in I^{+}} \left(l_{asc}(F^{'},G^{''}) + l_{asc}(G^{''},F^{'}) \right)\right] \end{aligned} \end{equation} $$    其中,$F^{'}=F(x^{'})$和$G^{''}=G(x^{''})$。正样本对中的每个样本都要计算与负样本集中样本的相似度,所以正样本对需要计算两次$l_{asc}$,即两个模态都分别需要计算一次。    另外,为了平衡MSE loss和ASC loss,加入权重系数$\omega_{1}$和$\omega_{2}$,其中$ω_{1}$是一个ramp-up函数,根据epoch数调整权重值,$ω_{2}$是一个标量。最后,Semi-CML框架训练Stage 1的整体loss定义如下: $$ \begin{align} &\min_{F,G}L_{CML}(F,G)=L_{sup}^{total}+\omega_{1}L_{mse}+\omega_{2}L_{ASC} \end{align} $$

  • 使用BMA教师模型的伪标签再学习

   使用半监督对比学习后大大提高了两种模态的分割性能,但实验表明,两种模态的性能之间存在差距,即存在一种模态的分割性能通常优于另一种模态。    为了进一步提高低性能模态的分割精度,本文利用高性能模态的模型设计了一种软伪标签再学习策略。首先设计了一种新颖的最佳模型移动平均(BMA)自集成技术以在第1阶段期间生成最佳且可靠的教师模型(即,在 Semi-CML 达到收敛之前的某个epoch $L_{1}$)。教师模型称为最佳模型移动平均自集成教师模型(BMA Teacher)。其次,在第2阶段(第$L_{1}$ epoch之后),本文使用BMA teacher模型和蒙特卡洛dropout采样来生成具有高可靠性的伪标签。然后,使用伪标签对低性能模态进行重新学习,同时,高性能模型也执行重新学习过程,但在预热时期($L_{2}$,滞后于第$L_{1}$ epoch)之后开始。

semi-CML stage2

图3 BMA更新策略和伪标签再学习过程

  1. 最佳模型移动平均教师模型    教师-学生模型广泛用于半监督学习算法,其中教师模型通常使用指数移动平均线(EMA)来更新网络参数。    这种方法通常会在每个epoch或minibatch中更新模型权重(尽管模型性能较差),这会导致教师模型的性能下降。这是因为深度模型通常会因不稳定的训练而产生性能冲击,当模型陷入低性能阶段可能不可靠。这时候如果把这个模型的权重更新到teacher model上,teacher model的可靠性会降低。    为了仅集成高质量的模型权重,本文提出了一种新颖的最佳模型移动平均(BMA)自集成方法来选择性地更新教师模型的权重,定义为$T(\theta)$。具体来说,根据每个epoch的训练或验证准确率来决定是否更新教师模型,因此只选择最优(与之前所有epoch的模型性能相比)或次优的高性能模型权重来更新教师模型重量。为了确保将最佳或次优模型权重更新到教师模型,设计了一个最佳模型池(BMP),定义为集合$R_{pool}$,以动态存储$p$个(BMP Number)不同epochs的最佳或次优精度值。    具体的做法是:    1、 对于Stage 1,BMA teacher 模型在$m$ epoch之后更新,因为更好的模型权重通常在最开始的$m$ epochs不可用,并且所有更新仅在Stage 1完成(即在epoch $L_{1}$之前)。    2、在第$m$ epoch后的$p$个epochs中,使用这$p$个epochs的精度初始化容量为$p$的BMP。    3、在第$(m+p+1)$个epoch, 直接将高性能模型的权重作为teacher模型的初始权重。    4、然后,对于$(m+p+1)$ epoch和$L_{1}$ epoch之间的每个 epoch,需要在满足某些条件后执行两个步骤:对于第1步,需要不断更新BMP以确保它包含最优和次优精度值。对于第2步,使用最佳模型移动平均函数来不断更新教师模型的权重:将当前epoch的精度与BMP中的最小精度进行比较,以决定是否更新BMP和教师模型。如果大于最小值,用当前精度更新池中的最小值,以确保池中的精度值是前$L_{1}$个epochs中的最好的$p$个精确度。    5、同时,使用BMA更新函数来更新教师模型的权重。教师模型的权重更新率随着准确率的增加而动态变化,定义如下:

$$ \alpha=\min \left(1 - \frac {Acc_{t}-Acc_{min}}{Acc_{t}},a_{0}\right) $$

   其中$Acc_{t}$是训练时期$t$的准确度,$Acc_{min}=\min⁡(R_{pool})$,$α_{0}=0.99$。最后,BMP和BMA teacher的所有更新过程描述如下:

$$ \begin{equation} \left{\begin{array}{cl} A c c_t \Rightarrow \mathbb{R}{\text {pool }} & \text { if } A c c_t>A c c \ \theta_t=\alpha \theta_{t-1}+(1-\alpha) \theta_t^{\prime \prime} & \text { if } A c c_t>A c c_{\min }, \end{array}\right. \end{equation} $$

   其中$θ_{t}$和$θ_{t}^{''}$分别是训练时期$t$的教师模型和高性能模型(例如$G(⋅)$)的权重。该方法更新教师模型权重可以过滤掉性能较差的模型权重,并在模型较好时使权重更新比例较大,从而最大化最高质量模型权重的集成。

Algorithm

图4 BMA算法细节

  1. 生成伪标签和再学习    在第2阶段,通过对BMA教师模型$T(\theta)$执行蒙特卡洛dropout采样,以获得更可靠的预测结果作为未标记数据的伪标签。具体做法是:在教师模型上应用dropout层作为贝叶斯神经网络的近似,并将高斯噪声应用于输入。然后,teacher模型进行D次随机前向传播,并对得到的预测结果进行平均,得到最终的可靠软伪标签。形式上,这个过程可以定义为:

$$ \begin{equation} P_s=\frac{1}{D} \sum_{i=1}^D T\left(x^{\prime \prime}+\xi_i\right) \end{equation} $$

   其中$P_{s}$是可靠的软伪标签,$\xi_{t}$是高斯噪声。然后,使用生成的软伪标签来获得新的损失。它们由两部分组成,包括监督损失($L_{sup}$)和Semi-CML中的ASC损失($L_{ASC}$)。在再学习阶段,ASC损失中的正负样本对由两种模态之一的预测和BMA教师模型生成的软伪标签组成,定义如下:

$$ \begin{equation} L_{\mathrm{ASC}}^{\mathrm{ReL}}\left(\hat{y}, P_s\right)=\frac{1}{2 K} \sum_{\left(\hat{y}, P_s\right) \in \Gamma^{+}}\left(l_{\text {asc }}\left(\hat{y}, P_s\right)+l_{a s c}\left(P_s, \hat{y}\right)\right) \end{equation} $$

   个人理解:这里BMA生成的伪标签可以视作一个新的模态的生成的结果,从而和原有的模态进行对比学习。    其中$\hat y$是监督模型的输出,$K$是一个mini-batch中未标记图像的数量。伪标签再学习过程主要针对低性能模态,因此监督公式定义如下:

$$ \begin{equation} \begin{aligned} & \min F \mathcal{L}}^{\prime}(F)=\mathbb{E}{x^{\prime}, P_s} \ & {\left[\alpha_1 \mathcal{L}}^{\mathrm{ReL}}\left(F\left(x^{\prime}\right), P_s\right)+\left(1-\alpha_1\right) \mathcal{L}_{\text {sup }}^{\prime}\left(F\left(x^{\prime}\right), P_s\right)\right]} \end{aligned} \end{equation} $$

   其中$\alpha_1$是训练低性能模态的平衡因子。BMA教师模型生成的伪标签也有助于训练高性能模态。对高性能模态进行类似的监督,其定义如下:

$$ \begin{equation} \begin{aligned} & \min G \mathcal{L}}^{\prime \prime}(G)=\mathbb{E}{x^{\prime \prime}, P_s} \ & {\left[\alpha_2 \mathcal{L}}^{\mathrm{ReL}}\left(G\left(x^{\prime\prime}\right), P_s\right)+\left(1-\alpha_2\right) \mathcal{L}_{\text {sup }}^{\prime \prime}\left(G\left(x^{\prime\prime}\right), P_s\right)\right]} \end{aligned} \end{equation} $$

   其中$\alpha_2$是高性能模态训练的平衡因子。为了确保低性能模态在重新学习过程中有更稳定的优化,高性能模态的重新学习过程仅在基于epoch $L_{1}$的预热之后开始。    采用伪标签再学习策略后,两个模型的分割性能将得到进一步提升。低性能模态的准确率会增加更多,因为这个学习过程提供了更多的跨模态知识。

实验结果

   非本文档重点    具体可以查看pdf