论文阅读-Tail-Enhanced Representation Learning for Surgical Triplet Recognition

2025-03-28发布于论文随笔 | 最后更新于2025-03-28 14:03:00

long-tail triplet

MICCAI 2024原文链接 Tail-Enhanced Representation Learning for Surgical Triplet Recognition

基本介绍

主要针对三元组类别间特征方差小的问题,之前方法会导致模型在不同类别间不自信。此时自然可以想到借助对比学习,但直接利用对比学习会导致:由于同一图片中存在多组类别的信息,特征学习时会存在相互干扰;长尾类别样本少,很难学习到足够的语义特征。

在该问题上,现有方法主要利用多任务知识蒸馏来加强尾部三元组识别,但这忽视了三元组类间特征的方差建模(事实上有些三元组类间特征差异很小),使模型在一些类别上不自信。如下图的top5预测概率所示:

top5概率结果图

本文为解决上述问题进行了如下设计:

  1. 添加一个解构模块,以求在同一图像中获得不同元的特征
  2. 使用一个全局记忆库来加强尾部类别

整体模型结构如下图:

模型结构图

模块详解

Tail -Enhanced Triplet Recognition

多头分类器分为上下两部分,上半部分用作各元以及整个三元组的分类任务,下半部分用于尾部类别的处理。

此处将特征提取器得到的结果记为\(\mathbf{F}\in\mathrm{R}^{D\times \frac{H}{32}\times\frac{W}{32}}\);将class activation map(CAM)记为\(\mathbf{C}_k\in\mathrm{R}^{1\times \frac{H}{32}\times\frac{W}{32}}\)\(k\)表示是取的是CAM的第k维,拼接后的特征经过单格卷积又减为D维,此时的特征记为\(\widetilde{\mathbf{F}}_k\in\mathrm{R}^{D\times \frac{H}{32}\times\frac{W}{32}}\)

此模块用于计算\(\mathcal{L}_{Cls}=\sum_{a\in A}{\mathcal{L}^a_{Cls}}\)

$$ \mathcal{L}_{Cls}^a = \frac{1}{K_a} \sum_{k=1}^{K_a} \begin{cases} (1 - \hat{y}_k^a)^{\gamma^+} \log(\hat{y}_k^a), & y_k^a = 1, \\ (\hat{y}_k^a)^{\gamma^-} \log(1 - \hat{y}_k^a), & y_k^a = 0, \end{cases} $$

其中的a表示各任务,\(A=\left\{I,V,T,IVT,Ta\right\}\)\(Ta\)为三元组尾部类分类任务;\(K_a\)为任务a的总类数,由于是多标签分类,用one-hot的0、1来表示标签;batch的总loss为各样本loss的均值;\(\gamma^+=0,\gamma^-=2\)为超参。

Instance-Level Contrastive Learning

通过原图\(\mathbf{x}\)和增强图\(\mathbf{x}'\)进行对比学习。定义记忆库\(\mathcal{B}=\left\{\mathbf{e}_1, \mathbf{e}_2, \ldots, \mathbf{e}_L\right\}\);以及正例\(\mathbf{E}_p = \left\{ e_i \in \mathcal{B} : y_i = y_q \right\}\)\(y_q\)为当前embedding \(\mathbf{e}\) 的三元组标签;从\(\mathbf{E}_p\)中随机取k个构成\(\mathbf{E}_p^k\)

$$ \mathcal{L}_{CL} = \frac{1}{k + 1} \sum_{\mathbf{e}_p \in \mathbf{E}_p^k \cup \mathbf{e}'} \log \frac{\exp \left( \mathbf{e}^\top \mathbf{e}_p / \tau \right)}{\exp \left( \mathbf{e}^\top \mathbf{e}' / \tau \right) + \Sigma_{j=1}^L \exp \left( \mathbf{e}^\top \mathbf{e}_j / \tau \right)} $$

温度系数\(\tau=0.07\)\(k=7\)为超参数

Prototype-Based Semantic Enhancement

主要思想就是通过最大化长尾embedding和其对应元的prototype的相似度来加强语义学习。通过下式计算prototype,其中\(\mathcal{B}^a = \left\{ \mathbf{e}_i \in \mathcal{B} : \tilde{y}_i^a = \tilde{y}_q^a \right\}\),即在子任务a下,记忆库中与待查询embedding \(\mathbf{e}\) 类别相同的embedding;

$$ p_{q,d}^a = \frac{1}{|\mathcal{B}^a|} \sum_{i=1}^{|\mathcal{B}^a|} e_{i,d} $$

最终可以得到\(\mathbf{p}_q^a = \left( p_{q,1}^a, p_{q,2}^a, \ldots, p_{q,D}^a \right)^\top\),损失由下式计算:

$$ \mathcal{L}_P = \sum_{a \in \left\{I, V, T\right\}} \log \frac{\exp \left( \mathbf{e}^\top \mathbf{p}_q^a \right)}{\sum_{j=1}^{K_a} \exp \left( \mathbf{e}^\top \mathbf{p}_j^a \right)} $$

实验结果

常规性能对比与模块消融

表2的消融是在fold1上做的。

文章还探究了将CAM拼接在特征后的作用(图3左侧),以及是否对子任务应用本文提到的加强方法的影响

模块细节消融

下图展示了不同超参数下的表现

不同超参数比较