基于注意力的多示例学习可解释性¶
在本次会议中,我们将探讨两种用于理解基于注意力的多示例学习模型行为的事后解释方法,该模型经过训练,用于将肺癌分为LUAD和LUSC两种最常见的肺癌亚型。
关于医疗人工智能应用中可解释性的重要性¶
信任与控制:医疗人工智能系统通常用于直接影响患者健康和生命的关键决策过程中。可解释性旨在提供正确水平的洞察力和对系统的控制,从而建立用户与系统之间的信任和信心。
洞察力和理解:可解释的模型还有助于提供对影响预测的潜在因素或特征的洞察力。在医疗人工智能中,这种理解对于医疗专业人员获取对疾病机制、识别风险因素或发现新的生物标志物的洞察力至关重要。这一方面符合生物标志物发现的方向,我们假设人工智能系统可能使用不同于当前标准的特征(例如,癌症的当前分级标准)。
错误分析与诊断:可解释性有助于进行错误分析,允许识别和理解模型的错误或错误预测。在医疗人工智能中,错误诊断可能会产生严重后果,可解释性使临床医生能够评估模型失败的情况并诊断潜在的问题或限制。这种反馈循环可以指导模型、数据集或特征工程的改进,从而实现更好的性能和更可靠的预测。
而令人向往的考虑...
法律和道德考虑:人工智能模型的可解释性可以用于解决法律和道德问题。在医疗保健领域,由人工智能系统做出的决策需要向患者、医疗专业人员、监管机构和其他利益相关者解释。通过提供可解释性,人工智能系统可以符合法律要求,例如《通用数据保护条例》(GDPR),该法赋予个人对对其产生重大影响的自动化决策进行解释的权利。
安全与鲁棒性:深度学习模型容易受到偏见、对抗性攻击或数据分布转移等问题的影响,这可能导致不正确或不可靠的预测。可解释性有助于检测这些问题,并评估模型的安全性和鲁棒性。通过了解模型的内部工作原理,可以识别潜在的偏见,调查模型可能过于自信或表现不佳的情况,并设计防范措施以减轻风险。
符合法规:在各个领域,包括医疗保健领域,可解释性越来越成为法规要求。监管机构,例如美国食品药品监督管理局(FDA),通常在批准部署之前要求对人工智能系统做出的决定进行解释和说明。可解释性允许审计、验证模型的行为,并符合监管标准,确保符合性和患者安全。
# 所有数据可以从此处下载:https://drive.google.com/drive/folders/1TmAfG7EWC1hjD7cHFGiJzUx2y3jLXdcP?usp=sharing
# 下载后,请将数据转移到您本地克隆的存储库中
use_drive = False
if use_drive:
from google.colab import drive
drive.mount('/content/drive')
!mkdir -p "/content/drive/My Drive/ai4healthsummerschool/"
# 加载和打印之前训练的 ABMIL 模型
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
# 定义全局注意力池化层
class AttentionTanhSigmoidGating(nn.Module):
def __init__(self, D=64, L=64, dropout=0.25):
r"""
具有 tanh 非线性和 sigmoid 门控的全局注意力池化层 (Ilse et al. 2018)。
Args:
D (int): 输入特征维度。
L (int): 隐藏层维度。Notation changed from M from Ilse et al 2018, as M is overloaded to also describe # of patch embeddings in a WSI.
dropout (float): Dropout 概率。
Returns:
A_norm (torch.Tensor): 归一化的注意力分数的 [M x 1] 维张量(总和为 1)
"""
super(AttentionTanhSigmoidGating, self).__init__()
self.tanhV = nn.Sequential(*[nn.Linear(D, L), nn.Tanh(), nn.Dropout(dropout)])
self.sigmU = nn.Sequential(*[nn.Linear(D, L), nn.Sigmoid(), nn.Dropout(dropout)])
self.w = nn.Linear(L, 1)
def forward(self, H, return_raw_attention=False):
A_raw = self.w(self.tanhV(H).mul(self.sigmU(H))) # 指数项
A_norm = F.softmax(A_raw, dim=0) # 应用 softmax 对权重进行归一化为 1
assert abs(A_norm.sum() - 1) < 1e-3 # 断言语句检查 sum(A) ~= 1
if return_raw_attention:
return A_norm, A_raw
return A_norm
# 定义 ABMIL 模型
class ABMIL(nn.Module):
def __init__(self, input_dim=320, hidden_dim=64, dropout=0.25, n_classes=2):
r"""
基于注意力的多实例学习 (Ilse et al. 2018)。
Args:
input_dim (int): 输入特征维度。
hidden_dim (int): 隐藏层维度。
dropout (float): Dropout 概率。
n_classes (int): 类别数量。
"""
super(ABMIL, self).__init__()
self.inst_level_fc = nn.Sequential(*[nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) # 全连接层,对每个嵌入应用 "实例级" 处理
self.global_attn = AttentionTanhSigmoidGating(L=hidden_dim, D=hidden_dim) # 注意力函数
self.bag_level_classifier = nn.Linear(hidden_dim, n_classes) # 包级分类器
def forward(self, X: torch.randn(100, 320), return_raw_attention=False):
r"""
输入 [M x D] 维的图像块特征集合(代表一个 WSI),并输出:1) 分类的对数几率,2) 未归一化的注意力分数。
Args:
X (torch.Tensor): [M x D] 维的图像块特征集合(代表一个 WSI)
Returns:
logits (torch.Tensor): [1 x n_classes] 维的未归一化分类对数几率张量。
A_norm (torch.Tensor): [M,] 或 [M x 1] 维的注意力分数张量。
"""
H_inst = self.inst_level_fc(X) # 1. 处理每个特征嵌入,使其大小为 "hidden-dim"
if return_raw_attention:
A_norm, A_raw = self.global_attn(H_inst, return_raw_attention=True)
else:
A_norm = self.global_attn(H_inst) # 2. 获取每个嵌入的归一化注意力分数(使 sum(A_norm) ~= 1)
z = torch.sum(A_norm * H_inst, dim=0) # 3. 对图像块进行全局注意力池化的输出
logits = self.bag_level_classifier(z).unsqueeze(dim=0) # 4. 获取用于分类任务的未归一化对数几率
try:
assert logits.shape == (1,2)
except:
print(f"Logit tensor shape is not formatted correctly. Should output [1 x 2] shape, but got {logits.shape} shape")
if return_raw_attention:
return logits, A_raw
return logits, A_norm
def captum(self, X: torch.randn(100, 320)):
r"""
输入 [M x D] 维的图像块特征集合(代表一个 WSI),并输出:1) 分类的对数几率,2) 未归一化的注意力分数。
Args:
X (torch.Tensor): [M x D] 维的图像块特征集合(代表一个 WSI)
Returns:
logits (torch.Tensor): [1 x n_classes] 维的未归一化分类对数几率张量。
A_norm (torch.Tensor): [M,] 或 [M x 1] 维的注意力分数张量。
"""
H_inst = self.inst_level_fc(X) # 1. 处理每个特征嵌入,使其大小为 "hidden-dim"
A_norm = self.global_attn(H_inst) # 2. 获取每个嵌入的归一化注意力分数(使 sum(A_norm) ~= 1)
z = torch.sum(A_norm * H_inst, dim=0) # 3. 对图像块进行全局注意力池化的输出
logits = self.bag_level_classifier(z).unsqueeze(dim=0) # 4. 获取用于分类任务的未归一化对数几率
try:
assert logits.shape == (1,2)
except:
print(f"Logit tensor shape is not formatted correctly. Should output [1 x 2] shape, but got {logits.shape} shape")
return logits
device = torch.device('cpu')
model = ABMIL(input_dim=320, hidden_dim=64).to(device)
# 如果使用 Google Drive,则设置路径
if use_drive:
path = '/content/drive/My Drive/ai4healthsummerschool/abmil.ckpt'
else:
path = os.path.join('data', 'checkpoints', 'abmil.ckpt')
model.load_state_dict(torch.load(path)) # 加载模型参数
model.eval() # 将模型设置为评估模式
利用注意力权重解释基于注意力的 MIL 模型¶
注意力权重是深度学习模型中的一种机制,用于确定输入数据的不同部分的重要性或相关性。这些权重可以通过提供关于哪些输入部分对模型的决策过程贡献更大的见解,来解释深度学习的预测。
import numpy as np
import os
import pandas as pd
import torch
# 加载特征及相应的标签
if use_drive: # 如果使用 Google Drive
feats_dirpath = '/content/drive/My Drive/ai4healthsummerschool/feats_pt'
csv_fpath = '/content/drive/My Drive/ai4healthsummerschool/tcga_lung_splits.csv'
else:
feats_dirpath = os.path.join('data', 'processed', 'feats_pt') # 如果不使用 Google Drive,则从本地路径加载
csv_fpath = os.path.join('data', 'processed', 'tcga_lung_splits.csv')
index = 5 # (LUAD 样本)
csv = pd.read_csv(csv_fpath) # 读取 CSV 文件
which_labelcol = 'OncoTreeCode_Binarized' # 标签列名
csv_split = csv[csv['split'] == 'test'] # 选择测试集数据
features = torch.load(os.path.join(feats_dirpath, csv_split.iloc[index]['slide_id'] + '.pt')) # 加载特定样本的特征数据
label = csv_split.iloc[index][which_labelcol] # 获取标签值
print('我们将分析以下预测和注意力分数:')
print(csv_split.iloc[index]['slide_id'] + '.pt') # 输出文件名
print('标签:', label) # 输出标签值
print('特征:', features.shape) # 输出特征数据的形状
# 进行推断并存储注意力权重
logits, attention = model(features, return_raw_attention=True) # 使用模型进行推断,返回原始注意力权重
logits = logits.squeeze() # 去除维度为1的维度
attention = attention.squeeze().detach().numpy() # 去除维度为1的维度并转换为 NumPy 数组
print(logits) # 输出预测值
print('形状:', attention.shape, '最小值:', np.min(attention).item(), '最大值:', np.max(attention).item()) # 输出注意力权重的形状、最小值和最大值
!pip install openslide-python
import h5py
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from openslide import OpenSlide
# 加载与感兴趣样本对应的坐标
if use_drive: # 如果使用 Google Drive
path = '/content/drive/My Drive/ai4healthsummerschool/TCGA-35-3615-01Z-00-DX1.585128eb-6652-4b05-9a83-dc8f242904a6_patches.h5'
else:
path = os.path.join('data', 'processed', 'TCGA-35-3615-01Z-00-DX1.585128eb-6652-4b05-9a83-dc8f242904a6_patches.h5')
with h5py.File(path, "r") as f:
coords = f['coords'][:] # 读取坐标数据
print('坐标:', coords.shape) # 输出坐标数据的形状
# 使用 OpenSlide 加载整个切片图像
if use_drive: # 如果使用 Google Drive
slide_path = '/content/drive/My Drive/ai4healthsummerschool/TCGA-35-3615-01Z-00-DX1.585128eb-6652-4b05-9a83-dc8f242904a6.svs'
else:
slide_path = os.path.join('data', 'slides', 'TCGA-35-3615-01Z-00-DX1.585128eb-6652-4b05-9a83-dc8f242904a6.tiff')
wsi = OpenSlide(slide_path) # 打开切片图像
def draw_heatmap(scores, coords, wsi, vis_level=-1,
patch_size=(256, 256),
blank_canvas=False, canvas_color=(220, 20, 50), alpha=0.4,
overlap=0.0, use_holes=True,
convert_to_percentiles=False, thresh=0.5,
max_size=None, custom_downsample=4,
cmap='coolwarm'):
"""
绘制热图
Args:
scores (numpy array of float): 注意力分数
coords (numpy array of int, n_patches x 2): 对应的坐标(相对于级别0)
wsi (openslide): 使用 openslide 打开的 WSI
vis_level (int): 要可视化的 WSI 金字塔级别
patch_size (tuple of int): 补丁尺寸(相对于级别0)
blank_canvas (bool): 是否使用空白画布绘制热图(而不是使用原始切片)
canvas_color (tuple of uint8): 画布颜色
alpha (float [0, 1]): 覆盖热图到原始切片上的混合系数
overlap (float [0 1]): 相邻补丁之间的重叠百分比(仅影响模糊的半径)
use_holes (bool): 是否还剪裁出检测到的组织空洞(仅在 segment == True 时有效)
convert_to_percentiles (bool): 是否将注意力分数转换为百分位数
thresh (float): 二值化阈值
max_size (int): 最大画布大小(如果超过则裁剪)
custom_downsample (int): 通过指定的因子另外缩小热图
cmap (str): 要使用的 matplotlib 颜色映射的名称
"""
downsample = (0.25, 0.25)
patch_size = np.ceil(np.array(patch_size)).astype(int)
coords = np.ceil(coords * np.array(downsample)).astype(int)
region_size = wsi.level_dimensions[vis_level]
w, h = region_size
print('\n创建热图:')
print('宽度:{},高度:{}'.format(w, h))
print('缩放后的补丁尺寸:', patch_size)
# 热图叠加:跟踪每个热图像素上的注意力分数
# 叠加计数器:跟踪每个热图像素上积累的注意力分数的次数
overlay = np.full(np.flip(region_size), 0).astype(float)
counter = np.full(np.flip(region_size), 0).astype(np.uint16)
count = 0
for idx in range(len(coords)):
score = scores[idx].item()
coord = coords[idx]
# 累积分数
overlay[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0]] += score
# 累积计数
counter[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0]] += 1
# 获取关注区域并平均累积的注意力
zero_mask = counter == 0
overlay[~zero_mask] = overlay[~zero_mask] / counter[~zero_mask]
del counter
img = np.array(wsi.read_region((0, 0), vis_level, region_size).convert("RGB"))
print('\n计算热图图像')
print('总共 {} 个补丁'.format(len(coords)))
twenty_percent_chunk = max(1, int(len(coords) * 0.2))
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
norm = plt.Normalize(scores.min(), scores.max())
for idx in range(len(coords)):
if (idx + 1) % twenty_percent_chunk == 0:
print('进度:{}/{}'.format(idx, len(coords)))
score = scores[idx].item()
coord = coords[idx]
# 注意力块
raw_block = overlay[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0]]
# 图像块(空白画布或原始图像)
img_block = img[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0]].copy()
# 颜色块(将颜色映射应用于注意力块)
color_block = (cmap(norm(raw_block)) * 255)[:,:,:3].astype(np.uint8)
# 复制整个颜色块
img_block = color_block
# 重写图像块
img[coord[1]:coord[1]+patch_size[1], coord[0]:coord[0]+patch_size[0]] = img_block.copy()
#return Image.fromarray(img) #overlay
print('完成')
del overlay
img = Image.fromarray(img)
w, h = img.size
if custom_downsample > 1:
img = img.resize((int(w/custom_downsample), int(h/custom_downsample)))
return img
heatmap = draw_heatmap(
scores=attention, # 注意力分数
coords=coords, # 坐标
wsi=wsi, # WSI
use_holes=True, # 是否使用孔洞
vis_level=1, # 可视化级别
blank_canvas=False, # 是否使用空白画布
convert_to_percentiles=False # 是否转换为百分位数
)
heatmap.save(os.path.join('data', 'interpretability', 'attention_heatmap.png')) # 保存热图为图片文件
# 提取最重要的补丁
n_samples = 10
# 对注意力权重进行排序
to_keep = np.argsort(attention)[-n_samples:]
scores_to_keep = attention[to_keep]
coords_to_keep = coords[to_keep]
print('要保留的索引:', to_keep)
print('分数:', scores_to_keep)
for idx in range(n_samples):
patch = wsi.read_region(coords_to_keep[idx], 1, (256, 256)) # 读取指定坐标处的补丁
patch.show() # 显示补丁图像
使用集成梯度解释¶
在这一部分中,我们将使用 Captum,一个开源软件包,提供即用即得的后期解释技术,包括集成梯度(IG)。
!pip install captum
符号表示:
- 假设有一个深度学习模型,其输入向量为
x
,输出为f(x)
。 - 基线或参考点表示为
x'
,通常选择为复杂度较低的点(例如,全零或随机噪声)。 - 对于每个输入特征
i
,归因分数表示为A_i
。
- 假设有一个深度学习模型,其输入向量为
梯度计算:
计算模型输出关于输入特征的梯度:
$$\vec{\nabla} f(x) = \left(\frac{\partial f(x)}{\partial x_1}, \frac{\partial f(x)}{\partial x_2}, \ldots, \frac{\partial f(x)}{\partial x_n}\right)$$
集成梯度公式:
计算每个特征
i
的集成梯度分数如下:$$A_i = (x_i - x'_i) \times \int_{\alpha=0}^1 \left(\frac{\partial f(x'+\alpha(x-x'))}{\partial x_i}\right) d\alpha$$
解释:
- 集成梯度通过考虑输入
x
与基线x'
之间的差异来计算每个特征i
的贡献。 - 然后,它沿着从基线
x'
到输入x
的直线路径积分模型输出关于特征i
的梯度。 - 积分是在从 0 到 1 的一系列步骤(α)上计算的,表示基线和输入之间的插值。
- 每个插值点处的梯度衡量了输出对于从基线到输入的变化的特征
i
的敏感性。 - 特征
i
的贡献乘以该特征的输入和基线之间的差异,捕获了该特征导致的模型输出变化。
- 集成梯度通过考虑输入
实施步骤:
- 选择一个基线点
x'
(全零、随机噪声或其他相关选择)。 - 定义积分的步数或间隔数。
- 对于从 0 到 1 的每个步骤 α,计算
x'+α(x-x')
作为中间输入。 - 计算每个步骤的模型输出关于中间输入的梯度。
- 累积梯度,并将其乘以每个特征
i
的输入和基线之间的差异,以计算归因分数A_i
。
- 选择一个基线点
集成梯度技术提供了特征级别的归因分数,使我们能够了解模型预测中每个输入特征的重要性。通过可视化或分析这些分数,我们可以深入了解哪些特征对于模型的决策过程具有影响力或关键性。
# 使用 captum 获取 Integrated Gradients(IG)分数。
from captum.attr import IntegratedGradients # 导入 Integrated Gradients
num_classes = 2 # 类别数
def interpret_sample(features):
return model.captum(X=features) # 定义用于解释样本的函数
ig = IntegratedGradients(interpret_sample) # 创建 Integrated Gradients 对象
features.requires_grad_() # 将特征设置为需要梯度
patch_preds = [] # 用于存储补丁预测的列表
for target in range(num_classes):
# 计算 IG 属性
ig_attr = ig.attribute((features), n_steps=50, target=target)
ig_attr = ig_attr.squeeze().sum(dim=1).cpu().detach() # 将 IG 属性压缩并计算每个特征的总和
patch_preds.append(ig_attr) # 将 IG 属性添加到列表中
patch_preds = torch.stack(patch_preds, dim=0) # 将列表转换为张量,并在指定维度上堆叠
print(patch_preds.shape) # 打印补丁预测的形状
# 显示 IG 分数的直方图
scores = patch_preds[0, :].detach().numpy() # 提取 IG 分数并转换为 NumPy 数组
bins = np.linspace(scores.min(), scores.max(), 50) # 创建直方图的分箱
plt.hist(scores, bins, histtype='bar', rwidth=0.8) # 绘制直方图
plt.show() # 显示直方图
# 使用 Captum 绘制热图
heatmap = draw_heatmap(
scores=patch_preds[0, :], # IG 分数
coords=coords, # 坐标
wsi=wsi, # WSI
cmap='jet', # 颜色映射
alpha=1.0, # 透明度
use_holes=True, # 是否使用孔洞
vis_level=1, # 可视化级别
blank_canvas=False, # 是否使用空白画布
convert_to_percentiles=False # 是否转换为百分位数
)
heatmap.save(os.path.join('data', 'interpretability', 'ig_heatmap.png')) # 保存热图为图片文件
# 提取最重要的补丁
n_samples = 10
# 对 IG 分数进行排序
to_keep = np.argsort(patch_preds[0, :])[-n_samples:]
scores_to_keep = patch_preds[0, to_keep]
coords_to_keep = coords[to_keep]
print('要保留的索引:', to_keep)
print('分数:', scores_to_keep)
for idx in range(n_samples):
patch = wsi.read_region(coords_to_keep[idx], 1, (256, 256)) # 读取指定坐标处的补丁
patch.show() # 显示补丁图像
The following link at http://clam.mahmoodlab.org visualizes high-attention heatmaps for LUAD vs LUSC subtyping via CLAM (similar to ABMIL
) and confidence scores for each slides.
讨论¶
IG 相对于注意力方法在模型解释性方面的主要优势是什么?
特征归因方法的局限性是什么?
解释性和可解释性之间有什么区别?将其与对控制的概念联系起来。
如果你是一名临床病理学家,观察到这些可视化效果,你会对让 AI 算法协助你进行医学诊断有什么见解或担忧?