MedCLIPSeg:
Probabilistic Vision-Language Adaptation for Data-Efficient and
Generalizable Medical Image Segmentation
med clip
seg:医学图像分割,基于概率的视觉语言适应用于高效通用的医学图像分割
Paper: https://arxiv.org/abs/2602.20423
Code: https://github.com/HealthX-Lab/MedCLIPSeg
Project: https://tahakoleilat.github.io/MedCLIPSeg
来源:cvpr 2026
模型/数据:huggingface.co/TahaKoleilat/MedCLIPSeg
作者:康科迪亚大学-加拿大
我感觉这一工作做得还行,可以作为借鉴。用于少样本、半监督模式的一个范式
目的
医学图像分割(medical image
segmentation)仍因标注数据优先,解剖特征模糊以及领域偏移等问题具有挑战性。尽管视觉-语言模型如CLIP提供了强大的跨模态表示,其在密集,文本引导的医学图像分割的潜力仍旧被忽视(remain
under explored)。
懂了,make CLIP great again.
背景
准确的医学分割仍旧hi诊断,治疗计划和定量临床随访的基石。有三个方面的障碍和限制:
分割GT的标注成本,标注的不一致性
病灶和器官可能展示模糊的边界,由于强度过渡或者部分体积效应。
扫描设备、采集协议和患者个体差异,扫描中常见的领域转移,导致仅基于分布内的数据训练得到的模型面对分布外的数据时失效。
迫切需要一种同时具备数据高效性、不确定性敏感、跨领域泛化能力的分割模型。
在医学分割领域,以往可以通过以下方法实现:
这些模型依赖大量的像素级监督,大多数模型运行时确定性的(deterministically).在处理分布外的输入和边界模糊的样本时,存在系统性的过度自信(systematically
over-confident),在缺乏预警机制的情况下,产生不可靠的分割结果。
在传统的确定性特征学习方法的问题下,未能充分考虑模糊性和特征之间的局部不一致。
Learning methods that can adaptively weigh evidence and/or modulate
attention based on contextual reliability of data distribution would be
hopeful to address this [53].
[53] Alireza Mehrtash, William M Wells, Clare M Tempany, Purang
Abolmaesumi, and Tina Kapur. Confidence calibration and predictive
uncertainty estimation for deep medical image segmentation. IEEE
transactions on medical imaging, 39(12):3868–3878, 2020. 2
这两个方法可以解决这个问题。
具体实施方法
image-20260324161507614
输入部分
Xt为文本输入,就是分割的文本描述,就是任务描述“乳腺超声图像上右上方有一个中等大小的圆形不规则肿物。”意思就是医生很少做逐像素标注CT图的活,但是这类文本描述做得很多。
Xv为图像输入就是CT图,待处理的图像。
模型
几乎是根据CLIP模型修改的。
CLIP就是把文本-图像全部转换到高维特征,计算两个高维特征的相似度。
看到图中,把CLIP的文本编码器和图像编码器全部Frozen,参数不可变,直接使用的预训练的权重,不参与参数变动。
可训练的参数只有PVL Adapter模块和后面的MLP
PVL Adapter,训练新加入的轻量级
Adapter,这个模块非常重要。属于即插即用的,类Lora的设计思路。
PVL Adapter
这一模块的源码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class PVL_Adapter (nn.Module): def __init__ (self, in_channels_vis: int , in_channels_txt: int , adapter_channels: int , beta: float , gate_init: int ): super ().__init__() self.proj_vis_down = nn.Sequential(nn.Linear(in_channels_vis, adapter_channels, bias=False )) self.proj_txt_down = nn.Linear(in_channels_txt, adapter_channels, bias=False ) self.proj_vis_up = nn.Linear(adapter_channels, in_channels_vis, bias=False ) self.proj_txt_up = nn.Linear(adapter_channels, in_channels_txt, bias=False ) self.two_way = TwoWayTransformerLayer(adapter_channels, beta, gate_init) def forward (self, vis, text ): v = self.proj_vis_down(vis) t = self.proj_txt_down(text) v_fused, t_fused = self.two_way(v, t) vis_out = self.proj_vis_up(v_fused) txt_out = self.proj_txt_up(t_fused) return vis_out, txt_out
这一结构作为两种特征提取的连接部分,负责特征融合,学习不确定性的功能。
其实目标分割的任务用原本的CLIP预训练模型就能做,这一部分是这篇论文的核心创新点,让模型学习到不确定性,不确定性大的可以说“不知道”。
可以双向查询,即不仅可以文本T查询图像V,也可以图像V查询文本T。
分为三个小部分,可以看做一个U-net结构的残差块。
具体可见代码部分
只能说K,V的方差和均值都是硬拆的
只有Q的方差和均值是算出来的
方差可以看做是一个惩罚项,方差越大,置信度就越低,给的注意力权重就越低。
这个模型预估了一个分布,而不是一个点,所以模型可以学习到不确定性
其中,Softplus激活函数的计算方法如下,这个函数在torch.nn模块中有实现
$$
\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
$$
Attn_PVL模块
1 2 3 4 5 6 7 8 9 10 class TwoWayTransformerLayer (nn.Module): def __init__ (self, embed_dim, beta=2.35 , gate_init=0.0 ): super ().__init__() self.cross_attn_img_to_txt = ProbCrossAttention(embed_dim, beta, gate_init) self.cross_attn_txt_to_img = ProbCrossAttention(embed_dim, beta, gate_init) def forward (self, img_tokens, txt_tokens ): img_tokens = self.cross_attn_img_to_txt(img_tokens, txt_tokens) txt_tokens = self.cross_attn_txt_to_img(txt_tokens, img_tokens) return img_tokens, txt_tokens
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 class ProbCrossAttention (nn.Module): """ Simple and stable probabilistic cross-attention: - keys/values each have mean + variance - softplus to keep variance positive - attention scores adjusted by key uncertainty """ def __init__ (self, dim, beta: float = 2.35 , gate_init: float = 0.0 ): super ().__init__() self.q_proj = nn.Linear(dim, dim) self.k_proj = nn.Linear(dim, dim * 2 ) self.v_proj = nn.Linear(dim, dim * 2 ) self.out_proj = nn.Linear(dim, dim) self.norm_k = nn.LayerNorm(dim) self.norm_v = nn.LayerNorm(dim) self.eps = 1e-6 self.gate = nn.Parameter(torch.tensor(gate_init)) self.beta = beta def forward (self, query, context, sample=True , num_samples=1 ): B, Tq, C = query.shape _, Tk, _ = context.shape Q = self.q_proj(query) K_out = self.k_proj(context) K_mu, K_logvar = K_out[..., :C], K_out[..., C:] K_mu = self.norm_k(K_mu) K_var = F.softplus(K_logvar) + self.eps V_out = self.v_proj(context) V_mu, V_logvar = V_out[..., :C], V_out[..., C:] V_mu = self.norm_v(V_mu) V_var = F.softplus(V_logvar) + self.eps scale = math.sqrt(C) mean_scores = torch.matmul(Q, K_mu.transpose(1 , 2 )) / scale var_penalty = torch.matmul(Q.pow (2 ), K_var.transpose(1 , 2 )) / C scores = mean_scores - self.beta * torch.sqrt(var_penalty) attn_weights = F.softmax(scores, dim=-1 ) eps = torch.randn_like(V_var) V_sample = V_mu + torch.sqrt(V_var) * eps out = torch.matmul(attn_weights, V_sample) gate = torch.sigmoid(self.gate) proj_out = self.out_proj(out) fused = gate * proj_out + (1 - gate) * query return fused
用于PVL
Adapter中部的特征融合,键Keys 和值Values 建模为具有可学习均值和方差的概率分布,以纳入其中的不确定性。
1 2 3 eps = torch.randn_like(V_var) V_sample = V_mu + torch.sqrt(V_var) * eps
V_sample不分把原本的V点转换为了一个分布。即均值+方差*(0,1)的随机数。
1 2 3 gate = torch.sigmoid(self.gate) proj_out = self.out_proj(out) fused = gate * proj_out + (1 - gate) * query
最后一个门控输出,是原Q值和加权后得出的q的加权和。
loss
1 2 3 4 5 6 7 8 9 10 11 12 13 14 logits_per_image = (patch_mean @ text_features.T) / self.temperature logits_per_text = (text_features @ patch_mean.T) / self.temperature with torch.no_grad(): text_sim = (text_features @ text_features.T) / self.temperature text_sim = text_sim / text_sim.norm(dim=-1 , keepdim=True ) soft_targets = F.softmax(text_sim, dim=-1 ) loss_i2t = self.soft_cross_entropy(logits_per_image, soft_targets) loss_t2i = self.soft_cross_entropy(logits_per_text, soft_targets.T) clip_loss = (loss_i2t + loss_t2i) / 2
Soft Contrastive Loss
使用的soft_targets是各个文本的相似性矩阵,表示相似的标签,虽然这两个词不一样,但它们很像。如果两个文本描述(例如“肿瘤”和“恶性肿块”)语义相近,它们的相似度就会很高;而以往CLIP模型使用的是1的对角矩阵,表示不同的标签之间完全无关。
image-20260324205533036
再加上分割loss就是最终的loss,其中Seg
loss属于常规写法,没有进行改动。
1 2 3 4 5 6 def calc_loss (low_res_logits, low_res_label_batch, ce_loss, dice_loss, cfg ): loss_ce = ce_loss(low_res_logits, low_res_label_batch.float ()) loss_dice = dice_loss(low_res_logits, low_res_label_batch) loss = cfg.TRAIN.DICE_WEIGHT * loss_dice + cfg.TRAIN.CE_WEIGHT * loss_ce return loss
实验
左图为消融实验,右图为不同text数据集(矛盾/缺失位置/过描述/不足描述/原始数据)下的数据。
image-20260324210543922
总结
实际上是在讲如何让 AI 学会“犹豫” 。通过 PVL
Adapter,MedCLIPSeg 将 CLIP
强大的图文匹配能力与贝叶斯深度学习的严谨性结合在一起,使其不仅能完成分割任务,还能实时评估任务的可信度。
PVL Adapter 是 MedCLIPSeg 的“安全阀”和“增强器” 。它让
CLIP
在处理高难度的医学图像时,不仅拥有了双向理解 的能力,更重要的是拥有了“自我怀疑” 的能力——这种能力在医疗诊断中比单纯的准确率更珍贵。