MedCLIPSeg: Probabilistic Vision-Language Adaptation for Data-Efficient and Generalizable Medical Image Segmentation

med clip seg:医学图像分割,基于概率的视觉语言适应用于高效通用的医学图像分割

  • Paper: https://arxiv.org/abs/2602.20423
  • Code: https://github.com/HealthX-Lab/MedCLIPSeg
  • Project: https://tahakoleilat.github.io/MedCLIPSeg
  • 来源:cvpr 2026
  • 模型/数据:huggingface.co/TahaKoleilat/MedCLIPSeg
  • 作者:康科迪亚大学-加拿大

我感觉这一工作做得还行,可以作为借鉴。用于少样本、半监督模式的一个范式

目的

医学图像分割(medical image segmentation)仍因标注数据优先,解剖特征模糊以及领域偏移等问题具有挑战性。尽管视觉-语言模型如CLIP提供了强大的跨模态表示,其在密集,文本引导的医学图像分割的潜力仍旧被忽视(remain under explored)。

懂了,make CLIP great again.

背景

准确的医学分割仍旧hi诊断,治疗计划和定量临床随访的基石。有三个方面的障碍和限制:

  • 分割GT的标注成本,标注的不一致性
  • 病灶和器官可能展示模糊的边界,由于强度过渡或者部分体积效应。
  • 扫描设备、采集协议和患者个体差异,扫描中常见的领域转移,导致仅基于分布内的数据训练得到的模型面对分布外的数据时失效。

迫切需要一种同时具备数据高效性、不确定性敏感、跨领域泛化能力的分割模型。

在医学分割领域,以往可以通过以下方法实现:

  • U-Net等CNN模型
  • ViT架构

这些模型依赖大量的像素级监督,大多数模型运行时确定性的(deterministically).在处理分布外的输入和边界模糊的样本时,存在系统性的过度自信(systematically over-confident),在缺乏预警机制的情况下,产生不可靠的分割结果。

在传统的确定性特征学习方法的问题下,未能充分考虑模糊性和特征之间的局部不一致。

Learning methods that can adaptively weigh evidence and/or modulate attention based on contextual reliability of data distribution would be hopeful to address this [53].

[53] Alireza Mehrtash, William M Wells, Clare M Tempany, Purang Abolmaesumi, and Tina Kapur. Confidence calibration and predictive uncertainty estimation for deep medical image segmentation. IEEE transactions on medical imaging, 39(12):3868–3878, 2020. 2

这两个方法可以解决这个问题。

具体实施方法

image-20260324161507614
  • 输入部分

    • Xt为文本输入,就是分割的文本描述,就是任务描述“乳腺超声图像上右上方有一个中等大小的圆形不规则肿物。”意思就是医生很少做逐像素标注CT图的活,但是这类文本描述做得很多。
    • Xv为图像输入就是CT图,待处理的图像。
  • 模型

    • 几乎是根据CLIP模型修改的。
    • CLIP就是把文本-图像全部转换到高维特征,计算两个高维特征的相似度。
    • 看到图中,把CLIP的文本编码器和图像编码器全部Frozen,参数不可变,直接使用的预训练的权重,不参与参数变动。
    • 可训练的参数只有PVL Adapter模块和后面的MLP
    • PVL Adapter,训练新加入的轻量级 Adapter,这个模块非常重要。属于即插即用的,类Lora的设计思路。

PVL Adapter

这一模块的源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class PVL_Adapter(nn.Module):
def __init__(self,
in_channels_vis: int,
in_channels_txt: int,
adapter_channels: int,
beta: float,
gate_init: int):

super().__init__()

# Down projection
self.proj_vis_down = nn.Sequential(nn.Linear(in_channels_vis, adapter_channels, bias=False))
self.proj_txt_down = nn.Linear(in_channels_txt, adapter_channels, bias=False)

# Up projection
self.proj_vis_up = nn.Linear(adapter_channels, in_channels_vis, bias=False)
self.proj_txt_up = nn.Linear(adapter_channels, in_channels_txt, bias=False)

# Cross-modal interaction
self.two_way = TwoWayTransformerLayer(adapter_channels, beta, gate_init)

def forward(self, vis, text):

v = self.proj_vis_down(vis)
t = self.proj_txt_down(text)

v_fused, t_fused = self.two_way(v, t)

vis_out = self.proj_vis_up(v_fused)
txt_out = self.proj_txt_up(t_fused)

return vis_out, txt_out

这一结构作为两种特征提取的连接部分,负责特征融合,学习不确定性的功能。 其实目标分割的任务用原本的CLIP预训练模型就能做,这一部分是这篇论文的核心创新点,让模型学习到不确定性,不确定性大的可以说“不知道”。 可以双向查询,即不仅可以文本T查询图像V,也可以图像V查询文本T。

分为三个小部分,可以看做一个U-net结构的残差块。

  • 下采样部分
  • 双向交互部分
  • 上采样部分

具体可见代码部分 只能说K,V的方差和均值都是硬拆的 只有Q的方差和均值是算出来的 方差可以看做是一个惩罚项,方差越大,置信度就越低,给的注意力权重就越低。 这个模型预估了一个分布,而不是一个点,所以模型可以学习到不确定性

其中,Softplus激活函数的计算方法如下,这个函数在torch.nn模块中有实现 $$ \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) $$

Attn_PVL模块

1
2
3
4
5
6
7
8
9
10
class TwoWayTransformerLayer(nn.Module):
def __init__(self, embed_dim, beta=2.35, gate_init=0.0):
super().__init__()
self.cross_attn_img_to_txt = ProbCrossAttention(embed_dim, beta, gate_init)
self.cross_attn_txt_to_img = ProbCrossAttention(embed_dim, beta, gate_init)

def forward(self, img_tokens, txt_tokens):
img_tokens = self.cross_attn_img_to_txt(img_tokens, txt_tokens)
txt_tokens = self.cross_attn_txt_to_img(txt_tokens, img_tokens)
return img_tokens, txt_tokens
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class ProbCrossAttention(nn.Module):
"""
Simple and stable probabilistic cross-attention:
- keys/values each have mean + variance
- softplus to keep variance positive
- attention scores adjusted by key uncertainty
"""
def __init__(self, dim, beta: float = 2.35, gate_init: float = 0.0):
super().__init__()
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim * 2) # mean + logvar
self.v_proj = nn.Linear(dim, dim * 2) # mean + logvar
self.out_proj = nn.Linear(dim, dim)
self.norm_k = nn.LayerNorm(dim)
self.norm_v = nn.LayerNorm(dim)
self.eps = 1e-6
self.gate = nn.Parameter(torch.tensor(gate_init)) # Initial residual mix
self.beta = beta

def forward(self, query, context, sample=True, num_samples=1):
B, Tq, C = query.shape
_, Tk, _ = context.shape

Q = self.q_proj(query) # [B, Tq, C]

# Keys: mean + variance
K_out = self.k_proj(context)
K_mu, K_logvar = K_out[..., :C], K_out[..., C:]
K_mu = self.norm_k(K_mu)
K_var = F.softplus(K_logvar) + self.eps # positive

# Values: mean + variance
V_out = self.v_proj(context)
V_mu, V_logvar = V_out[..., :C], V_out[..., C:]
V_mu = self.norm_v(V_mu)
V_var = F.softplus(V_logvar) + self.eps # positive

# Attention scores
scale = math.sqrt(C)
mean_scores = torch.matmul(Q, K_mu.transpose(1, 2)) / scale

var_penalty = torch.matmul(Q.pow(2), K_var.transpose(1, 2)) / C
scores = mean_scores - self.beta * torch.sqrt(var_penalty)

attn_weights = F.softmax(scores, dim=-1)

eps = torch.randn_like(V_var)
V_sample = V_mu + torch.sqrt(V_var) * eps
out = torch.matmul(attn_weights, V_sample) # average over samples

gate = torch.sigmoid(self.gate)
proj_out = self.out_proj(out)
fused = gate * proj_out + (1 - gate) * query

return fused

用于PVL Adapter中部的特征融合,键Keys和值Values建模为具有可学习均值和方差的概率分布,以纳入其中的不确定性。

1
2
3
# 这是V_sample部分
eps = torch.randn_like(V_var)
V_sample = V_mu + torch.sqrt(V_var) * eps

V_sample不分把原本的V点转换为了一个分布。即均值+方差*(0,1)的随机数。

1
2
3
gate = torch.sigmoid(self.gate)
proj_out = self.out_proj(out)
fused = gate * proj_out + (1 - gate) * query

最后一个门控输出,是原Q值和加权后得出的q的加权和。

loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Compute logits
logits_per_image = (patch_mean @ text_features.T) / self.temperature # (B, B)
logits_per_text = (text_features @ patch_mean.T) / self.temperature # (B, B)

# --- Soft targets based on text similarity ---
with torch.no_grad():
text_sim = (text_features @ text_features.T) / self.temperature # (B, B)
text_sim = text_sim / text_sim.norm(dim=-1, keepdim=True)
soft_targets = F.softmax(text_sim, dim=-1) # temperature-controlled soft labels

loss_i2t = self.soft_cross_entropy(logits_per_image, soft_targets)
loss_t2i = self.soft_cross_entropy(logits_per_text, soft_targets.T)

clip_loss = (loss_i2t + loss_t2i) / 2

Soft Contrastive Loss

使用的soft_targets是各个文本的相似性矩阵,表示相似的标签,虽然这两个词不一样,但它们很像。如果两个文本描述(例如“肿瘤”和“恶性肿块”)语义相近,它们的相似度就会很高;而以往CLIP模型使用的是1的对角矩阵,表示不同的标签之间完全无关。

image-20260324205533036

再加上分割loss就是最终的loss,其中Seg loss属于常规写法,没有进行改动。

1
2
3
4
5
6
# Seg loss
def calc_loss(low_res_logits, low_res_label_batch, ce_loss, dice_loss, cfg):
loss_ce = ce_loss(low_res_logits, low_res_label_batch.float())
loss_dice = dice_loss(low_res_logits, low_res_label_batch)
loss = cfg.TRAIN.DICE_WEIGHT * loss_dice + cfg.TRAIN.CE_WEIGHT * loss_ce
return loss

实验

左图为消融实验,右图为不同text数据集(矛盾/缺失位置/过描述/不足描述/原始数据)下的数据。

image-20260324210543922

总结

实际上是在讲如何让 AI 学会“犹豫”。通过 PVL Adapter,MedCLIPSeg 将 CLIP 强大的图文匹配能力与贝叶斯深度学习的严谨性结合在一起,使其不仅能完成分割任务,还能实时评估任务的可信度。

PVL Adapter 是 MedCLIPSeg 的“安全阀”和“增强器”。它让 CLIP 在处理高难度的医学图像时,不仅拥有了双向理解的能力,更重要的是拥有了“自我怀疑”的能力——这种能力在医疗诊断中比单纯的准确率更珍贵。