在计算病理学中用于癌症诊断的弱监督深度学习¶
- 主讲人:
- Guillaume Jaume (gjaume@bwh.harvard.edu)
- 哈佛医学院和布里格姆妇女医院博士后研究员
- 最初由Richard J. Chen (richardchen@g.harvard.edu)提出和撰写
定义:
计算病理学(CPath): 基于细胞和组织的显微分析的计算方法,用于研究疾病。
数字病理学: 用于在数字环境中获取、管理和诊断病理玻片的一组工具和系统。
全玻片图像(WSI): 使用扫描仪对玻片进行高分辨率数字化得到的图像。
溴甘醇-威廉士(H&E)染色: 组织学分析的参考染色方法,用于可视化细胞核(紫色)和细胞外信息以及细胞质(粉红色)。
背景:
计算病理学旨在利用基于人工智能的计算工具自动化、辅助和增强病理学临床实践。
组织表型分型是计算病理学(CPATH)中的一个基本问题,用于表征癌症诊断、预后和治疗反应的组织病理学特征。与自然图像不同,全玻片成像是一个具有挑战性的计算机视觉领域,其中图像分辨率可以高达$150,000 \times 150,000$像素(加载整个图像需要超过50 GB的内存)。
为了解决这个计算和内存瓶颈问题,大多数最先进的方法使用了一个基于多实例学习(MIL)的三阶段弱监督流程:
在单个放大倍率("zoom")下对组织进行补丁化,例如20倍放大倍率。
对补丁级特征进行提取,以构建一组补丁嵌入(将补丁压缩100~500倍)。
对嵌入进行全局汇集,以构建一个幻灯片级别的表示,用于使用幻灯片级别标签(例如亚型、等级、阶段、生存、来源)的弱监督。
笔记本目标: 以下教程旨在区分肺腺癌(LUAD,所有肺癌的40%)与肺鳞状细胞癌(LUSC,所有肺癌的30%)(参见Lu等人,Nature BME 2021和代码库CLAM)。具体来说,我们将:
训练和评估一个名为
AverageMIL
的"朴素" MIL 算法,该算法取补丁嵌入的平均值(作为全局汇集算子)。实现一个更复杂的算法,称为基于注意力的多实例学习(
ABMIL
),该算法学习注意力权重来计算补丁嵌入的加权平均值。比较和对比
AverageMIL
和ABMIL
,讨论哪种算法更好,并讨论潜在限制。
关于这个笔记本:
模型实现和训练直接改编自CLAM。CLAM包含许多额外功能(例如 - 允许用户设置优化器、模型类型、日志信息和其他超参数),但由于教学目的,这里简化了这个笔记本的运行。要使用所有功能,请参阅CLAM。
尽管这个笔记本是基于CLAM构建的,但您将要实现的感兴趣的方法不是CLAM,而是来自Ilse等人,ICML 2018的另一种方法,称为ABMIL。
虽然预提取特征是使用CLAM代码库生成的,但编码器不是在 ImageNet 上预训练的截断 ResNet-50(维度 1024)在20倍分辨率下。相反,我们使用了一个更小的CNN编码器(维度 320)在10倍分辨率下提取特征,这将数据集的大小从约11 GB 缩小到约3.96 GB 的存储空间(在下面的单元格中提供了预提取特征的下载链接)。此外,为了可重现性设置了一个torch.seed(所有输出应该是确定性的)。
Colab 安装、数据下载和依赖项¶
- 获取预定义的 tcga-luad 和 tcga-lusc 的临床元数据 csv 文件,带有预先定义的训练/验证/测试拆分
- 获取 tcga-luad 和 tcga-lusc 诊断 WSI 的预提取特征(共计 1043 个 WSI,大小约为 3.96 GB,下载时间约为 67 秒)
或者,您可以直接从 Dropbox 下载数据到您的本地计算机,并在本地运行此 Colab 笔记本。
use_drive = False
if use_drive:
from google.colab import drive
drive.mount('/content/drive')
!mkdir -p "/content/drive/My Drive/ai4healthsummerschool/"
# either download in colab (data will be deleted when re-starting) or mount your labdrive (preferred, but requires 4GB of storage)
if use_drive:
!wget https://www.dropbox.com/s/5wuvu791vwntg9o/tcga_lung_splits.csv -P "/content/drive/My Drive/ai4healthsummerschool"
!wget https://www.dropbox.com/s/euepd2owxvuwr7v/feats_pt.zip
!unzip -q feats_pt.zip
!mv feats_pt "/content/drive/My Drive/ai4healthsummerschool"
else:
!wget https://www.dropbox.com/s/5wuvu791vwntg9o/tcga_lung_splits.csv
!wget https://www.dropbox.com/s/euepd2owxvuwr7v/feats_pt.zip
!unzip -q feats_pt.zip
--2023-07-05 17:53:04-- https://www.dropbox.com/s/5wuvu791vwntg9o/tcga_lung_splits.csv Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312 Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected. HTTP request sent, awaiting response... 302 Found Location: /s/raw/5wuvu791vwntg9o/tcga_lung_splits.csv [following] --2023-07-05 17:53:04-- https://www.dropbox.com/s/raw/5wuvu791vwntg9o/tcga_lung_splits.csv Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc2584411338e312d2dc01a897a7.dl.dropboxusercontent.com/cd/0/inline/B_RI62q1I4vP4bVo7yBoP0o0nLUydn_k3nt5sJPm3p2prkcqhVBmRl_VY0lgTyiDQftl3nJuiQ634aBMTMeXHpwgK1S5-lNLTyiHfgyBRiCiYZJxAOMWbb0Ey7xSFm28sAbvx0tXuUgMV_5xNW3WWGxB_aVb5KTeYe7kQu4VIp8MCg/file# [following] --2023-07-05 17:53:04-- https://uc2584411338e312d2dc01a897a7.dl.dropboxusercontent.com/cd/0/inline/B_RI62q1I4vP4bVo7yBoP0o0nLUydn_k3nt5sJPm3p2prkcqhVBmRl_VY0lgTyiDQftl3nJuiQ634aBMTMeXHpwgK1S5-lNLTyiHfgyBRiCiYZJxAOMWbb0Ey7xSFm28sAbvx0tXuUgMV_5xNW3WWGxB_aVb5KTeYe7kQu4VIp8MCg/file Resolving uc2584411338e312d2dc01a897a7.dl.dropboxusercontent.com (uc2584411338e312d2dc01a897a7.dl.dropboxusercontent.com)... 162.125.3.15, 2620:100:601d:15::a27d:50f Connecting to uc2584411338e312d2dc01a897a7.dl.dropboxusercontent.com (uc2584411338e312d2dc01a897a7.dl.dropboxusercontent.com)|162.125.3.15|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 170283 (166K) [text/plain] Saving to: ‘tcga_lung_splits.csv’ tcga_lung_splits.cs 100%[===================>] 166.29K --.-KB/s in 0.04s 2023-07-05 17:53:05 (4.58 MB/s) - ‘tcga_lung_splits.csv’ saved [170283/170283] --2023-07-05 17:53:05-- https://www.dropbox.com/s/euepd2owxvuwr7v/feats_pt.zip Resolving www.dropbox.com (www.dropbox.com)... 162.125.3.18, 2620:100:6018:18::a27d:312 Connecting to www.dropbox.com (www.dropbox.com)|162.125.3.18|:443... connected. HTTP request sent, awaiting response... 302 Found Location: /s/raw/euepd2owxvuwr7v/feats_pt.zip [following] --2023-07-05 17:53:05-- https://www.dropbox.com/s/raw/euepd2owxvuwr7v/feats_pt.zip Reusing existing connection to www.dropbox.com:443. HTTP request sent, awaiting response... 302 Found Location: https://uc7c96ff4e834a7fbe7fe428b46e.dl.dropboxusercontent.com/cd/0/inline/B_RH2SZ-Ta92Wb-qrBDceeYXa-JB1eVjyWKjVeBNspWsJzpzE331m1Fg50O1tjhKdRvPNi-rY0WQRZRLPJGZ-972BLL1WAgvgaayRSOEmhtKCJckE7ep5Hilf393DzseUJZ8QZBF82mEjRu9Yc6PopDf6exI0kLJffoA63AVsw_7ZA/file# [following] --2023-07-05 17:53:06-- https://uc7c96ff4e834a7fbe7fe428b46e.dl.dropboxusercontent.com/cd/0/inline/B_RH2SZ-Ta92Wb-qrBDceeYXa-JB1eVjyWKjVeBNspWsJzpzE331m1Fg50O1tjhKdRvPNi-rY0WQRZRLPJGZ-972BLL1WAgvgaayRSOEmhtKCJckE7ep5Hilf393DzseUJZ8QZBF82mEjRu9Yc6PopDf6exI0kLJffoA63AVsw_7ZA/file Resolving uc7c96ff4e834a7fbe7fe428b46e.dl.dropboxusercontent.com (uc7c96ff4e834a7fbe7fe428b46e.dl.dropboxusercontent.com)... 162.125.3.15, 2620:100:6030:15::a27d:500f Connecting to uc7c96ff4e834a7fbe7fe428b46e.dl.dropboxusercontent.com (uc7c96ff4e834a7fbe7fe428b46e.dl.dropboxusercontent.com)|162.125.3.15|:443... connected. HTTP request sent, awaiting response... 302 Found Location: /cd/0/inline2/B_Tx61hoZe6I-K4BRWWflNltZKjWw5s2PfGB6DEjTpTHA25qZfXhWiDYX7RyiTvR8wnPI2WVgyOWHZ-WNNPpv0C8KoZ2K93ZrV8NQX_ASv09Apnn9BERSo9BnJfHwuSPZ2O5lMi396MkPxKGYxlYBa5LbMs9ob62nuyPzNFT15LnMvN4tFN0sDuAhexMObdwpBbSmaegpcre3nRvVYbkDK8IPPR2tovWHsMXz0mYBBHwYGfcNUVFZ18s-pUBvxNWeIR__AJ5PYZSPSih0R3g_nAZVb8JLUzwdAqRvxWt4vJO3gwMPGf9EOIwrnP8n9UbZa6pLqTN91bFmsPzMNAay_bUDTaJ14rENkTbjB7kDDRoYi4Y8bsohneP_C_d3AyIos7O38zSrm3oZWdx8g58tYEWI2VIp-82kozJhWiF32jXdQ/file [following] --2023-07-05 17:53:06-- https://uc7c96ff4e834a7fbe7fe428b46e.dl.dropboxusercontent.com/cd/0/inline2/B_Tx61hoZe6I-K4BRWWflNltZKjWw5s2PfGB6DEjTpTHA25qZfXhWiDYX7RyiTvR8wnPI2WVgyOWHZ-WNNPpv0C8KoZ2K93ZrV8NQX_ASv09Apnn9BERSo9BnJfHwuSPZ2O5lMi396MkPxKGYxlYBa5LbMs9ob62nuyPzNFT15LnMvN4tFN0sDuAhexMObdwpBbSmaegpcre3nRvVYbkDK8IPPR2tovWHsMXz0mYBBHwYGfcNUVFZ18s-pUBvxNWeIR__AJ5PYZSPSih0R3g_nAZVb8JLUzwdAqRvxWt4vJO3gwMPGf9EOIwrnP8n9UbZa6pLqTN91bFmsPzMNAay_bUDTaJ14rENkTbjB7kDDRoYi4Y8bsohneP_C_d3AyIos7O38zSrm3oZWdx8g58tYEWI2VIp-82kozJhWiF32jXdQ/file Reusing existing connection to uc7c96ff4e834a7fbe7fe428b46e.dl.dropboxusercontent.com:443. HTTP request sent, awaiting response... 200 OK Length: 3961539230 (3.7G) [application/zip] Saving to: ‘feats_pt.zip’ feats_pt.zip 100%[===================>] 3.69G 89.3MB/s in 42s 2023-07-05 17:53:48 (90.9 MB/s) - ‘feats_pt.zip’ saved [3961539230/3961539230]
import os
import copy
import matplotlib.pyplot as plt
import seaborn
import numpy as np
import pandas as pd
import sklearn.metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
print(torch.__version__)
2.0.1+cu118
WSI data preprocessing for histology slides in the TCGA-Lung cohort¶
要处理WSIs,通常使用诸如CLAM之类的工具进行组织补丁和非重叠补丁特征提取。尽管易于使用,但使用CLAM进行特征处理需要下载千亿像素的WSIs(TCGA-LUAD和TCGA-LUSC中超过1000个WSIS),超过100GB的存储空间。为了缓解这个问题,这个问题集提供了预提取的特征(通过CLAM处理,但使用了一个大小远远小于$D=320$的视觉编码器)。然而,为了仍然说明CLAM预处理的工作原理,下面的单元格描述了WSIs如何被构造为[M x D]
维的补丁嵌入包,其中M
是组织补丁的数量,D
是您编码器的隐藏维度大小。再次强调,如果您有兴趣重新生成这些特征,请使用CLAM。
注意: 不需要运行此单元格来训练最终模型!
# 假设我们有一个"M个[256 x 256 x 3]图像补丁的集合(M = 512),这些补丁是从WSI中的非重叠补丁中获取的。
M = 2
X = torch.randn(M, 3, 256, 256) # 排列格式为(Batch,Channel,Width,Height)或简写为(B,C,W,H)
print("WSI形状:", X.shape)
# 例如,我们将使用CNN模型(在ImageNet上预训练)作为我们的视觉编码器,用于从每个补丁中预提取“压缩”的表示。
cnn = torchvision.models.mobilenet_v3_small()
cnn.eval()
# 由于这个模型来自torchvision,并在ImageNet上训练,模型的输出是ImageNet类的概率分数(总共1000个类)。
# 要从每个补丁中提取有用的特征,我们必须使用CNN的倒数第二层(或倒数第二层)输出,然后将其馈送到线性层中。
print("ImageNet的概率分数:", cnn.forward(X[:1]).shape)
# 为了提取倒数第二层的特征,我们可以定义一个新函数,该函数返回特征,而不是将其馈送到模型内部的分类器层中。
# 再次强调,我们希望使用在ImageNet上预训练的特征,但不需要“ImageNet”类别的分类分数!
# 请参见下面的文档,了解MobileNetV3中forward pass的工作原理。
# https://pytorch.org/vision/main/_modules/torchvision/models/mobilenetv3.html#mobilenet_v3_small
encoder = lambda x: torch.flatten(cnn.avgpool(cnn.features(x)), 1)
print("特征嵌入形状:", encoder(X[:1]).shape)
# 现在,我们可以使用我们的编码器来提取每个补丁的特征。
# 通常,WSI中非重叠补丁的数量约为15000个。因此,我们经常需要以小批量方式提取补丁特征。
batch_size = 32
H = []
for bag_idx in range(0, M, batch_size):
H.append(encoder(X[bag_idx:(bag_idx+batch_size)]).cpu().detach().numpy())
print("Bag形状", np.vstack(H).shape)
WSI Shape: torch.Size([2, 3, 256, 256]) Probability Scores for ImageNet: torch.Size([1, 1000]) Feature Embedding Shape: torch.Size([1, 576]) Bag Shape (2, 576)
数据探索¶
注意: 不需要运行此单元格来训练最终模型!
# 我们下载特征和标签csv的位置
use_drive = False
if use_drive:
feats_dirpath, csv_fpath = '/content/drive/My Drive/ai4healthsummerschool/feats_pt/', '/content/drive/My Drive/ai4healthsummerschool/tcga_lung_splits.csv'
else:
feats_dirpath, csv_fpath = './feats_pt/', './tcga_lung_splits.csv'
# 标签csv匹配case_id(患者)、slide_id(WSI图像文件名)和诊断(LUAD vs LUSC)
# 以及预定义的拆分(train / val / test)
df = pd.read_csv(csv_fpath)
display(df)
display(df[['split', 'OncoTreeCode']].value_counts())
# 提取的特征文件名+slide_id列匹配
feats_pt_fnames = pd.Series(os.listdir(feats_dirpath))
print("提取的特征的示例文件名:", list(feats_pt_fnames[:5]))
print("提取的特征文件名与slide_id列的重叠:",
len(set(df['slide_id']).intersection(set(feats_pt_fnames.str[:-3]))))
# 每个包的大小的统计信息
bag_sizes = []
for e in os.scandir(feats_dirpath):
feats_pt = torch.load(e.path) # [M x d]-dim tensor
bag_sizes.append(feats_pt.shape[0])
print('平均包大小:', np.mean(bag_sizes))
print('包大小标准差:', np.std(bag_sizes))
case_id | slide_id | tumor_type | OncoTreeSiteCode | main_cancer_type | sex | project_id | Diagnosis | OncoTreeCode | OncoTreeCode_Binarized | split | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | TCGA-73-4676 | TCGA-73-4676-01Z-00-DX1.4d781bbc-a45e-4f9d-b6b... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
1 | TCGA-MP-A4T6 | TCGA-MP-A4T6-01Z-00-DX1.085C4F5A-DB1B-434A-9D6... | Primary | LUNG | Non-Small Cell Lung Cancer | F | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
2 | TCGA-78-7167 | TCGA-78-7167-01Z-00-DX1.f79e1a9b-a3eb-4c91-a1f... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
3 | TCGA-L9-A444 | TCGA-L9-A444-01Z-00-DX1.88CF6F01-0C1F-4572-81E... | Primary | LUNG | Non-Small Cell Lung Cancer | F | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
4 | TCGA-55-8097 | TCGA-55-8097-01Z-00-DX1.2f847b65-a5dc-41be-9dd... | Primary | LUNG | Non-Small Cell Lung Cancer | F | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1038 | TCGA-21-A5DI | TCGA-21-A5DI-01Z-00-DX1.E9123261-ADE7-468C-9E9... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUSC | Lung Squamous Cell Carcinoma | LUSC | 1 | test |
1039 | TCGA-77-7465 | TCGA-77-7465-01Z-00-DX1.25e4b0b4-4948-432f-801... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUSC | Lung Squamous Cell Carcinoma | LUSC | 1 | test |
1040 | TCGA-34-8454 | TCGA-34-8454-01Z-00-DX1.A2308ED3-E430-4448-853... | Primary | LUNG | Non-Small Cell Lung Cancer | F | TCGA-LUSC | Lung Squamous Cell Carcinoma | LUSC | 1 | test |
1041 | TCGA-77-7138 | TCGA-77-7138-01Z-00-DX1.8c912762-0829-4692-92a... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUSC | Lung Squamous Cell Carcinoma | LUSC | 1 | test |
1042 | TCGA-77-8131 | TCGA-77-8131-01Z-00-DX1.dcb8e2c7-0d2f-4b38-9db... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUSC | Lung Squamous Cell Carcinoma | LUSC | 1 | test |
1043 rows × 11 columns
split OncoTreeCode train LUAD 433 LUSC 415 test LUAD 49 LUSC 49 val LUAD 49 LUSC 48 dtype: int64
Example filenames for extracted features: ['TCGA-75-6205-01Z-00-DX1.B75BC6BA-5196-4F62-BDA2-3F1D320ABD7C.pt', 'TCGA-77-8153-01Z-00-DX1.E8E40968-E7AD-4EA2-A832-8EC04D5CB7A1.pt', 'TCGA-56-8083-01Z-00-DX1.140c8d5b-f660-4fef-b8da-6bb2c119c021.pt', 'TCGA-73-A9RS-01Z-00-DX1.EDCEFE41-61E2-48C9-B8D5-28B55372E0CA.pt', 'TCGA-56-8201-01Z-00-DX1.883903fb-d70d-4c72-be76-6788b1bc3b35.pt'] Overlap of extracted feature filenames + slide_id column: 1043 Mean Bag Size: 3259.9090038314175 Std Bag Size: 2133.97437395412
模型 1: AverageMIL¶
实现了一个最简单的训练设置,通过 AverageMIL
在 LUAD vs. LUSC 亚型上进行弱监督学习,使用了来自癌症基因组图谱(The Cancer Genome Atlas)的1043个诊断 H&E 组织切片(特征已经预先提取并从安装中下载,所有案例和切片ID的临床元数据也已经下载)。
您可以在 Google Colab 笔记本中运行这些单元格,并查看该算法在 20 个周期内的性能如何。
class AverageMIL(nn.Module):
def __init__(self, input_dim=320, hidden_dim=64, dropout=0.25, n_classes=2):
r"""
AverageMIL, 一个简单的MIL算法,将所有补丁特征进行平均池化。
Args:
input_dim (int): 输入特征维度。
hidden_dim (int): 隐藏层维度。
dropout (float): Dropout概率。
n_classes (int): 类别数。
"""
super(AverageMIL, self).__init__()
self.inst_level_fc = nn.Sequential(*[nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) # "实例级"对每个嵌入应用的全连接层
self.bag_level_classifier = nn.Linear(hidden_dim, n_classes) # 包级别分类器
def forward(self, H):
r"""
接受一个[M x D]维的补丁特征包(表示一个WSI),并输出:1) 用于分类的logits,2) 未归一化的注意力分数。
Args:
H (torch.Tensor): [M x D]维的补丁特征包(表示一个WSI)
Returns:
logits (torch.Tensor): [1 x n_classes]维张量,用于分类任务的未归一化logits。
None (不返回注意力分数)
"""
H = self.inst_level_fc(H) # 1. 预处理每个“实例级”嵌入,使其变为“隐藏维”维度大小
z = H.mean(dim=0).unsqueeze(dim=0) # 2. 补丁嵌入的平均值
logits = self.bag_level_classifier(z) # 3. 包级别分类器
return logits, None
class MILDataset(torch.utils.data.dataset.Dataset):
r"""
torch.utils.data.dataset.Dataset对象,从CSV中加载每个WSI的预提取特征。
Args:
feats_dirpath (str): 预提取补丁特征的路径(假设这些特征保存为带有对应slide_id的文件名的*.pt对象)
csv_fpath (str): 包含以下内容的CSV文件路径:1) Case ID, 2) Slide ID, 3) 拆分信息(train / val / test),以及4) 用于分类的标签列。
which_split (str): 用于子集化CSV的拆分(选择:['train', 'val', 'test'])
n_classes (int): 类别数(默认为2,用于LUAD vs LUSC亚型)
"""
def __init__(self, feats_dirpath='./', csv_fpath='./tcga_lung_splits.csv', which_split='train', which_labelcol='OncoTreeCode_Binarized'):
self.feats_dirpath, self.csv, self.which_labelcol = feats_dirpath, pd.read_csv(csv_fpath), which_labelcol
self.csv_split = self.csv[self.csv['split']==which_split]
def __getitem__(self, index):
features = torch.load(os.path.join(self.feats_dirpath, self.csv_split.iloc[index]['slide_id']+'.pt'))
label = self.csv_split.iloc[index][self.which_labelcol]
return features, label
def __len__(self):
return self.csv_split.shape[0]
def traineval_epoch(epoch, model, loader, optimizer=None, loss_fn=nn.CrossEntropyLoss(), split='train', device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), verbose=1, print_every=300):
r"""
执行一轮训练/评估的函数,使用torch.nn模型在torch.utils.data.DataLoader对象上。
通常,这些函数分别为训练和验证定义,但为了节省行数,我们将它们合并在一起。
Args:
epoch (int): 当前训练/评估的轮次(用于记录)。
model (torch.nn): 用于处理补丁特征包的MIL模型。
loader (torch.utils.data.DataLoader): 获取每个WSI的补丁特征包的对象。
loss_fn (torch.nn): 损失函数。
split (str): 使用的拆分,用于设置模型 + 计算损失 + 计算梯度。
device (torch): 表示将分配到的设备上的torch.Tensor的对象。
verbose (int): 是否打印摘要epoch结果(verbose >=1)和迭代信息(verbose >=2)。
print_every (int): 每多少个批次迭代打印一次
Returns:
log_dict (dict): 用于记录训练/验证/测试拆分的损失和性能的字典。
"""
model.train() if (split == 'train') else model.eval() # 根据拆分设置模型是否用于训练或评估
total_loss, Y_probs, labels = 0.0, [], [] # 跟踪损失 + logits/标签用于性能指标
for batch_idx, (X_bag, label) in enumerate(loader):
# 由于我们假设批量大小== 1,我们希望阻止torch将我们的补丁特征包作为[1 x M x D] torch张量进行整理。
X_bag, label = X_bag[0].to(device), label.to(device)
if (split == 'train'):
logits, A_norm = model(X_bag)
loss = loss_fn(logits, label)
loss.backward(), optimizer.step(), optimizer.zero_grad()
else:
with torch.no_grad(): logits, A_norm = model(X_bag)
loss = loss_fn(logits, label)
# 跟踪总损失、logits和当前进度
total_loss += loss.item()
Y_probs.append(torch.softmax(logits, dim=-1).cpu().detach().numpy())
labels.append(label.cpu().detach().numpy())
if ((batch_idx + 1) % print_every == 0) and (verbose >= 2):
print(f'Epoch {epoch}:\t Batch {batch_idx}\t Avg Loss: {total_loss / (batch_idx+1):.04f}\t Label: {label.item()}\t Bag Size: {X_bag.shape[0]}')
# 从保存的logits /标签计算平衡准确度和AUC-ROC
Y_probs, labels = np.vstack(Y_probs), np.concatenate(labels)
log_dict = {f'{split} loss': total_loss/len(loader),
f'{split} acc': sklearn.metrics.balanced_accuracy_score(labels, Y_probs.argmax(axis=1)),
f'{split} auc': sklearn.metrics.roc_auc_score(labels, Y_probs[:, 1])}
# 打印epoch结束信息
if (verbose >= 1):
print(f'### ({split.capitalize()} Summary) ###')
print(f'Epoch {epoch}:\t' + f'\t'.join([f'{k.capitalize().rjust(10)}: {log_dict[k]:.04f}' for k,v in log_dict.items()]))
return log_dict
# 设置随机种子(用于可重现性)
torch.manual_seed(2023)
# 获取用于训练-验证-测试拆分评估的数据加载器
feats_dirpath, csv_fpath = './feats_pt/', './tcga_lung_splits.csv'
loader_kwargs = {
'batch_size': 1,
'num_workers': 2,
'pin_memory': False
}
train_dataset, val_dataset, test_dataset = [MILDataset(feats_dirpath, csv_fpath, which_split=split) for split in ['train', 'val', 'test']]
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **loader_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=False, **loader_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, **loader_kwargs)
# 获取模型、优化器和损失函数
device = torch.device('cpu')
model = AverageMIL(input_dim=320, hidden_dim=64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
loss_fn = nn.CrossEntropyLoss()
# 设置训练-验证循环和提前停止
num_epochs, min_early_stopping, patience, counter = 20, 10, 5, 0
lowest_val_loss, best_model = np.inf, None
all_train_logs, all_val_logs = [], []
for epoch in range(num_epochs):
# 训练模型
train_log = traineval_epoch(epoch, model, train_loader, optimizer=optimizer, split='train', device=device, verbose=2, print_every=200)
# 验证模型
val_log = traineval_epoch(epoch, model, val_loader, optimizer=None, split='val', device=device, verbose=1)
val_loss = val_log['val loss']
# 提前停止:如果验证损失在 <min_early_stopping> 轮后 <patience> 轮内不下降,则提前停止模型训练
if (epoch > min_early_stopping):
if (val_loss < lowest_val_loss):
print(f'重置提前停止计数器: {lowest_val_loss:.04f} -> {val_loss:.04f}...')
lowest_val_loss, counter, best_model = val_loss, 0, copy.deepcopy(model)
else:
print(f'提前停止计数器更新: {counter}/{patience} -> {counter+1}/{patience}...')
counter += 1
if counter >= patience: break
print()
# 在测试拆分上报告最佳模型(最低验证损失)
best_model = model if (best_model is None) else best_model
test_log = traineval_epoch(epoch, best_model, test_loader, optimizer=None, split='test', device=device, verbose=1)
Epoch 0: Batch 199 Avg Loss: 0.7137 Label: 1 Bag Size: 957 Epoch 0: Batch 399 Avg Loss: 0.7084 Label: 1 Bag Size: 1131 Epoch 0: Batch 599 Avg Loss: 0.7093 Label: 1 Bag Size: 1342 Epoch 0: Batch 799 Avg Loss: 0.7075 Label: 1 Bag Size: 2319 ### (Train Summary) ### Epoch 0: Train loss: 0.7067 Train acc: 0.5178 Train auc: 0.5205 ### (Val Summary) ### Epoch 0: Val loss: 0.6804 Val acc: 0.5627 Val auc: 0.6947 Epoch 1: Batch 199 Avg Loss: 0.6892 Label: 1 Bag Size: 2830 Epoch 1: Batch 399 Avg Loss: 0.6933 Label: 0 Bag Size: 4905 Epoch 1: Batch 599 Avg Loss: 0.6951 Label: 0 Bag Size: 5624 Epoch 1: Batch 799 Avg Loss: 0.6919 Label: 0 Bag Size: 1298 ### (Train Summary) ### Epoch 1: Train loss: 0.6945 Train acc: 0.5530 Train auc: 0.5598 ### (Val Summary) ### Epoch 1: Val loss: 0.6736 Val acc: 0.5525 Val auc: 0.7105 Epoch 2: Batch 199 Avg Loss: 0.6909 Label: 0 Bag Size: 3509 Epoch 2: Batch 399 Avg Loss: 0.6808 Label: 0 Bag Size: 6975 Epoch 2: Batch 599 Avg Loss: 0.6799 Label: 0 Bag Size: 1555 Epoch 2: Batch 799 Avg Loss: 0.6784 Label: 0 Bag Size: 4302 ### (Train Summary) ### Epoch 2: Train loss: 0.6773 Train acc: 0.5668 Train auc: 0.6022 ### (Val Summary) ### Epoch 2: Val loss: 0.6703 Val acc: 0.5215 Val auc: 0.7105 Epoch 3: Batch 199 Avg Loss: 0.6628 Label: 0 Bag Size: 6237 Epoch 3: Batch 399 Avg Loss: 0.6645 Label: 0 Bag Size: 6577 Epoch 3: Batch 599 Avg Loss: 0.6678 Label: 1 Bag Size: 223 Epoch 3: Batch 799 Avg Loss: 0.6640 Label: 0 Bag Size: 3368 ### (Train Summary) ### Epoch 3: Train loss: 0.6653 Train acc: 0.5916 Train auc: 0.6377 ### (Val Summary) ### Epoch 3: Val loss: 0.6588 Val acc: 0.6507 Val auc: 0.7164 Epoch 4: Batch 199 Avg Loss: 0.6642 Label: 1 Bag Size: 1563 Epoch 4: Batch 399 Avg Loss: 0.6751 Label: 0 Bag Size: 1512 Epoch 4: Batch 599 Avg Loss: 0.6637 Label: 1 Bag Size: 8248 Epoch 4: Batch 799 Avg Loss: 0.6586 Label: 1 Bag Size: 933 ### (Train Summary) ### Epoch 4: Train loss: 0.6589 Train acc: 0.5951 Train auc: 0.6462 ### (Val Summary) ### Epoch 4: Val loss: 0.7084 Val acc: 0.5000 Val auc: 0.7092 Epoch 5: Batch 199 Avg Loss: 0.6623 Label: 0 Bag Size: 377 Epoch 5: Batch 399 Avg Loss: 0.6632 Label: 0 Bag Size: 1856 Epoch 5: Batch 599 Avg Loss: 0.6511 Label: 1 Bag Size: 3239 Epoch 5: Batch 799 Avg Loss: 0.6467 Label: 0 Bag Size: 356 ### (Train Summary) ### Epoch 5: Train loss: 0.6434 Train acc: 0.6376 Train auc: 0.6816 ### (Val Summary) ### Epoch 5: Val loss: 0.6920 Val acc: 0.5000 Val auc: 0.7083 Epoch 6: Batch 199 Avg Loss: 0.6496 Label: 1 Bag Size: 4575 Epoch 6: Batch 399 Avg Loss: 0.6349 Label: 1 Bag Size: 4743 Epoch 6: Batch 599 Avg Loss: 0.6317 Label: 0 Bag Size: 4420 Epoch 6: Batch 799 Avg Loss: 0.6311 Label: 1 Bag Size: 4480 ### (Train Summary) ### Epoch 6: Train loss: 0.6323 Train acc: 0.6517 Train auc: 0.7069 ### (Val Summary) ### Epoch 6: Val loss: 0.7329 Val acc: 0.5000 Val auc: 0.7083 Epoch 7: Batch 199 Avg Loss: 0.6297 Label: 0 Bag Size: 4443 Epoch 7: Batch 399 Avg Loss: 0.6245 Label: 1 Bag Size: 1261 Epoch 7: Batch 599 Avg Loss: 0.6233 Label: 1 Bag Size: 4893 Epoch 7: Batch 799 Avg Loss: 0.6257 Label: 1 Bag Size: 5530 ### (Train Summary) ### Epoch 7: Train loss: 0.6244 Train acc: 0.6617 Train auc: 0.7160 ### (Val Summary) ### Epoch 7: Val loss: 0.6490 Val acc: 0.6471 Val auc: 0.7117 Epoch 8: Batch 199 Avg Loss: 0.6155 Label: 1 Bag Size: 5553 Epoch 8: Batch 399 Avg Loss: 0.6293 Label: 0 Bag Size: 8315 Epoch 8: Batch 599 Avg Loss: 0.6184 Label: 0 Bag Size: 4905 Epoch 8: Batch 799 Avg Loss: 0.6141 Label: 1 Bag Size: 3630 ### (Train Summary) ### Epoch 8: Train loss: 0.6119 Train acc: 0.6746 Train auc: 0.7349 ### (Val Summary) ### Epoch 8: Val loss: 0.6426 Val acc: 0.6477 Val auc: 0.7122 Epoch 9: Batch 199 Avg Loss: 0.6157 Label: 0 Bag Size: 2482 Epoch 9: Batch 399 Avg Loss: 0.6002 Label: 0 Bag Size: 1856 Epoch 9: Batch 599 Avg Loss: 0.6062 Label: 0 Bag Size: 384 Epoch 9: Batch 799 Avg Loss: 0.6094 Label: 0 Bag Size: 579 ### (Train Summary) ### Epoch 9: Train loss: 0.6099 Train acc: 0.6767 Train auc: 0.7334 ### (Val Summary) ### Epoch 9: Val loss: 0.6621 Val acc: 0.6112 Val auc: 0.7151 Epoch 10: Batch 199 Avg Loss: 0.5707 Label: 1 Bag Size: 3912 Epoch 10: Batch 399 Avg Loss: 0.6012 Label: 0 Bag Size: 6484 Epoch 10: Batch 599 Avg Loss: 0.6039 Label: 1 Bag Size: 3609 Epoch 10: Batch 799 Avg Loss: 0.5971 Label: 0 Bag Size: 345 ### (Train Summary) ### Epoch 10: Train loss: 0.5959 Train acc: 0.6872 Train auc: 0.7599 ### (Val Summary) ### Epoch 10: Val loss: 0.6437 Val acc: 0.6577 Val auc: 0.7164 Epoch 11: Batch 199 Avg Loss: 0.5891 Label: 0 Bag Size: 1329 Epoch 11: Batch 399 Avg Loss: 0.6049 Label: 0 Bag Size: 3423 Epoch 11: Batch 599 Avg Loss: 0.5980 Label: 1 Bag Size: 5204 Epoch 11: Batch 799 Avg Loss: 0.5910 Label: 1 Bag Size: 1342 ### (Train Summary) ### Epoch 11: Train loss: 0.5950 Train acc: 0.6931 Train auc: 0.7554 ### (Val Summary) ### Epoch 11: Val loss: 0.6243 Val acc: 0.6913 Val auc: 0.7151 Resetting early-stopping counter: inf -> 0.6243... Epoch 12: Batch 199 Avg Loss: 0.5887 Label: 1 Bag Size: 5400 Epoch 12: Batch 399 Avg Loss: 0.5753 Label: 0 Bag Size: 6074 Epoch 12: Batch 599 Avg Loss: 0.5858 Label: 0 Bag Size: 4085 Epoch 12: Batch 799 Avg Loss: 0.5852 Label: 0 Bag Size: 2415 ### (Train Summary) ### Epoch 12: Train loss: 0.5874 Train acc: 0.7049 Train auc: 0.7621 ### (Val Summary) ### Epoch 12: Val loss: 0.6305 Val acc: 0.6482 Val auc: 0.7194 Early-stopping counter updating: 0/5 -> 1/5... Epoch 13: Batch 199 Avg Loss: 0.5654 Label: 0 Bag Size: 4486 Epoch 13: Batch 399 Avg Loss: 0.5794 Label: 0 Bag Size: 464 Epoch 13: Batch 599 Avg Loss: 0.5889 Label: 0 Bag Size: 8841 Epoch 13: Batch 799 Avg Loss: 0.5794 Label: 0 Bag Size: 1334 ### (Train Summary) ### Epoch 13: Train loss: 0.5819 Train acc: 0.6895 Train auc: 0.7633 ### (Val Summary) ### Epoch 13: Val loss: 0.6198 Val acc: 0.6913 Val auc: 0.7224 Resetting early-stopping counter: 0.6243 -> 0.6198... Epoch 14: Batch 199 Avg Loss: 0.5543 Label: 0 Bag Size: 2093 Epoch 14: Batch 399 Avg Loss: 0.5699 Label: 0 Bag Size: 9103 Epoch 14: Batch 599 Avg Loss: 0.5784 Label: 1 Bag Size: 2906 Epoch 14: Batch 799 Avg Loss: 0.5836 Label: 1 Bag Size: 1805 ### (Train Summary) ### Epoch 14: Train loss: 0.5861 Train acc: 0.6993 Train auc: 0.7626 ### (Val Summary) ### Epoch 14: Val loss: 0.6432 Val acc: 0.6412 Val auc: 0.7245 Early-stopping counter updating: 0/5 -> 1/5... Epoch 15: Batch 199 Avg Loss: 0.5710 Label: 1 Bag Size: 711 Epoch 15: Batch 399 Avg Loss: 0.5635 Label: 1 Bag Size: 980 Epoch 15: Batch 599 Avg Loss: 0.5850 Label: 0 Bag Size: 3167 Epoch 15: Batch 799 Avg Loss: 0.5800 Label: 1 Bag Size: 5346 ### (Train Summary) ### Epoch 15: Train loss: 0.5789 Train acc: 0.7088 Train auc: 0.7732 ### (Val Summary) ### Epoch 15: Val loss: 0.6718 Val acc: 0.5946 Val auc: 0.7287 Early-stopping counter updating: 1/5 -> 2/5... Epoch 16: Batch 199 Avg Loss: 0.5943 Label: 0 Bag Size: 4156 Epoch 16: Batch 399 Avg Loss: 0.5866 Label: 1 Bag Size: 2068 Epoch 16: Batch 599 Avg Loss: 0.5768 Label: 1 Bag Size: 3720 Epoch 16: Batch 799 Avg Loss: 0.5750 Label: 1 Bag Size: 6535 ### (Train Summary) ### Epoch 16: Train loss: 0.5765 Train acc: 0.6930 Train auc: 0.7688 ### (Val Summary) ### Epoch 16: Val loss: 0.6682 Val acc: 0.6050 Val auc: 0.7313 Early-stopping counter updating: 2/5 -> 3/5... Epoch 17: Batch 199 Avg Loss: 0.6217 Label: 1 Bag Size: 1594 Epoch 17: Batch 399 Avg Loss: 0.6043 Label: 1 Bag Size: 5042 Epoch 17: Batch 599 Avg Loss: 0.5981 Label: 1 Bag Size: 4271 Epoch 17: Batch 799 Avg Loss: 0.5737 Label: 0 Bag Size: 3418 ### (Train Summary) ### Epoch 17: Train loss: 0.5760 Train acc: 0.7109 Train auc: 0.7726 ### (Val Summary) ### Epoch 17: Val loss: 0.6160 Val acc: 0.6798 Val auc: 0.7326 Resetting early-stopping counter: 0.6198 -> 0.6160... Epoch 18: Batch 199 Avg Loss: 0.5472 Label: 0 Bag Size: 2924 Epoch 18: Batch 399 Avg Loss: 0.5504 Label: 0 Bag Size: 5320 Epoch 18: Batch 599 Avg Loss: 0.5530 Label: 0 Bag Size: 2522 Epoch 18: Batch 799 Avg Loss: 0.5583 Label: 0 Bag Size: 446 ### (Train Summary) ### Epoch 18: Train loss: 0.5615 Train acc: 0.6861 Train auc: 0.7838 ### (Val Summary) ### Epoch 18: Val loss: 0.6216 Val acc: 0.6405 Val auc: 0.7351 Early-stopping counter updating: 0/5 -> 1/5... Epoch 19: Batch 199 Avg Loss: 0.5789 Label: 1 Bag Size: 1063 Epoch 19: Batch 399 Avg Loss: 0.5672 Label: 1 Bag Size: 3959 Epoch 19: Batch 599 Avg Loss: 0.5663 Label: 0 Bag Size: 1455 Epoch 19: Batch 799 Avg Loss: 0.5632 Label: 0 Bag Size: 4642 ### (Train Summary) ### Epoch 19: Train loss: 0.5647 Train acc: 0.7148 Train auc: 0.7823 ### (Val Summary) ### Epoch 19: Val loss: 0.6092 Val acc: 0.6607 Val auc: 0.7385 Resetting early-stopping counter: 0.6160 -> 0.6092... ### (Test Summary) ### Epoch 19: Test loss: 0.5563 Test acc: 0.6837 Test auc: 0.8401
模型 2. 实现基于注意力的多实例学习 (ABMIL)¶
在你对 AverageMIL
进行实验之后,你已经准备好为 LUAD vs. LUSC 亚型实现一个更复杂的模型了。形式上,令 $\mathbf{H}=\left\{\mathbf{h}_1, \ldots, \mathbf{h}_M\right\} \in \mathbb{R}^{M \times D}$ 为包含 $M$ 个补丁嵌入的一组袋子,其中每个嵌入的维度大小为 $D$。Ilse等人在2018年提出了以下基于注意力的MIL池化操作:
$$ \mathbf{z} =\sum_{i=1}^M a_i \mathbf{h}_i, \quad \text{其中} \enspace a_i=\frac{\exp \left\{\mathbf{w}^{\top}\left(\tanh \left(\mathbf{V h}_{i} ^ { \top }\right) \odot \operatorname{sigm}\left(\mathbf{U h}_i^{\top}\right)\right)\right\}}{\sum_{j=1}^M \exp \left\{\mathbf{w}^{\top}\left(\tanh \left(\mathbf{V} \mathbf{h}_j^{\top}\right) \odot \operatorname{sigm}\left(\mathbf{U h}_j^{\top}\right)\right)\right\}} $$
其中 $\mathbf{w} \in \mathbb{R}^{L \times 1}$, $\mathbf{V} \in \mathbb{R}^{L \times D}$, 和 $\mathbf{U} \in \mathbb{R}^{L \times D}$ 是可学习的神经网络参数(实现为全连接层),而 $\mathbf{z} \in \mathbb{R}^{D}$ 则是 $\mathbf{H}$ 中所有补丁嵌入的加权平均值。双曲正切 $\tanh (\cdot)$ 逐元素非线性和 sigmoid 非线性被用于适当的梯度流。
通过 PyTorch,计算 $a_m$ 的数学表达式被实现为 torch.nn
模块 AttentionTanhSigmoidGating
,我们将其作为 ABMIL
中的一个层来计算补丁嵌入的加权平均值。
class AttentionTanhSigmoidGating(nn.Module):
def __init__(self, D=64, L=64, dropout=0.25):
r"""
带有双曲正切非线性和 Sigmoid 门控的全局注意力池化层(Ilse et al. 2018)。
Args:
D (int): 输入特征维度。
L (int): 隐藏层维度。符号从 Ilse et al 2018 中的 M 改为 L,因为 M 也用于描述 WSI 中的补丁嵌入数量。
dropout (float): Dropout 概率。
Returns:
A_norm (torch.Tensor): 归一化注意力分数的 [M x 1] 维张量(总和为 1)。
"""
super(AttentionTanhSigmoidGating, self).__init__()
self.tanhV = nn.Sequential(*[nn.Linear(D, L), nn.Tanh(), nn.Dropout(dropout)])
self.sigmU = nn.Sequential(*[nn.Linear(D, L), nn.Sigmoid(), nn.Dropout(dropout)])
self.w = nn.Linear(L, 1)
def forward(self, H):
A_raw = self.w(self.tanhV(H).mul(self.sigmU(H))) # 指数项
A_norm = F.softmax(A_raw, dim=0) # 应用 softmax 将权重归一化为 1
assert abs(A_norm.sum() - 1) < 1e-3 # 断言语句,用于检查 sum(A) 是否约等于 1
return A_norm
class ABMIL(nn.Module):
def __init__(self, input_dim=320, hidden_dim=64, dropout=0.25, n_classes=2):
r"""
基于注意力的多实例学习(Ilse et al. 2018)。
Args:
input_dim (int): 输入特征维度。
hidden_dim (int): 隐藏层维度。
dropout (float): Dropout 概率。
n_classes (int): 类别数量。
"""
super(ABMIL, self).__init__()
self.inst_level_fc = nn.Sequential(*[nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) # 对每个嵌入应用“实例级”的全连接层
self.global_attn = AttentionTanhSigmoidGating(L=hidden_dim, D=hidden_dim) # 注意力函数
self.bag_level_classifier = nn.Linear(hidden_dim, n_classes) # 袋子级别的分类器
def forward(self, X: torch.randn(100, 320)):
r"""
接受一个 [M x D] 维的补丁特征包(表示一个 WSI),并输出:1)用于分类的未归一化对数,2)未归一化的注意力分数。
Args:
X (torch.Tensor): [M x D] 维的补丁特征包(表示一个 WSI)
Returns:
logits (torch.Tensor): 用于分类任务的未归一化对数的 [1 x n_classes] 维张量。
A_norm (torch.Tensor): 注意力分数的 [M,]- 或 [M x 1] 维张量。
"""
H_inst = self.inst_level_fc(X) # 1. 处理每个特征嵌入,使其大小为“hidden-dim”
A_norm = self.global_attn(H_inst) # 2. 获取每个嵌入的归一化注意力分数(使 sum(A_norm) 约等于 1)
z = torch.sum(A_norm * H_inst, dim=0) # 3. 对补丁进行全局注意力池化得到输出
logits = self.bag_level_classifier(z).unsqueeze(dim=0) # 4. 获取用于分类任务的未归一化对数
try:
assert logits.shape == (1,2)
except:
print(f"Logit tensor shape is not formatted correctly. Should output [1 x 2] shape, but got {logits.shape} shape")
return logits, A_norm
# 设置随机种子(用于可重复性)
torch.manual_seed(2023)
# 获取用于训练-验证-测试分割评估的数据加载器
if use_drive:
feats_dirpath = '/content/drive/My Drive/ai4healthsummerschool/feats_pt/'
csv_fpath = '/content/drive/My Drive/ai4healthsummerschool/tcga_lung_splits.csv'
else:
feats_dirpath, csv_fpath = './feats_pt/', './tcga_lung_splits.csv'
display(pd.read_csv(csv_fpath).head(10)) # 可视化数据
loader_kwargs = {'batch_size': 1, 'num_workers': 2, 'pin_memory': False} # 由于可变大小的包大小,批处理大小设置为 1。难以整合。
train_dataset, val_dataset, test_dataset = [MILDataset(feats_dirpath, csv_fpath, which_split=split) for split in ['train', 'val', 'test']]
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **loader_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=False, **loader_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, **loader_kwargs)
# 获取模型、优化器和损失函数
device = torch.device('cpu')
model = ABMIL(input_dim=320, hidden_dim=64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
loss_fn = nn.CrossEntropyLoss()
# 设置训练-验证循环和提前停止
num_epochs, min_early_stopping, patience, counter = 20, 10, 5, 0
lowest_val_loss, best_model = np.inf, None
all_train_logs, all_val_logs = [], [] # TODO: 每个时期对 train_log / val_log 进行一些操作以帮助可视化性能曲线?
for epoch in range(num_epochs):
train_log = traineval_epoch(epoch, model, train_loader, optimizer=optimizer, split='train', device=device, verbose=2, print_every=200)
val_log = traineval_epoch(epoch, model, val_loader, optimizer=None, split='val', device=device, verbose=1)
val_loss = val_log['val loss']
# 提前停止:如果验证损失在最小提前停止之后的 <patience> 个时期内没有下降,提前停止模型训练
if (epoch > min_early_stopping):
if (val_loss < lowest_val_loss):
print(f'Resetting early-stopping counter: {lowest_val_loss:.04f} -> {val_loss:.04f}...')
lowest_val_loss, counter, best_model = val_loss, 0, copy.deepcopy(model)
else:
print(f'Early-stopping counter updating: {counter}/{patience} -> {counter+1}/{patience}...')
counter += 1
if counter >= patience: break
print()
# 在测试分割上报告最佳模型(最低验证损失)
best_model = model if (best_model is None) else best_model
test_log = traineval_epoch(epoch, best_model, test_loader, optimizer=None, split='test', device=device, verbose=1)
case_id | slide_id | tumor_type | OncoTreeSiteCode | main_cancer_type | sex | project_id | Diagnosis | OncoTreeCode | OncoTreeCode_Binarized | split | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | TCGA-73-4676 | TCGA-73-4676-01Z-00-DX1.4d781bbc-a45e-4f9d-b6b... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
1 | TCGA-MP-A4T6 | TCGA-MP-A4T6-01Z-00-DX1.085C4F5A-DB1B-434A-9D6... | Primary | LUNG | Non-Small Cell Lung Cancer | F | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
2 | TCGA-78-7167 | TCGA-78-7167-01Z-00-DX1.f79e1a9b-a3eb-4c91-a1f... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
3 | TCGA-L9-A444 | TCGA-L9-A444-01Z-00-DX1.88CF6F01-0C1F-4572-81E... | Primary | LUNG | Non-Small Cell Lung Cancer | F | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
4 | TCGA-55-8097 | TCGA-55-8097-01Z-00-DX1.2f847b65-a5dc-41be-9dd... | Primary | LUNG | Non-Small Cell Lung Cancer | F | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
5 | TCGA-44-8119 | TCGA-44-8119-01Z-00-DX1.1EBEBFA7-22DB-4365-9DF... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
6 | TCGA-49-AAR2 | TCGA-49-AAR2-01Z-00-DX1.1F09F896-446E-4C55-8D0... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
7 | TCGA-L9-A743 | TCGA-L9-A743-01Z-00-DX1.27ED2955-E8B5-4A3C-ADA... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
8 | TCGA-99-8032 | TCGA-99-8032-01Z-00-DX1.7380b78f-ea25-43e0-ac9... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
9 | TCGA-55-6972 | TCGA-55-6972-01Z-00-DX1.0b441ad0-c30f-4f63-849... | Primary | LUNG | Non-Small Cell Lung Cancer | M | TCGA-LUAD | Lung Adenocarcinoma | LUAD | 0 | train |
Epoch 0: Batch 199 Avg Loss: 0.7363 Label: 0 Bag Size: 507 Epoch 0: Batch 399 Avg Loss: 0.7275 Label: 0 Bag Size: 1163 Epoch 0: Batch 599 Avg Loss: 0.7064 Label: 1 Bag Size: 4417 Epoch 0: Batch 799 Avg Loss: 0.7082 Label: 0 Bag Size: 1598 ### (Train Summary) ### Epoch 0: Train loss: 0.7074 Train acc: 0.5111 Train auc: 0.5243 ### (Val Summary) ### Epoch 0: Val loss: 0.6894 Val acc: 0.4991 Val auc: 0.6386 Epoch 1: Batch 199 Avg Loss: 0.7016 Label: 0 Bag Size: 6939 Epoch 1: Batch 399 Avg Loss: 0.6918 Label: 0 Bag Size: 1229 Epoch 1: Batch 599 Avg Loss: 0.6928 Label: 0 Bag Size: 519 Epoch 1: Batch 799 Avg Loss: 0.6887 Label: 1 Bag Size: 4335 ### (Train Summary) ### Epoch 1: Train loss: 0.6899 Train acc: 0.5462 Train auc: 0.5605 ### (Val Summary) ### Epoch 1: Val loss: 0.6863 Val acc: 0.5000 Val auc: 0.6586 Epoch 2: Batch 199 Avg Loss: 0.6807 Label: 1 Bag Size: 7186 Epoch 2: Batch 399 Avg Loss: 0.6799 Label: 1 Bag Size: 2611 Epoch 2: Batch 599 Avg Loss: 0.6813 Label: 0 Bag Size: 708 Epoch 2: Batch 799 Avg Loss: 0.6780 Label: 1 Bag Size: 4235 ### (Train Summary) ### Epoch 2: Train loss: 0.6780 Train acc: 0.5782 Train auc: 0.6022 ### (Val Summary) ### Epoch 2: Val loss: 0.6763 Val acc: 0.5595 Val auc: 0.6739 Epoch 3: Batch 199 Avg Loss: 0.6632 Label: 1 Bag Size: 4657 Epoch 3: Batch 399 Avg Loss: 0.6718 Label: 1 Bag Size: 6929 Epoch 3: Batch 599 Avg Loss: 0.6729 Label: 0 Bag Size: 4582 Epoch 3: Batch 799 Avg Loss: 0.6697 Label: 1 Bag Size: 2143 ### (Train Summary) ### Epoch 3: Train loss: 0.6716 Train acc: 0.5818 Train auc: 0.6167 ### (Val Summary) ### Epoch 3: Val loss: 0.7256 Val acc: 0.5000 Val auc: 0.6705 Epoch 4: Batch 199 Avg Loss: 0.6751 Label: 0 Bag Size: 681 Epoch 4: Batch 399 Avg Loss: 0.6710 Label: 0 Bag Size: 4601 Epoch 4: Batch 599 Avg Loss: 0.6649 Label: 0 Bag Size: 3559 Epoch 4: Batch 799 Avg Loss: 0.6613 Label: 0 Bag Size: 287 ### (Train Summary) ### Epoch 4: Train loss: 0.6596 Train acc: 0.5896 Train auc: 0.6441 ### (Val Summary) ### Epoch 4: Val loss: 0.6605 Val acc: 0.6794 Val auc: 0.6913 Epoch 5: Batch 199 Avg Loss: 0.6653 Label: 1 Bag Size: 1096 Epoch 5: Batch 399 Avg Loss: 0.6557 Label: 1 Bag Size: 4743 Epoch 5: Batch 599 Avg Loss: 0.6457 Label: 0 Bag Size: 4152 Epoch 5: Batch 799 Avg Loss: 0.6487 Label: 0 Bag Size: 6785 ### (Train Summary) ### Epoch 5: Train loss: 0.6471 Train acc: 0.6375 Train auc: 0.6780 ### (Val Summary) ### Epoch 5: Val loss: 0.6603 Val acc: 0.5699 Val auc: 0.7007 Epoch 6: Batch 199 Avg Loss: 0.6459 Label: 1 Bag Size: 1111 Epoch 6: Batch 399 Avg Loss: 0.6323 Label: 0 Bag Size: 1265 Epoch 6: Batch 599 Avg Loss: 0.6366 Label: 1 Bag Size: 4928 Epoch 6: Batch 799 Avg Loss: 0.6297 Label: 1 Bag Size: 6929 ### (Train Summary) ### Epoch 6: Train loss: 0.6280 Train acc: 0.6475 Train auc: 0.7111 ### (Val Summary) ### Epoch 6: Val loss: 0.6431 Val acc: 0.6403 Val auc: 0.7062 Epoch 7: Batch 199 Avg Loss: 0.6292 Label: 0 Bag Size: 1555 Epoch 7: Batch 399 Avg Loss: 0.6124 Label: 1 Bag Size: 3609 Epoch 7: Batch 599 Avg Loss: 0.6162 Label: 0 Bag Size: 3900 Epoch 7: Batch 799 Avg Loss: 0.6058 Label: 1 Bag Size: 6125 ### (Train Summary) ### Epoch 7: Train loss: 0.6057 Train acc: 0.6645 Train auc: 0.7375 ### (Val Summary) ### Epoch 7: Val loss: 0.6388 Val acc: 0.7100 Val auc: 0.7177 Epoch 8: Batch 199 Avg Loss: 0.6160 Label: 0 Bag Size: 569 Epoch 8: Batch 399 Avg Loss: 0.6007 Label: 1 Bag Size: 3959 Epoch 8: Batch 599 Avg Loss: 0.5961 Label: 0 Bag Size: 3229 Epoch 8: Batch 799 Avg Loss: 0.5946 Label: 0 Bag Size: 4215 ### (Train Summary) ### Epoch 8: Train loss: 0.5906 Train acc: 0.6884 Train auc: 0.7590 ### (Val Summary) ### Epoch 8: Val loss: 0.6215 Val acc: 0.6607 Val auc: 0.7338 Epoch 9: Batch 199 Avg Loss: 0.5595 Label: 0 Bag Size: 1061 Epoch 9: Batch 399 Avg Loss: 0.5697 Label: 0 Bag Size: 1080 Epoch 9: Batch 599 Avg Loss: 0.5724 Label: 1 Bag Size: 1111 Epoch 9: Batch 799 Avg Loss: 0.5676 Label: 1 Bag Size: 2370 ### (Train Summary) ### Epoch 9: Train loss: 0.5709 Train acc: 0.7127 Train auc: 0.7802 ### (Val Summary) ### Epoch 9: Val loss: 0.6258 Val acc: 0.6996 Val auc: 0.7474 Epoch 10: Batch 199 Avg Loss: 0.5501 Label: 0 Bag Size: 657 Epoch 10: Batch 399 Avg Loss: 0.5604 Label: 0 Bag Size: 4582 Epoch 10: Batch 599 Avg Loss: 0.5507 Label: 1 Bag Size: 2964 Epoch 10: Batch 799 Avg Loss: 0.5445 Label: 0 Bag Size: 1877 ### (Train Summary) ### Epoch 10: Train loss: 0.5479 Train acc: 0.7466 Train auc: 0.8090 ### (Val Summary) ### Epoch 10: Val loss: 0.6297 Val acc: 0.6418 Val auc: 0.7721 Epoch 11: Batch 199 Avg Loss: 0.5339 Label: 1 Bag Size: 4775 Epoch 11: Batch 399 Avg Loss: 0.5165 Label: 0 Bag Size: 1163 Epoch 11: Batch 599 Avg Loss: 0.5147 Label: 0 Bag Size: 1045 Epoch 11: Batch 799 Avg Loss: 0.5211 Label: 1 Bag Size: 3985 ### (Train Summary) ### Epoch 11: Train loss: 0.5204 Train acc: 0.7333 Train auc: 0.8233 ### (Val Summary) ### Epoch 11: Val loss: 0.5948 Val acc: 0.7517 Val auc: 0.7874 Resetting early-stopping counter: inf -> 0.5948... Epoch 12: Batch 199 Avg Loss: 0.5097 Label: 1 Bag Size: 7475 Epoch 12: Batch 399 Avg Loss: 0.4951 Label: 1 Bag Size: 2272 Epoch 12: Batch 599 Avg Loss: 0.5076 Label: 0 Bag Size: 1681 Epoch 12: Batch 799 Avg Loss: 0.4966 Label: 1 Bag Size: 4230 ### (Train Summary) ### Epoch 12: Train loss: 0.4967 Train acc: 0.7786 Train auc: 0.8461 ### (Val Summary) ### Epoch 12: Val loss: 0.5608 Val acc: 0.7324 Val auc: 0.8014 Resetting early-stopping counter: 0.5948 -> 0.5608... Epoch 13: Batch 199 Avg Loss: 0.5048 Label: 0 Bag Size: 1846 Epoch 13: Batch 399 Avg Loss: 0.5039 Label: 0 Bag Size: 4170 Epoch 13: Batch 599 Avg Loss: 0.4992 Label: 1 Bag Size: 5204 Epoch 13: Batch 799 Avg Loss: 0.4776 Label: 1 Bag Size: 2459 ### (Train Summary) ### Epoch 13: Train loss: 0.4785 Train acc: 0.7714 Train auc: 0.8557 ### (Val Summary) ### Epoch 13: Val loss: 0.6134 Val acc: 0.6728 Val auc: 0.8134 Early-stopping counter updating: 0/5 -> 1/5... Epoch 14: Batch 199 Avg Loss: 0.4360 Label: 1 Bag Size: 1246 Epoch 14: Batch 399 Avg Loss: 0.4248 Label: 1 Bag Size: 511 Epoch 14: Batch 599 Avg Loss: 0.4275 Label: 0 Bag Size: 2878 Epoch 14: Batch 799 Avg Loss: 0.4469 Label: 1 Bag Size: 4986 ### (Train Summary) ### Epoch 14: Train loss: 0.4499 Train acc: 0.7972 Train auc: 0.8744 ### (Val Summary) ### Epoch 14: Val loss: 0.6795 Val acc: 0.6118 Val auc: 0.8355 Early-stopping counter updating: 1/5 -> 2/5... Epoch 15: Batch 199 Avg Loss: 0.4929 Label: 0 Bag Size: 8899 Epoch 15: Batch 399 Avg Loss: 0.4438 Label: 1 Bag Size: 2629 Epoch 15: Batch 599 Avg Loss: 0.4495 Label: 1 Bag Size: 5989 Epoch 15: Batch 799 Avg Loss: 0.4487 Label: 1 Bag Size: 3173 ### (Train Summary) ### Epoch 15: Train loss: 0.4431 Train acc: 0.8009 Train auc: 0.8764 ### (Val Summary) ### Epoch 15: Val loss: 0.7838 Val acc: 0.5508 Val auc: 0.8304 Early-stopping counter updating: 2/5 -> 3/5... Epoch 16: Batch 199 Avg Loss: 0.4658 Label: 1 Bag Size: 2192 Epoch 16: Batch 399 Avg Loss: 0.4532 Label: 0 Bag Size: 8059 Epoch 16: Batch 599 Avg Loss: 0.4323 Label: 0 Bag Size: 3019 Epoch 16: Batch 799 Avg Loss: 0.4167 Label: 0 Bag Size: 5014 ### (Train Summary) ### Epoch 16: Train loss: 0.4073 Train acc: 0.8277 Train auc: 0.8994 ### (Val Summary) ### Epoch 16: Val loss: 0.6598 Val acc: 0.6322 Val auc: 0.8410 Early-stopping counter updating: 3/5 -> 4/5... Epoch 17: Batch 199 Avg Loss: 0.4421 Label: 0 Bag Size: 2352 Epoch 17: Batch 399 Avg Loss: 0.4002 Label: 1 Bag Size: 502 Epoch 17: Batch 599 Avg Loss: 0.3995 Label: 0 Bag Size: 5177 Epoch 17: Batch 799 Avg Loss: 0.4163 Label: 1 Bag Size: 2739 ### (Train Summary) ### Epoch 17: Train loss: 0.4133 Train acc: 0.8065 Train auc: 0.8932 ### (Val Summary) ### Epoch 17: Val loss: 0.5700 Val acc: 0.6828 Val auc: 0.8520 Early-stopping counter updating: 4/5 -> 5/5... ### (Test Summary) ### Epoch 17: Test loss: 0.5071 Test acc: 0.7755 Test auc: 0.8917
讨论. 比较和对比 AverageMIL 和 ABMIL¶
比较和对比 AverageMIL
和 ABMIL
在验证和测试性能上的表现。具体而言:
- 哪个模型在测试集上的整体AUC和平衡准确度表现更好?每个模型更容易将哪个类别(LUAD还是LUSC)误分类?
- 在http://clam.mahmoodlab.org提供了针对LUAD和LUSC亚型的高注意力热图可视化,通过CLAM(类似于
ABMIL
),以及每张幻灯片的置信度分数。作为一名临床病理学家,观察这些可视化结果,您对让AI算法辅助您的医学诊断有哪些见解或担忧? - 这个问题集中的实验设置仅限于对来自TCGA的数据进行评估。列出Lu等人2021年(或其他相关的生物医学成像× AI研究)中使用的三种技术,可以用于评估1)数据效率,2)泛化性能和3)
ABMIL
的基于注意力的可解释性的一致性。
# save best_model for next session
if use_drive:
torch.save(best_model.state_dict(), '/content/drive/My Drive/ai4healthsummerschool/abmil.ckpt')