Skip to content

Vision Transformer(Pyramid TNT)

Attention

原文地址:https://zhuanlan.zhihu.com/p/454761367

Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

Transformer in Transformer 针对 ViT 处理图片的方式:将输入图片划分成一个个块 (patch) ,然后针对将这些 patch 看成一个序列 (Sequence) 的不完美之处,提出了一种 TNT 架构,它不仅考虑 patch 之间的信息,还考虑每个 patch 的内部信息,使得 Transformer 模型分别对整体和局部信息进行建模,提升性能。

TNT 架构没有使用 PVT 提出的 Transformer 模型金字塔结构,而金字塔结构在大多数 Vision Transformer 和 MLP 模型上都被证明了有很好的建模性能,所以 Pyramid TNT 作为 TNT 的 Extended Version,进一步验证了金字塔结构对于 TNT Backbone 的作用。

41 Pyramid TNT:使用金字塔结构改进的 TNT Baseline

论文名称:PyramidTNT: Improved Transformer-in-Transformer Baselines with Pyramid Architecture

TNT 论文地址:

https://arxiv.org/pdf/2103.00112.pdf​arxiv.org/pdf/2103.00112.pdf

Pyramid TNT 论文地址:

https://arxiv.org/pdf/2201.00978.pdf​arxiv.org/pdf/2201.00978.pdf

Transformer 需要的是序列 (Sequence) 的输入信号,而我们有的是 image 这种 2D 的输入信号,那直接把图片分块以后进行 Flatten 操作是一种很直觉的处理方式。但是,这种 intuitive 的方法能不能够完美地建模图像,因为我们缺少了一部分非常重要的信息,即:每个 patch 的内部信息

TNT 认为,每个输入的内部信息,即每个 patch 的内部信息,没有被 Transformer 所建模。是一个欠考虑的因素。所以 TNT 使得 Transformer 模型既建模那些不同 patch 之间的关系,也要建模每个 patch 内部的关系。

第 1 步还是将输入图片划分成 n 个块 (patch):

X=[X1,X2,,Xn]Rn×p×p×3

式中 p 是每个块的大小。ViT,DeiT,IPT,SETR,ViT-FRCNN 到这里就把它们输入 Transformer 了,TNT 为了更好地学习图片中 global 和 local 信息的关系,还要再进行一步。 在 TNT 中,作者将 patch 视为表示图像的视觉 "sentence"。每个 patch 进一步分成 m 个子块,即一个 "sentence" 由一系列视觉 "words" 组成。

Xi[xi,1,xi,2,,xi,m],

式中, xi,jRs×s×3 代表第 i 个视觉 "sentence" 的第 j 个视觉 "words",这一步其实是把每个 patch 通过 PyTorch 的 unfold 操作划分成更小的 patch,之后把这些小 patch 通过线性投影展平,就得到了:

Yi=[yi,1,yi,2,,yi,m],yi,j=FC(Vec(xi,j)),

其中, yi,jRc 是第 j 个视觉 "words" 的 Embedding, c 代表 Embedding dimension。

如下图 1 所示,输入是一个大 patch,输出的黄色大长条是这个 patch 展平以后的 sentence embedding,输出的彩色小长条是这个 patch 划分成更小的 patch 之后再展平以后的 word embedding。

图 2 的操作进行完之后就得到了大 patch 的 sentence embedding 以及小 patch 的 word embedding。接下来把它们送入 Transformer 的 Block 里面建模特征,如下图 2 所示。Transformer 是由很多 TNT Blocks 组成的,每个 TNT Block 包含 2 个 Transformer Block,分别是:

  • Outer block 建模 sentence embedding 之间的 global relationship。
  • Inner block 建模 word embedding 之间的 local structure information。

这两种 Block 对应 2 个不同的数据流,其中 Outer block 的数据流在不同 patch 之间运行,而 Inner block 的数据流在每个 patch 内部运行。

Inner Transformer:

定义 YilRp×p×c=Rm×c ,我们把这个值传入 Inner Transformer Tin ,则有:

Yil=Yil1+MSA(LN(Yil1)),Yil=Yil+MLP(LN(Yil)).

注意正常的 Transformer 的输入应该是 (b,n,d) 的张量,这里 b 代表 batch size, n 代表序列长度, d 代表 hidden dimension。不考虑 batch size 这一维,就是一个 (n,d) 的矩阵,也可以看做是 nd 维向量,那么对于 Inner Transformer Tin 来讲,这里的 d=mc 。也就是说,Inner Transformer Tin 的输入是 nmc 维的向量。注意这里的 YilRp×p×c=Rm×c 就是这 n 个向量的其中一个。所以 Inner Transformer 的第 l 个 layer 的输出就可以写为:

Yl=[Y1l,Y2l,,Ynl],YlRn×m×c

Inner Transformer Tin 建模的是更细小的像素级别的 relationship,例如,在一张人脸中,属于眼睛的像素与眼睛的其他像素更相关,而与前额像素的 relationship 较少。

Outer Transformer:

Outer Transformer Tout 就相当于是 ViT 中的 Transformer,它建模的是更答大的 patch 级别的 relationship,输入的 patch embedding 使用 ViT 类似的做法,添加 class tokenZclass ,它们初始化为 0。

Z0=[Zclass,Z10,Z20,,Zn0]R(n+1)×d

定义 ZilRd 为第 l 个 layer 的第 i 个向量,则 Outer Transformer 的表达式为:

Zil=Zil1+MSA(LN(Zil1)),Zil=Zil+MLP(LN(Zil)).

那么现在既有 Outer Transformer 的第 l 个 layer 的输出向量:

Zl=[Zclass,Z1l,Z2l,,Znl]R(n+1)×d

也有 Inner Transformer 的第 l 个 layer 的输出向量:

Yl=[Y1l,Y2l,,Ynl],YlRn×m×c

下面的问题是:要如何把它们结合起来,以融合 global 和 local 的信息呢?

作者采用的方式是:

Zil1=Zil1+Vec(Yil1)Wl1+bl1,

式中, Zil1Rd,Vec() 代表 Flatten 操作, Wl1Rmc×d,bl1Rd 代表权重。

通过这种方式,把第 l 个 layer 的第 i 个 sentence embedding 向量和第 i 个 word embedding 向量融合起来,即对应图 2 的结构。

总的来说,TNT Block 第 l 个 layer 的输入和输出可以表示为:

Yl,Zl=TNT(Yl1,Zl1)

在 TNT Block 中,Inner Transformer 建模 word embedding 之间的 local structure information 之间的关系,而 Outer block 建模 sentence embedding 之间的 global relationship。通过将 TNT Block 堆叠 L 次,作者构建了 Transformer in Transformer。最后,使用一个分类头对图像进行分类。

位置编码:

位置编码的作用是让像素间保持空间位置关系,对于图像就是保持二维信息,它对于图像识别任务来讲很重要。具体来说,就需要对 sentence embedding 和 word embedding 分别设计一种位置编码。

  • sentence positional encoding:

作者这里使用的是可学习的 1D 位置编码:

Z0Z0+Esentence,

式中, EsentenceR(n+1)×d 是给 sentence embedding 使用的位置编码,它用来编码全局空间信息 (global spatial information)。

  • word positional encoding:

作者这里使用的是可学习的 1D 位置编码:

Yi0Yi0+Eword, i=1,2,,n

式中, EwordRm×c 是给 word embedding 使用的位置编码,它们用来编码局部相对信息 (local relative information)。

  • 40.2 Pyramid TNT 原理分析:

TNT 作为一种通用的视觉任务 Backbone,取得了优异的性能。Pyramid TNT 受到 Transformer 模型两种主流改进方法:金字塔架构 (PVT,Swin Transformer,CycleMLP 等等)卷积 stem (Convolutional Stem) 的启发,改进了 TNT 架构。

Pyramid TNT 将它们融入 TNT 中,金字塔架构 (Pyramid Structure) 用于提取多尺度信息,卷积 stem (Convolutional Stem) 用于改善图片分块的方法和使得训练过程更加稳定。此外,Pyramid TNT 还包括其他一些 trick 比如相对位置编码等。

Convolutional Stem

给定输入图片 XRH×W ,ViT 的做法是通过一个 stride=kernel=patch size 的卷积进行图片的分块操作。Early convolutions help transformers see better (NeurIPS 2021) 这篇论文发现:将这个卷积操作替换成几个连续的卷积操作能够使得 Transformer 模型获得更好的性能,且对优化器更加鲁棒,增加了优化稳定性。基于这个发现,作者也对 TNT 模型应用了 Convolutional Stem。

具体而言 Pyramid TNT 的 Convolutional Stem 是 5 个 3×3 卷积。对于 Outer Transformer,Convolutional Stem 将输入图片变成 YRH2×W2×C ,式中 C 是 sentence embedding 维度。对于 Inner Transformer,Convolutional Stem 将输入图片变成 YRH8×W8×D ,式中 D 是 word embedding 的维度。对于位置编码,sentence positional encoding 和 word positional encoding 被分别添加在了 sentence embedding 和 word embedding 上。

Pyramid Architecture

[原始的 TNT 网络]:

在原始 TNT 中,在每个 Block 中保持相同数量的 tokens,遵循 ViT 的设计方式。视觉 "sentence" 和视觉 "words" 的数量自下而上一直保持不变。

视觉 "sentence" 的特征图分辨率自下而上一直是 Hp×Wp=H16×W16=14×14

视觉 "words" 的特征图分辨率自下而上一直是 p4×p4=4×4

[Pyramid TNT 网络]:

在 Pyramid TNT 中,网络在每个 stage 中保持不同数量的 tokens,遵循 PVT 的设计方式。视觉 "sentence" 和视觉 "words" 的数量自下而上分阶段变化。

视觉 "words" 的特征图分辨率在 4 个 stage 中分别是: H2×W2,H4×W4,H8×W8,H16×W16

视觉 "sentence" 的特征图分辨率在 4 个 stage 中分别是: H8×W8,H16×W16,H32×W32,H64×W64

通过 Convolution Stem,把 224×224 的输入图片分成 8×8 的大 patch,一共是 28×28 个。所以 Outer Transformer 特征图的分辨率是:Hout×Wout=28×28=H8×W8

通过 Convolution Stem,把 8×8 的大 patch 分成 2×2 的小 patch,一共是 4×4×28×28 个。所以 Inner Transformer 特征图的分辨率是:Hin×Win×Hout×Wout=4×4×28×28=H2×W2

不同 stage 之间通过一个 stride=2 的卷积操作降低特征分辨率。注意不同的 stage 的 Outer Transformer,即视觉 "sentence" 的特征图分辨率 Hout×Wout 是大小一直变化的,而不同的 stage 的 Inner Transformer视觉 "words" 的特征图分辨率一直是 4×4×Hout×Wout

实验结果

分类任务实验结果

数据集:ImageNet-1k (1,280,000 Training data, 50,000 validation data,1000 classes)

超参数设置:

实验结果如下图 5 所示。与原始 TNT 相比,Pyramid TNT 获得了更好的性能。 Pyramid TNT-S 比 TNT-S 少 1.9B 计算量,精度提高了 0.5%。作者还将 Pyramid TNT 与其他有代表性的 CNN、MLP 和基于 Transformer 的模型进行了比较。从结果中,我们可以看到 Pyramid TNT 是最先进的视觉 Backbone。

目标检测实验结果

数据集:COCO 2017 (118,000 Training data, 50,000 validation data)

对比的框架:RetinaNet,Mask R-CNN

超参数:Batch size=2,AdamW Optimizer,initial lr=1e-4,在第 8 和第 11 个 Epoch 分别乘以 0.1,weight decay=0.05,"1x" schedule (12 epochs),输入图片 resize 成 (1333, 800)。

金字塔的四个阶段的空间分辨率被设置为: H8×W8,H16×W16,H32×W32,H64×W64 。作者使用了 stride=2 的转置卷积和 BN 和 GeLU 激活函数加上 stride=1,kernel=3 的卷积和 BN 和 GeLU 激活函数,以产生 H4×W4,H8×W8,H16×W16,H32×W32 的分辨率的特征图。

在具有相似计算成本的 one-stage 和 two-stage 的检测器上,Pyramid-S 明显优于其他 Backbone。例如,基于 Pyramid-S 的 RetinaNet 达到了 42.0 AP 和 57.7 APL 。这些结果表明,金字塔结构有助于捕获更好的全局信息。

实例分割实验结果

数据集:COCO 2017 (118,000 Training data, 50,000 validation data)

对比的框架:Mask R-CNN,Cascade Mask R-CNN

超参数:Batch size=16,AdamW Optimizer,initial lr=1e-4,在第 27 和第 33 个 Epoch 分别乘以 0.1,weight decay=0.05,"3x" schedule,输入图片 resize 成 (1333, 800)。

Pyramid-S 在 Mask R-CNN 和 Cascade Mask R-CNN 上可以获得比其他 Backbone 好得多的 APbAPm ,显示出其更好的特征表示能力。例如,Pyramid-S 在 Mask R-CNN 上 Wave-MLP 高出 0.9 的 APb

  • 40.3 Pyramid TNT 代码解读:

代码来自:

https://github.com/huawei-noah/CV-Backbones/tree/master/tnt_pytorch​github.com/huawei-noah/CV-Backbones/tree/master/tnt_pytorch

一些张量的维度的大小已经在代码中以注释的形式进行标注。

Convolutional Stem:

class Stem(nn.Module):
    """ Image to Visual Word Embedding
    Overlap: https://arxiv.org/pdf/2106.13797.pdf
    """
    def __init__(self, img_size=224, in_chans=3, outer_dim=768, inner_dim=24):
        super().__init__()
        img_size = to_2tuple(img_size)
        self.img_size = img_size
        self.inner_dim = inner_dim
        self.num_patches = img_size[0] // 8 * img_size[1] // 8
        self.num_words = 16

        self.common_conv = nn.Sequential(
            nn.Conv2d(in_chans, inner_dim*2, 3, stride=2, padding=1),
            nn.BatchNorm2d(inner_dim*2),
            nn.ReLU(inplace=True),
        )
        self.inner_convs = nn.Sequential(
            nn.Conv2d(inner_dim*2, inner_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(inner_dim),
            nn.ReLU(inplace=False),
        )
        self.outer_convs = nn.Sequential(
            nn.Conv2d(inner_dim*2, inner_dim*4, 3, stride=2, padding=1),
            nn.BatchNorm2d(inner_dim*4),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_dim*4, inner_dim*8, 3, stride=2, padding=1),
            nn.BatchNorm2d(inner_dim*8),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_dim*8, outer_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(outer_dim),
            nn.ReLU(inplace=False),
        )

        self.unfold = nn.Unfold(kernel_size=4, padding=0, stride=4)

    def forward(self, x):
        B, C, H, W = x.shape
        H_out, W_out = H // 8, W // 8
        H_in, W_in = 4, 4
        x = self.common_conv(x)
        # inner_tokens
        # inner_tokens: (B, inner_dim, H/2, W/2)
        inner_tokens = self.inner_convs(x) # B, C, H, W
        # inner_tokens: (B, H/8, W/8, inner_dim*4*4)
        inner_tokens = self.unfold(inner_tokens).transpose(1, 2) # B, N, Ck2
        # inner_tokens: (B, inner_dim, H/8*W/8, 4*4)
        inner_tokens = inner_tokens.reshape(B * H_out * W_out, self.inner_dim, H_in*W_in).transpose(1, 2) # B*N, C, 4*4
        # outer_tokens
        # outer_tokens: (B, outer_dim, H/8, W/8)
        outer_tokens = self.outer_convs(x) # B, C, H_out, W_out
        # outer_tokens: (B, H/8*W/8, outer_dim)
        outer_tokens = outer_tokens.permute(0, 2, 3, 1).reshape(B, H_out * W_out, -1)
        return inner_tokens, outer_tokens, (H_out, W_out), (H_in, W_in)

注意 Convolution Stem 返回的 inner_tokens 和 outer_tokens 张量的维度:
outer_tokens: (B, H/8W/8, outer_dim)
inner_tokens: (B, inner_dim, H/8
W/8, 44)
Convolution Stem 返回的 inner_tokens 和 outer_tokens 分别通过后面 Block 类的 Inner Attention 和 Outer Attention,二者输出的张量维度分别是:(B
H/8W/8, 44, inner_dim) 和 (B, H/8*W/8, outer_dim)。之后,这两个张量再按照上式 10 中的方式融合在一起。
其实 Pyramid Transformer in Transformer 代码的核心是通过这个 Convolution Stem 分别得到两个不同维度的张量,一个输入 Outer Transformer Block,一个输入 Inner Transformer Block。这两个 Transformer Block 的输出再按照上式 10 中的方式融合在一起。

MLP:

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

Attention 类 (这里作者用了 PVT V2 的轻量 attention 类的实现):

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.pool = nn.AvgPool2d(sr_ratio, stride=sr_ratio)
            self.linear = nn.Linear(dim, dim)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W, relative_pos=None):
        B, N, C = x.shape
        # q: (B, nH, N, C/nH)
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            # x_: (B, C, H, W)
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            # x_: (B, N/4, C)
            x_ = self.pool(x_).reshape(B, C, -1).permute(0, 2, 1)
            # x_: (B, N/4, C)
            x_ = self.norm(self.linear(x_))
            # x_: (2, B, nH, N/4, C/nH)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # k,v: (B, nH, N/4, C/nH)
        k, v = kv[0], kv[1]

        # attn: (B, nH, N, N/4)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        if relative_pos is not None:
            attn += relative_pos
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # X: (B, N, C)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

一个 Pyramid TNT Block 的实现:

class Block(nn.Module):
    """ TNT Block
    """
    def __init__(self, outer_dim, inner_dim, outer_head, inner_head, num_words, mlp_ratio=4.,
                 qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm, se=0, sr_ratio=1):
        super().__init__()
        self.has_inner = inner_dim > 0
        if self.has_inner:
            # Inner
            self.inner_norm1 = norm_layer(num_words * inner_dim)
            self.inner_attn = Attention(
                inner_dim, num_heads=inner_head, qkv_bias=qkv_bias,
                qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
            self.inner_norm2 = norm_layer(num_words * inner_dim)
            self.inner_mlp = Mlp(in_features=inner_dim, hidden_features=int(inner_dim * mlp_ratio),
                                 out_features=inner_dim, act_layer=act_layer, drop=drop)

            self.proj_norm1 = norm_layer(num_words * inner_dim)
            self.proj = nn.Linear(num_words * inner_dim, outer_dim, bias=False)
            self.proj_norm2 = norm_layer(outer_dim)
        # Outer
        self.outer_norm1 = norm_layer(outer_dim)
        self.outer_attn = Attention(
            outer_dim, num_heads=outer_head, qkv_bias=qkv_bias,
            qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.outer_norm2 = norm_layer(outer_dim)
        self.outer_mlp = Mlp(in_features=outer_dim, hidden_features=int(outer_dim * mlp_ratio),
                             out_features=outer_dim, act_layer=act_layer, drop=drop)
        # SE
        self.se = se
        self.se_layer = None
        if self.se > 0:
            self.se_layer = SE(outer_dim, 0.25)

    def forward(self, x, outer_tokens, H_out, W_out, H_in, W_in, relative_pos):
        # outer_tokens: (B, H/8*W/8, outer_dim)
        B, N, C = outer_tokens.size()
        if self.has_inner:
            # x: (B*H/8*W/8, 4*4, inner_dim)
            x = x + self.drop_path(self.inner_attn(self.inner_norm1(x.reshape(B, N, -1)).reshape(B*N, H_in*W_in, -1), H_in, W_in)) # B*N, k*k, c
            # x: (B*H/8*W/8, 4*4, inner_dim)
            x = x + self.drop_path(self.inner_mlp(self.inner_norm2(x.reshape(B, N, -1)).reshape(B*N, H_in*W_in, -1))) # B*N, k*k, c
            # outer_tokens: (B, H/8*W/8, outer_dim)
            outer_tokens = outer_tokens + self.proj_norm2(self.proj(self.proj_norm1(x.reshape(B, N, -1)))) # B, N, C
        if self.se > 0:
            outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens), H_out, W_out, relative_pos))
            tmp_ = self.outer_mlp(self.outer_norm2(outer_tokens))
            outer_tokens = outer_tokens + self.drop_path(tmp_ + self.se_layer(tmp_))
        else:
            # outer_tokens: (B, H/8*W/8, outer_dim)
            outer_tokens = outer_tokens + self.drop_path(self.outer_attn(self.outer_norm1(outer_tokens), H_out, W_out, relative_pos))
            # outer_tokens: (B, H/8*W/8, outer_dim)
            outer_tokens = outer_tokens + self.drop_path(self.outer_mlp(self.outer_norm2(outer_tokens)))
        # x: (B*H/8*W/8, 4*4, inner_dim)
        # outer_tokens: (B, H/8*W/8, outer_dim)
        return x, outer_tokens

和 TNT 基本一致,不同之处是前向函数中还需要传入 H_out, W_out, H_in, W_in, relative_pos 这些参数,它们分别代表大 patch 和小 patch 的特征分辨率大小。

一个 Pyramid TNT Stage 的实现:

class Stage(nn.Module):
    """ PyramidTNT stage
    """
    def __init__(self, num_blocks, outer_dim, inner_dim, outer_head, inner_head, num_patches, num_words, mlp_ratio=4.,
                 qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm, se=0, sr_ratio=1):
        super().__init__()
        blocks = []
        drop_path = drop_path if isinstance(drop_path, list) else [drop_path] * num_blocks

        for j in range(num_blocks):
            if j == 0:
                _inner_dim = inner_dim
            elif j == 1 and num_blocks > 6:
                _inner_dim = inner_dim
            else:
                _inner_dim = -1
            blocks.append(Block(
                outer_dim, _inner_dim, outer_head=outer_head, inner_head=inner_head,
                num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop,
                attn_drop=attn_drop, drop_path=drop_path[j], act_layer=act_layer, norm_layer=norm_layer,
                se=se, sr_ratio=sr_ratio))

        self.blocks = nn.ModuleList(blocks)
        self.relative_pos = nn.Parameter(torch.randn(
                        1, outer_head, num_patches, num_patches // sr_ratio // sr_ratio))

    def forward(self, inner_tokens, outer_tokens, H_out, W_out, H_in, W_in):
        for blk in self.blocks:
            inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens, H_out, W_out, H_in, W_in, self.relative_pos)
        return inner_tokens, outer_tokens

不同的 stage 之间应有下采样的操作。"sentence" level 和 "word" level 的下采样分别通过下面的 SentenceAggregation 类和 WordAggregation 类来解决:

class SentenceAggregation(nn.Module):
    """ Sentence Aggregation
    """
    def __init__(self, dim_in, dim_out, stride=2, act_layer=nn.GELU):
        super().__init__()
        self.stride = stride
        self.norm = nn.LayerNorm(dim_in)
        self.conv = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=2*stride-1, padding=stride-1, stride=stride),
        )

    def forward(self, x, H, W):
        B, N, C = x.shape # B, N, C
        x = self.norm(x)
        x = x.transpose(1, 2).reshape(B, C, H, W)
        x = self.conv(x)
        H, W = math.ceil(H / self.stride), math.ceil(W / self.stride)
        x = x.reshape(B, -1, H * W).transpose(1, 2)
        return x, H, W

class WordAggregation(nn.Module):
    """ Word Aggregation
    """
    def __init__(self, dim_in, dim_out, stride=2, act_layer=nn.GELU):
        super().__init__()
        self.stride = stride
        self.dim_out = dim_out
        self.norm = nn.LayerNorm(dim_in)
        self.conv = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=2*stride-1, padding=stride-1, stride=stride),
        )

    def forward(self, x, H_out, W_out, H_in, W_in):
        B_N, M, C = x.shape # B*N, M, C
        x = self.norm(x)
        x = x.reshape(-1, H_out, W_out, H_in, W_in, C)

        # padding to fit (1333, 800) in detection.
        pad_input = (H_out % 2 == 1) or (W_out % 2 == 1)
        if pad_input:
            x = F.pad(x.permute(0, 3, 4, 5, 1, 2), (0, W_out % 2, 0, H_out % 2))
            x = x.permute(0, 4, 5, 1, 2, 3)            
        # patch merge
        x1 = x[:, 0::2, 0::2, :, :, :]  # B, H/2, W/2, H_in, W_in, C
        x2 = x[:, 1::2, 0::2, :, :, :]
        x3 = x[:, 0::2, 1::2, :, :, :]
        x4 = x[:, 1::2, 1::2, :, :, :]
        x = torch.cat([torch.cat([x1, x2], 3), torch.cat([x3, x4], 3)], 4) # B, H/2, W/2, 2*H_in, 2*W_in, C
        x = x.reshape(-1, 2*H_in, 2*W_in, C).permute(0, 3, 1, 2) # B_N/4, C, 2*H_in, 2*W_in
        x = self.conv(x)  # B_N/4, C, H_in, W_in
        x = x.reshape(-1, self.dim_out, M).transpose(1, 2)
        return x

我们可以发现 "sentence" level 和 "word" level 的下采样都是通过一个卷积操作完成。
第 1 个 stage 结束后的下采样:
SentenceAggregation 输入维度: (B,H8×W8,outer dim)输出维度: (B,H16×W16,2×outer dim)
WordAggregation 输入维度:(B×H8×W8,4×4,inner dim)输出维度: (B×H16×W16,4×4,2×inner dim)
第 2 个 stage 结束后的下采样:
SentenceAggregation 输入维度: (B,H16×W16,2×outer dim)输出维度: (B,H32×W32,4×outer dim)
WordAggregation 输入维度:(B×H16×W16,4×4,2×inner dim)输出维度: (B×H32×W32,4×4,4×inner dim)
第 3 个 stage 结束后的下采样:
SentenceAggregation 输入维度: (B,H32×W32,4×outer dim)输出维度: (B,H64×W64,4×outer dim)
WordAggregation 输入维度:(B×H32×W32,4×4,4×inner dim)输出维度: (B×H64×W64,4×4,4×inner dim)

Pyramid TNT 整体模型架构:

class PyramidTNT(nn.Module):
    """ PyramidTNT (Transformer in Transformer) for computer vision
    """
    def __init__(self, configs=None, img_size=224, in_chans=3, num_classes=1000, mlp_ratio=4., qkv_bias=False,
                qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, se=0):
        super().__init__()
        self.num_classes = num_classes
        depths = configs['depths']
        outer_dims = configs['outer_dims']
        inner_dims = configs['inner_dims']
        outer_heads = configs['outer_heads']
        inner_heads = configs['inner_heads']
        sr_ratios = [4, 2, 1, 1]
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule 
        self.num_features = outer_dims[-1]  # num_features for consistency with other models 

        self.patch_embed = Stem(
            img_size=img_size, in_chans=in_chans, outer_dim=outer_dims[0], inner_dim=inner_dims[0])
        num_patches = self.patch_embed.num_patches
        num_words = self.patch_embed.num_words

        self.outer_pos = nn.Parameter(torch.zeros(1, num_patches, outer_dims[0]))
        self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dims[0]))
        self.pos_drop = nn.Dropout(p=drop_rate)

        depth = 0
        self.word_merges = nn.ModuleList([])
        self.sentence_merges = nn.ModuleList([])
        self.stages = nn.ModuleList([])
        for i in range(4):
            if i > 0:
                self.word_merges.append(WordAggregation(inner_dims[i-1], inner_dims[i], stride=2))
                self.sentence_merges.append(SentenceAggregation(outer_dims[i-1], outer_dims[i], stride=2))
            self.stages.append(Stage(depths[i], outer_dim=outer_dims[i], inner_dim=inner_dims[i],
                        outer_head=outer_heads[i], inner_head=inner_heads[i],
                        num_patches=num_patches // (2 ** i) // (2 ** i), num_words=num_words, mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate,
                        drop_path=dpr[depth:depth+depths[i]], norm_layer=norm_layer, se=se, sr_ratio=sr_ratios[i])
            )
            depth += depths[i]

        self.norm = norm_layer(outer_dims[-1])

        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
        # self.repr = nn.Linear(outer_dim, representation_size)
        # self.repr_act = nn.Tanh()

        # Classifier head
        self.head = nn.Linear(outer_dims[-1], num_classes) if num_classes > 0 else nn.Identity()

        trunc_normal_(self.outer_pos, std=.02)
        trunc_normal_(self.inner_pos, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        if isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

 @torch.jit.ignore
    def no_weight_decay(self):
        return {'outer_pos', 'inner_pos'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        inner_tokens, outer_tokens, (H_out, W_out), (H_in, W_in) = self.patch_embed(x)
        inner_tokens += self.inner_pos # B*N, 8*8, C
        outer_tokens += self.pos_drop(self.outer_pos)  # B, N, D

        for i in range(4):
            if i > 0:
                inner_tokens = self.word_merges[i-1](inner_tokens, H_out, W_out, H_in, W_in)
                outer_tokens, H_out, W_out = self.sentence_merges[i-1](outer_tokens, H_out, W_out)
            inner_tokens, outer_tokens = self.stages[i](inner_tokens, outer_tokens, H_out, W_out, H_in, W_in)

        outer_tokens = self.norm(outer_tokens)
        return outer_tokens.mean(dim=1)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

不同大小的 Pyramid TNT 配置信息:

@register_model
def ptnt_ti_patch16_192(pretrained=False, **kwargs):
    outer_dim = 80
    inner_dim = 5
    outer_head = 2
    inner_head = 1
    configs = {
        'depths': [2, 6, 3, 2],
        'outer_dims': [outer_dim, outer_dim*2, outer_dim*4, outer_dim*4],
        'inner_dims': [inner_dim, inner_dim*2, inner_dim*4, inner_dim*4],
        'outer_heads': [outer_head, outer_head*2, outer_head*4, outer_head*4],
        'inner_heads': [inner_head, inner_head*2, inner_head*4, inner_head*4],
    }

    model = PyramidTNT(configs=configs, img_size=192, qkv_bias=False, **kwargs)
    model.default_cfg = default_cfgs['tnt_s_patch16_192']
    if pretrained:
        load_pretrained(
            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
    return model

小结

本文介绍了 Pyramid TNT 架构的原理和 PyTorch 代码实现。TNT 作为一种通用的视觉任务 Backbone,取得了优异的性能。Pyramid TNT 受到 Transformer 模型两种主流改进方法:金字塔架构 (PVT,Swin Transformer,CycleMLP 等等) 和卷积 stem (Convolutional Stem) 的启发,改进了 TNT 架构。Pyramid TNT 将它们融入 TNT 中,金字塔架构 (Pyramid Structure) 用于提取多尺度信息,卷积 stem (Convolutional Stem) 用于改善图片分块的方法和使得训练过程更加稳定。此外,Pyramid TNT 还包括其他一些 trick 比如相对位置编码等。