Skip to content

ViT

摘要

尽管变换器(Transformer)架构已成为自然语言处理任务的事实标准,但其在计算机视觉中的应用仍然有限。在视觉领域,注意力要么与卷积网络结合使用,要么用于替换卷积网络的某些组件,同时保持它们的整体结构不变。我们展示了这种对 CNNs 的依赖不是必需的,直接应用于图像块序列的纯粹变换器可以在图像分类任务上表现非常好。当在大量数据上预训练并转移到多个中等大小或小型图像识别基准(ImageNet、CIFAR-100、VTAB 等)时,视觉变换器(ViT)与最先进的卷积网络相比取得了优异的结果,同时训练所需的计算资源大大减少。

引言

基于自注意力的架构,特别是变换器(Transformers)(Vaswani 等人,2017 年),已成为自然语言处理(NLP)中的首选模型。主导方法是在大型文本语料库上进行预训练,然后在较小的特定任务数据集上进行微调(Devlin 等人,2019 年)。得益于变换器的计算效率和可扩展性,现在已经可以训练前所未有的大型模型,拥有超过 1000 亿个参数(Brown 等人,2020 年;Lepikhin 等人,2020 年)。随着模型和数据集的增长,性能仍然没有显示出饱和的迹象。

然而,在计算机视觉领域,卷积架构仍然占据主导地位(LeCun 等人,1989 年;Krizhevsky 等人,2012 年;He 等人,2016 年)。受 NLP 成功的启发,多项工作尝试将类 CNN 架构与自注意力结合(Wang 等人,2018 年;Carion 等人,2020 年),有些工作完全替代了卷积(Ramachandran 等人,2019 年;Wang 等人,2020a)。后者这些模型,虽然理论上高效,但由于使用了特殊的注意力模式,尚未在现代硬件加速器上有效扩展。因此,在大规模图像识别中,经典的 ResNet-like 架构仍然是最先进的技术(Mahajan 等人,2018 年;Xie 等人,2020 年;Kolesnikov 等人,2020 年)。

受 NLP 中变换器扩展成功的启发,我们尝试直接将标准变换器应用于图像,尽可能少地进行修改。为此,我们将图像分割成补丁,并将这些补丁的线性嵌入序列作为输入提供给变换器。图像补丁被视为与 NLP 应用中的令牌(词)相同。我们以监督方式在图像分类任务上训练模型。

当在中等大小的数据集(如 ImageNet)上训练,且不使用强正则化时,这些模型提供了与相当大小的 ResNets 相比略低几个百分点的准确率。这种看似令人沮丧的结果可能是预期之中的:变换器缺乏 CNN 所固有的一些归纳偏置,如平移等变性和局部性,因此在训练数据量不足时泛化性不佳。

然而,如果在更大的数据集上训练模型(1400 万至 3 亿图像),情况就会发生变化。我们发现大规模训练胜过归纳偏置。我们的视觉变换器(ViT)在足够规模的预训练并转移到数据点较少的任务时,取得了优异的结果。当在公共的 ImageNet-21k 数据集或内部的 JFT-300M 数据集上进行预训练时,ViT 在多个图像识别基准测试中达到或超越了最先进的水平。特别是,最佳模型在 ImageNet 上达到了 88.55% 的准确率,在 ImageNet-ReaL 上达到了 90.72% 的准确率,在 CIFAR-100 上达到了 94.55% 的准确率,在 VTAB 的 19 个任务套件上达到了 77.63% 的准确率。

相关工作

Vaswani 等人(2017 年)为机器翻译提出了变换器(Transformers),自那以来,它已成为许多自然语言处理(NLP)任务中的最先进方法。大型基于变换器的模型通常在大型语料库上进行预训练,然后针对手头的任务进行微调:BERT(Devlin 等人,2019 年)使用去噪自监督预训练任务,而 GPT 系列工作使用语言建模作为其预训练任务(Radford 等人,2018 年;2019 年;Brown 等人,2020 年)。

对图像应用自注意力的朴素方法将要求每个像素都关注每个其他像素。由于在像素数量上的二次成本,这不适用于实际输入大小。因此,为了在图像处理的上下文中应用变换器,过去尝试了几种近似方法。Parmar 等人(2018 年)仅在每个查询像素的局部邻域内应用自注意力,而不是全局应用。这样的局部多头点积自注意力块可以完全替换卷积(Hu 等人,2019 年;Ramachandran 等人,2019 年;Zhao 等人,2020 年)。在不同的研究线中,稀疏变换器(Sparse Transformers)(Child 等人,2019 年)采用可扩展的全局自注意力近似,使其适用于图像。另一种扩展注意力的方式是在不同大小的块中应用它(Weissenborn 等人,2019 年),在极端情况下仅沿单个轴(Ho 等人,2019 年;Wang 等人,2020a 年)。许多这些专门的注意力架构在计算机视觉任务上展示了有希望的结果,但需要复杂的工程实现才能在硬件加速器上高效实现。

与我们最相关的是 Cordonnier 等人(2020 年)的模型,该模型从输入图像中提取 2×2 大小的块,并在其上应用完全自注意力。这个模型与 ViT 非常相似,但我们的工作进一步展示了大规模预训练使得普通变换器与(甚至优于)最先进的 CNNs 竞争。此外,Cordonnier 等人(2020 年)使用的小块尺寸为 2×2 像素,这使得模型仅适用于小分辨率图像,而我们也处理了中分辨率图像。

结合卷积神经网络(CNNs)与自注意力形式的尝试也引起了很多兴趣,例如通过增强图像分类的特征图(Bello 等人,2019 年)或通过使用自注意力进一步处理 CNN 的输出,例如用于对象检测(Hu 等人,2018 年;Carion 等人,2020 年),视频处理(Wang 等人,2018 年;Sun 等人,2019 年),图像分类(Wu 等人,2020 年),无监督对象发现(Locatello 等人,2020 年),或统一文本 - 视觉任务(Chen 等人,2020c;Lu 等人,2019 年;Li 等人,2019 年)。

另一个最近的相关模型是图像 GPT(iGPT)(Chen 等人,2020a),它在降低图像分辨率和颜色空间后将变换器应用于图像像素。该模型以生成模型的形式以无监督方式进行训练,然后可以通过微调或线性探测用于分类性能,达到 ImageNet 上 72% 的最大准确率。

我们的工作增加了探索大规模图像识别的论文集合,这些研究超出了标准 ImageNet 数据集的规模。使用额外的数据源能够在标准基准上实现最先进的结果(Mahajan 等人,2018 年;Touvron 等人,2019 年;Xie 等人,2020 年)。此外,Sun 等人(2017 年)研究了 CNN 性能随数据集大小的扩展情况,而 Kolesnikov 等人(2020 年);Djolonga 等人(2020 年)对 CNN 从大规模数据集如 ImageNet-21k 和 JFT-300M 进行转移学习进行了实证探索。我们也关注这两个后者数据集,但训练的是变换器而不是先前工作中使用的基于 ResNet 的模型。

image.png

图 1:模型概览。我们将图像分割成固定大小的块,对每个块进行线性嵌入,添加位置嵌入,然后将得到的向量序列输入到一个标准的变换器(Transformer)编码器中。为了进行分类,我们使用添加一个额外的可学习的“分类令牌”到序列中的标准方法。变换器编码器的插图受到了 Vaswani 等人(2017 年)的启发。

方法

在模型设计中,我们尽可能紧密地遵循原始的变换器(Transformer)(Vaswani 等人,2017 年)。这种故意简单设置的一个优点是,可扩展的 NLP 变换器架构及其高效实现几乎可以直接使用。

视觉变换器(ViT)

模型的概览如图 1 所示。标准的变换器接收一个 1D 的令牌嵌入序列作为输入。为了处理 2D 图像,我们将图像 \(\mathrm{x} \in \mathbb{R}^{H \times W \times C}\) 重塑为一系列平展的 2D 块 \(\mathrm{x}_p \in \mathbb{R}^{N \times\left(P^2 \cdot C\right)}\),其中 \((H, W)\) 是原始图像的分辨率,\(C\) 是通道数,\((P, P)\) 是每个图像块的分辨率,而 \(N=H W / P^2\) 是生成的块数量,也作为变换器的有效输入序列长度。变换器在其所有层中使用恒定的潜在向量大小 \(D\),因此我们将块平展并通过一个可训练的线性投影映射到 \(D\) 维(等式 1)。我们将这个投影的输出称为块嵌入。

类似于 BERT 的 [class] 令牌,我们在嵌入的块序列前加入一个可学习的嵌入 \(\left(\mathbf{z}_0^0=\mathrm{x}_{\text {class }}\right.)\),其在变换器编码器输出的状态 \(\left(\mathrm{z}_L^0\right)\) 作为图像表示 y(等式 4)。在预训练和微调期间,一个分类头被附加到 \(\mathrm{z}_L^0\)。分类头在预训练时通过一个带有一个隐藏层的 MLP 实现,在微调时通过一个单一的线性层实现。

位置嵌入被添加到块嵌入中以保留位置信息。我们使用标准的可学习 1D 位置嵌入,因为我们没有观察到使用更高级的 2D 感知位置嵌入带来显著的性能提升(附录 D.4)。嵌入向量的结果序列作为编码器的输入。

变换器编码器(Vaswani 等人,2017 年)由交替的多头自注意力(MSA,见附录 A)和 MLP 块(等式 2,3)组成。在每个块之前应用 Layernorm(LN),并在每个块之后应用残差连接(Wang 等人,2019 年;Baevski \& Auli,2019 年)。

MLP 包含两层,使用 GELU 非线性激活函数。

\[ \begin{aligned} \mathbf{z}_0 & =\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_p^1 \mathbf{E} ; \mathbf{x}_p^2 \mathbf{E} ; \cdots ; \mathbf{x}_p^N \mathbf{E}\right]+\mathbf{E}_{p o s}, & & \mathbf{E} \in \mathbb{R}^{\left(P^2 \cdot C\right) \times D}, \mathbf{E}_{p o s} \in \mathbb{R}^{(N+1) \times D} \\ \mathbf{z}_{\ell}^{\prime} & =\operatorname{MSA}\left(\operatorname{LN}\left(\mathbf{z}_{\ell-1}\right)\right)+\mathbf{z}_{\ell-1}, & & \ell=1 \ldots L \\ \mathbf{z}_{\ell} & =\operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{z}_{\ell}^{\prime}\right)\right)+\mathbf{z}_{\ell}^{\prime}, & & \ell=1 \ldots L \\ \mathbf{y} & =\operatorname{LN}\left(\mathbf{z}_L^0\right) & & \end{aligned} \]

归纳偏见。我们注意到,与 CNNs 相比,视觉变换器(ViT)具有更少的图像特定归纳偏见。在 CNNs 中,局部性、二维邻域结构和平移等变性被整合到整个模型的每一层中。在 ViT 中,只有 MLP 层是局部的和平移等变的,而自注意力层是全局的。二维邻域结构的使用非常有限:在模型的开始通过将图像切割成块,在微调时调整不同分辨率图像的位置嵌入(如下所述)。除此之外,初始化时的位置嵌入不包含关于块的 2D 位置的信息,所有块之间的空间关系都需要从头开始学习。

混合架构。作为原始图像块的替代,输入序列可以由 CNN(LeCun 等人,1989)的特征图组成。在这种混合模型中,块嵌入投影 E(等式 1)应用于从 CNN 特征图中提取的块。作为一个特殊情况,块可以具有 1x1 的空间大小,这意味着输入序列是通过简单地展平特征图的空间维度并投影到变换器维度获得的。分类输入嵌入和位置嵌入按上述添加。

微调和更高分辨率

通常,我们在大型数据集上预训练 ViT,并微调到(较小的)下游任务。为此,我们移除预训练的预测头,并附加一个零初始化的 D×K 前馈层,其中 K 是下游类别的数量。在比预训练时更高的分辨率下微调通常是有益的(Touvron 等人,2019;Kolesnikov 等人,2020)。当输入更高分辨率的图像时,我们保持块的大小不变,这导致更大的有效序列长度。视觉变换器可以处理任意序列长度(受内存限制),然而,预训练的位置嵌入可能不再有意义。因此,我们根据它们在原始图像中的位置进行 2D 插值,来调整预训练的位置嵌入。注意,这种分辨率调整和块提取是人为将关于图像 2D 结构的归纳偏见注入到视觉变换器的唯一点。