ST-GCN 代码解读
参见博文:ST-GCN 论文解读
论文代码:https://github.com/yysijie/st-gcn
graph.py
文件路径:st-gcn/net/utils/graph.py
import numpy as npclass Graph():""" The Graph to model the skeletons extracted by the openposeArgs:strategy (string): must be one of the follow candidates- uniform: Uniform Labeling- distance: Distance Partitioning- spatial: Spatial ConfigurationFor more information, please refer to the section 'Partition Strategies'in our paper (https://arxiv.org/abs/1801.07455).layout (string): must be one of the follow candidates- openpose: Is consists of 18 joints. For more information, pleaserefer to https://github.com/CMU-Perceptual-Computing-Lab/openpose#output- ntu-rgb+d: Is consists of 25 joints. For more information, pleaserefer to https://github.com/shahroudy/NTURGB-Dmax_hop (int): the maximal distance between two connected nodesdilation (int): controls the spacing between the kernel points"""def __init__(self,layout='openpose',strategy='uniform',max_hop=1,dilation=1):self.max_hop = max_hopself.dilation = dilationself.get_edge(layout)self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop)self.get_adjacency(strategy)def __str__(self):return self.Adef get_edge(self, layout):if layout == 'openpose':self.num_node = 18self_link = [(i, i) for i in range(self.num_node)]neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,11),(10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),(0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]self.edge = self_link + neighbor_linkself.center = 1elif layout == 'ntu-rgb+d':self.num_node = 25self_link = [(i, i) for i in range(self.num_node)]neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),(6, 5), (7, 6), (8, 7), (9, 21), (10, 9),(11, 10), (12, 11), (13, 1), (14, 13), (15, 14),(16, 15), (17, 1), (18, 17), (19, 18), (20, 19),(22, 23), (23, 8), (24, 25), (25, 12)]neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]self.edge = self_link + neighbor_linkself.center = 21 - 1elif layout == 'ntu_edge':self.num_node = 24self_link = [(i, i) for i in range(self.num_node)]neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),(8, 7), (9, 2), (10, 9), (11, 10), (12, 11),(13, 1), (14, 13), (15, 14), (16, 15), (17, 1),(18, 17), (19, 18), (20, 19), (21, 22), (22, 8),(23, 24), (24, 12)]neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]self.edge = self_link + neighbor_linkself.center = 2# elif layout=='customer settings'# passelse:raise ValueError("Do Not Exist This Layout.")def get_adjacency(self, strategy):valid_hop = range(0, self.max_hop + 1, self.dilation)adjacency = np.zeros((self.num_node, self.num_node))for hop in valid_hop:adjacency[self.hop_dis == hop] = 1normalize_adjacency = normalize_digraph(adjacency)if strategy == 'uniform':A = np.zeros((1, self.num_node, self.num_node))A[0] = normalize_adjacencyself.A = Aelif strategy == 'distance':A = np.zeros((len(valid_hop), self.num_node, self.num_node))for i, hop in enumerate(valid_hop):A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==hop]self.A = Aelif strategy == 'spatial':A = []for hop in valid_hop:a_root = np.zeros((self.num_node, self.num_node))a_close = np.zeros((self.num_node, self.num_node))a_further = np.zeros((self.num_node, self.num_node))for i in range(self.num_node):for j in range(self.num_node):if self.hop_dis[j, i] == hop:if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]:a_root[j, i] = normalize_adjacency[j, i]elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]:a_close[j, i] = normalize_adjacency[j, i]else:a_further[j, i] = normalize_adjacency[j, i]if hop == 0:A.append(a_root)else:A.append(a_root + a_close)A.append(a_further)A = np.stack(A)self.A = Aelse:raise ValueError("Do Not Exist This Strategy")def get_hop_distance(num_node, edge, max_hop=1):A = np.zeros((num_node, num_node))for i, j in edge:A[j, i] = 1A[i, j] = 1# compute hop stepshop_dis = np.zeros((num_node, num_node)) + np.inftransfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]arrive_mat = (np.stack(transfer_mat) > 0)for d in range(max_hop, -1, -1):hop_dis[arrive_mat[d]] = dreturn hop_disdef normalize_digraph(A):Dl = np.sum(A, 0)num_node = A.shape[0]Dn = np.zeros((num_node, num_node))for i in range(num_node):if Dl[i] > 0:Dn[i, i] = Dl[i](-1)AD = np.dot(A, Dn)return ADdef normalize_undigraph(A):Dl = np.sum(A, 0)num_node = A.shape[0]Dn = np.zeros((num_node, num_node))for i in range(num_node):if Dl[i] > 0:Dn[i, i] = Dl[i](-0.5)DAD = np.dot(np.dot(Dn, A), Dn)return DAD
这段代码是一个 Graph 类,用于建立骨架(skeleton)的连接关系。该类接受四个参数,包括骨架类型(layout),连接策略(strategy),最大跳数(max_hop)以及步长(dilation)。根据不同的骨架类型,会有不同数量(openpose:18,ntu-rgb+d:25)和位置的节点,同时也有不同的连接关系。
- 连接策略包括三种:均匀(uniform)、距离(distance)和空间(spatial)。
- 在初始化过程中,类会根据节点数和连接关系计算出每个节点与其他节点的距离(hop_dis),根据距离设置每个节点之间的邻接矩阵(adjacency),并对邻接矩阵进行归一化(normalize_digraph)。
- 最后,根据连接策略,将邻接矩阵转化为所需的图矩阵(AAA)。
其中,距离策略需要生成多个不同阈值下的矩阵;空间策略,则需要进一步分类根节点、近邻、远邻等不同的节点,并生成多个矩阵。
分段理解
def get_edge(self, layout):if layout == 'openpose':self.num_node = 18self_link = [(i, i) for i in range(self.num_node)]neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,11),(10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),(0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]self.edge = self_link + neighbor_linkself.center = 1elif layout == 'ntu-rgb+d':self.num_node = 25self_link = [(i, i) for i in range(self.num_node)]neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),(6, 5), (7, 6), (8, 7), (9, 21), (10, 9),(11, 10), (12, 11), (13, 1), (14, 13), (15, 14),(16, 15), (17, 1), (18, 17), (19, 18), (20, 19),(22, 23), (23, 8), (24, 25), (25, 12)]neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]self.edge = self_link + neighbor_linkself.center = 21 - 1elif layout == 'ntu_edge':self.num_node = 24self_link = [(i, i) for i in range(self.num_node)]neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6),(8, 7), (9, 2), (10, 9), (11, 10), (12, 11),(13, 1), (14, 13), (15, 14), (16, 15), (17, 1),(18, 17), (19, 18), (20, 19), (21, 22), (22, 8),(23, 24), (24, 12)]neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]self.edge = self_link + neighbor_linkself.center = 2# elif layout=='customer settings'# passelse:raise ValueError("Do Not Exist This Layout.")
函数get_edge(self, layout) 用于将不同骨架类型(layout)的节点连接关系(edge)表示成 edge 矩阵。在函数中,
- 根据输入的骨架类型,设置相应的节点数(num_node)、自连接(self_link)和邻接连接(neighbor_link);
- 然后将自连接和邻接连接拼接成完整的 edge 矩阵(edge)。
- 函数还会返回中心节点(center)的编号。
- 其中,‘openpose’、‘ntu-rgb+d’ 和 ‘ntu_edge’ 是预先定义好的三种骨架类型,如果输入的骨架类型不是这三种,则抛出一个值错误(ValueError)。如果需要定制其他类型的节点连接关系,可以在该函数中添加相关代码实现。
self_link = [(i, i) for i in range(self.num_node)]
用于生成自连接(self_link)的边。其中,self.num_node 是一个整数,表示节点总数。range(self.num_node)返回从0到 self.num_node-1的整数序列,即所有节点的编号。对于每一个节点 i 来说,(i, i)表示该节点自身与自身之间的连接。因此,这行代码生成了一个包含所有节点自连接边的列表 self_link。例如,如果 self.num_node 等于3,则该行代码的结果为[(0,0), (1,1), (2,2)]。
def get_adjacency(self, strategy):valid_hop = range(0, self.max_hop + 1, self.dilation)adjacency = np.zeros((self.num_node, self.num_node))for hop in valid_hop:adjacency[self.hop_dis == hop] = 1normalize_adjacency = normalize_digraph(adjacency)if strategy == 'uniform':A = np.zeros((1, self.num_node, self.num_node))A[0] = normalize_adjacencyself.A = Aelif strategy == 'distance':A = np.zeros((len(valid_hop), self.num_node, self.num_node))for i, hop in enumerate(valid_hop):A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==hop]self.A = Aelif strategy == 'spatial':A = []for hop in valid_hop:a_root = np.zeros((self.num_node, self.num_node))a_close = np.zeros((self.num_node, self.num_node))a_further = np.zeros((self.num_node, self.num_node))for i in range(self.num_node):for j in range(self.num_node):if self.hop_dis[j, i] == hop:if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]:a_root[j, i] = normalize_adjacency[j, i]elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]:a_close[j, i] = normalize_adjacency[j, i]else:a_further[j, i] = normalize_adjacency[j, i]if hop == 0:A.append(a_root)else:A.append(a_root + a_close)A.append(a_further)A = np.stack(A)self.A = Aelse:raise ValueError("Do Not Exist This Strategy")
函数 get_adjacency(self, strategy)用于根据不同的输入策略(strategy)生成不同的节点邻接矩阵(adjacency matrix),并将其组成张量(A)。
- 首先,在循环中,生成了一系列整数(valid_hop),表示在几跳之内可以到达相邻节点。
- 然后,程序构造了一个全零矩阵 adjacency,其大小为(self.num_node, self.num_node),其中 self.num_node 表示图中节点的数量。这个矩阵将被用于存储不同跳数下的邻接矩阵信息。
- 接下来,程序遍历 valid_hop 迭代器对象,获取当前的跳数 hop。对于当前的 hop,程序通过
self.hop_dis == hop
语句获取所有距离为 hop 的节点对,并将它们在 adjacency 矩阵中的对应值设置为1。这样,当 hop 取得不同的值时,不同的节点对将被加入到邻接矩阵中,从而构成了一个由多个邻接矩阵组成的图。
接着,根据不同的输入策略,生成不同的节点邻接矩阵:
- uniform:直接将归一化的 adjacency 矩阵放入张量 AAA 中,形状为(1, num_node, num_node),即单通道矩阵。
- distance:根据 valid_hop 列表的长度,生成一个三维矩阵 AAA,每一层是一种距离,一个像素是一个节点,每一层的矩阵用 adjacency 矩阵中对应距离的部分来填充。例如,如果 valid_hop=[0,2],则生成两层矩阵,第一层是距离为0的邻接矩阵,第二层是距离为2的邻接矩阵。
- spatial:生成一个三维矩阵 AAA,每一层是一种距离,一个像素是一个节点,每一层的矩阵将节点划分为三类:与中心节点相同距离的节点、比中心节点近的节点、比中心节点远的节点。分别用三个矩阵来表示这三类节点的邻接情况,然后合并在一起即可。
最后,函数会将张量 A 赋值给 self.A。如果输入的策略不是 ‘uniform’ 、‘distance’ 或 ‘spatial’,则抛出一个值错误(ValueError)。
valid_hop = range(0, self.max_hop + 1, self.dilation)
valid_hop 是一个迭代器对象,其中的元素表示节点之间允许的最大跳数范围。具体来说,valid_hop 通过 range函数生成,起点为0,终点为 self.max_hop+1(不包含终点),步长为 self.dilation。举例说明,比如一张图中有7个节点,self.max_hop=3,self.dilation=2,那么 valid_hop 就等同于[0, 2, 4]。其中,0表示两个节点之间允许直接相连的情况,2表示最多经过两个节点即可到达目标节点的情况,4表示最多经过四个节点即可到达目标节点的情况。由此可见,随着跳数的增加,节点之间的连接方式会越来越复杂。
def get_hop_distance(num_node, edge, max_hop=1):A = np.zeros((num_node, num_node))for i, j in edge:A[j, i] = 1A[i, j] = 1# compute hop stepshop_dis = np.zeros((num_node, num_node)) + np.inftransfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]arrive_mat = (np.stack(transfer_mat) > 0)for d in range(max_hop, -1, -1):hop_dis[arrive_mat[d]] = dreturn hop_dis
- 首先根据输入的节点数和边信息构造了一个二维邻接矩阵 AAA,其中如果节点 iii 和节点 jjj 之间存在一条边,则 A[i,j]A[i,j]A[i,j] 和 A[j,i]A[j,i]A[j,i] 位置上的元素值均为1,否则都为0。这里需要注意的是,矩阵 AAA 是基于无向图的。
- 接着,程序创建了一个名为 hop_dis 的二维数组,用于存储任意两个节点之间的跳数。hop_dis 的大小和 AAA 矩阵相同,初始值设为 np.inf,表示所有节点之间的距离都是无穷大。
- 然后,程序通过循环迭代计算 AAA 矩阵的幂次方,从而确定每两个节点之间的可达性。这里采用了 numpy.linalg 下的 matrix_power 函数,可以快速求解矩阵的幂次方。transfer_mat 是一个列表,其中第 iii 个元素表示 AAA 矩阵的 iii 次幂,也就是说,transfer_mat[i]表示从距离为 iii 的节点到距离为0的节点所需经过的路径数,如果该值大于0,则说明距离为 iii 的节点可以到达距离为0的节点。注意,这个计算包含了自身到自身的情况(跳数为0),因此 transfer_mat 的长度为 max_hop+1。
- 然后,程序通过 np.stack 将所有 transfer_mat 合并成一个张量 arrive_mat,其中第 iii 个二维张量表示跳数为 iii 的所有节点对之间的可达性。由于 transfer_mat 的取值均为0和1,因此 arrive_mat 也只含有0和1两种取值,进一步减小了存储空间和计算开销。
- 最后,程序通过逆序迭代处理 arrive_mat 张量的各个二维矩阵,从而得到任意两个节点之间的跳数。具体来说,从跳数为 max_hop 的矩阵开始,将 arrive_mat[d]中所有非零元素所对应的位置在 hop_dis 中设为 ddd;再逐次处理跳数为 d−1,d−2,...,1d-1,d-2,...,1d−1,d−2,...,1,直到跳数为0时完成计算。这样,hop_dis 中就会存储所有节点之间的跳数,其中 hop_dis[i,j]表示从节点 iii 到节点 jjj 所需的最少跳数(如果节点 iii 和节点 jjj 之间不连通,则 hop_dis[i,j]为 np.inf)。
def normalize_digraph(A):Dl = np.sum(A, 0)num_node = A.shape[0]Dn = np.zeros((num_node, num_node))for i in range(num_node):if Dl[i] > 0:Dn[i, i] = Dl[i](-1)AD = np.dot(A, Dn)return AD
给定一个有向图的邻接矩阵 AAA,该函数将其归一化为一种称为“随机游走”的形式。具体来说,
- 首先求出每个节点的出度和入度之和,也就是沿着每个节点的出边和入边各遍历一次时经过该节点的次数总和;
- 然后对于每个节点,计算其出度和入度之和,如果大于零则将其得到一个归一化因子,表示从该节点出发的所有出边上权重之和的倒数;
- 最后,将原始邻接矩阵 AAA 乘以这些归一化因子组成的对角矩阵 DnD_nDn,得到归一化后的邻接矩阵 ADADAD。
这个操作有助于将图中每个节点的影响力(PageRank)传递给其他节点,在随机游走模型中起着重要的作用。
def normalize_undigraph(A):Dl = np.sum(A, 0)num_node = A.shape[0]Dn = np.zeros((num_node, num_node))for i in range(num_node):if Dl[i] > 0:Dn[i, i] = Dl[i](-0.5)DAD = np.dot(np.dot(Dn, A), Dn)return DAD
给定一个无向图的邻接矩阵 AAA,该函数将其归一化为一种称为“谱聚类”的形式。具体来说,
- 首先求出每个节点的度数之和,也就是沿着每个节点的边(不区分入边和出边)走过该节点的次数总和;
- 然后对于每个节点,计算其度数的平方根并求倒数,表示从该节点出发经过的边数对该节点影响力的调整因素;
- 最后,将原始邻接矩阵 AAA 乘以这些调整因子组成的对角矩阵 DnD_nDn 两次,即 DADDADDAD,得到归一化后的邻接矩阵。
这个操作有助于将图中每个节点的特征转化为一种线性代数形式,并且降低了图像素之间距离的差异,使得聚类更加准确。在谱聚类模型中,该操作通常被用来计算拉普拉斯矩阵,从而实现图像素聚类。
tgcn.py
文件路径:st-gcn/net/utils/tgcn.py
# The based unit of graph convolutional networks.import torch
import torch.nn as nnclass ConvTemporalGraphical(nn.Module):r"""The basic module for applying a graph convolution.Args:in_channels (int): Number of channels in the input sequence dataout_channels (int): Number of channels produced by the convolutionkernel_size (int): Size of the graph convolving kernelt_kernel_size (int): Size of the temporal convolving kernelt_stride (int, optional): Stride of the temporal convolution. Default: 1t_padding (int, optional): Temporal zero-padding added to both sides ofthe input. Default: 0t_dilation (int, optional): Spacing between temporal kernel elements.Default: 1bias (bool, optional): If ``True``, adds a learnable bias to the output.Default: ``True``Shape:- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` formatwhere:math:`N` is a batch size,:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,:math:`T_{in}/T_{out}` is a length of input/output sequence,:math:`V` is the number of graph nodes. """def __init__(self,in_channels,out_channels,kernel_size,t_kernel_size=1,t_stride=1,t_padding=0,t_dilation=1,bias=True):super().__init__()self.kernel_size = kernel_sizeself.conv = nn.Conv2d(in_channels,out_channels * kernel_size,kernel_size=(t_kernel_size, 1),padding=(t_padding, 0),stride=(t_stride, 1),dilation=(t_dilation, 1),bias=bias)def forward(self, x, A):assert A.size(0) == self.kernel_sizex = self.conv(x)n, kc, t, v = x.size()x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)x = torch.einsum('nkctv,kvw->nctw', (x, A))return x.contiguous(), A
ConvTemporalGraphical 模块的初始化函数中的输入参数说明:
- in_channels:输入数据序列的通道数。
- out_channels:卷积操作产生的输出数据的通道数。
- kernel_size:图卷积核的大小。
- t_kernel_size:时间卷积核的大小。
- t_stride:时间卷积的步幅,默认为1。
- t_padding:在输入数据的两侧添加的时间维度的零填充数量,默认为0。
- t_dilation:时间卷积核内部元素之间的间隔,默认为1。
- bias:是否使用可学习的偏差项,默认为True。
它通过输入的图序列数据和邻接矩阵对数据进行卷积操作,得到输出数据和输出邻接矩阵。其中,输入数据和输出数据的格式为(N, in_channels/out_channels, T_in/T_out, V),
- N 为 batch size,
- in_channels/out_channels 为输入/输出特征图的通道数,
- T_in/T_out 为序列长度,
- V 为节点数;
邻接矩阵格式为(K, V, V),其中 K 为空间卷积核大小,等于 kernel_size。具体实现上,该模块使用了二维卷积以及张量乘法的方式,经过卷积操作后使用 einsum 函数将结果与邻接矩阵相乘。
st-gcn.py
文件路径:st-gcn/net/st_gcn.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variablefrom net.utils.tgcn import ConvTemporalGraphical
from net.utils.graph import Graphclass Model(nn.Module):r"""Spatial temporal graph convolutional networks.Args:in_channels (int): Number of channels in the input datanum_class (int): Number of classes for the classification taskgraph_args (dict): The arguments for building the graphedge_importance_weighting (bool): If ``True``, adds a learnableimportance weighting to the edges of the graphkwargs (optional): Other parameters for graph convolution unitsShape:- Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`- Output: :math:`(N, num_class)` where:math:`N` is a batch size,:math:`T_{in}` is a length of input sequence,:math:`V_{in}` is the number of graph nodes,:math:`M_{in}` is the number of instance in a frame."""def __init__(self, in_channels, num_class, graph_args,edge_importance_weighting, kwargs):super().__init__()# load graphself.graph = Graph(graph_args)A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)self.register_buffer('A', A)# build networksspatial_kernel_size = A.size(0)temporal_kernel_size = 9kernel_size = (temporal_kernel_size, spatial_kernel_size)self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}self.st_gcn_networks = nn.ModuleList((st_gcn(in_channels, 64, kernel_size, 1, residual=False, kwargs0),st_gcn(64, 64, kernel_size, 1, kwargs),st_gcn(64, 64, kernel_size, 1, kwargs),st_gcn(64, 64, kernel_size, 1, kwargs),st_gcn(64, 128, kernel_size, 2, kwargs),st_gcn(128, 128, kernel_size, 1, kwargs),st_gcn(128, 128, kernel_size, 1, kwargs),st_gcn(128, 256, kernel_size, 2, kwargs),st_gcn(256, 256, kernel_size, 1, kwargs),st_gcn(256, 256, kernel_size, 1, kwargs),))# initialize parameters for edge importance weightingif edge_importance_weighting:self.edge_importance = nn.ParameterList([nn.Parameter(torch.ones(self.A.size()))for i in self.st_gcn_networks])else:self.edge_importance = [1] * len(self.st_gcn_networks)# fcn for predictionself.fcn = nn.Conv2d(256, num_class, kernel_size=1)def forward(self, x):# data normalizationN, C, T, V, M = x.size()x = x.permute(0, 4, 3, 1, 2).contiguous()x = x.view(N * M, V * C, T)x = self.data_bn(x)x = x.view(N, M, V, C, T)x = x.permute(0, 1, 3, 4, 2).contiguous()x = x.view(N * M, C, T, V)# forwadfor gcn, importance in zip(self.st_gcn_networks, self.edge_importance):x, _ = gcn(x, self.A * importance)# global poolingx = F.avg_pool2d(x, x.size()[2:])x = x.view(N, M, -1, 1, 1).mean(dim=1)# predictionx = self.fcn(x)x = x.view(x.size(0), -1)return xdef extract_feature(self, x):# data normalizationN, C, T, V, M = x.size()x = x.permute(0, 4, 3, 1, 2).contiguous()x = x.view(N * M, V * C, T)x = self.data_bn(x)x = x.view(N, M, V, C, T)x = x.permute(0, 1, 3, 4, 2).contiguous()x = x.view(N * M, C, T, V)# forwadfor gcn, importance in zip(self.st_gcn_networks, self.edge_importance):x, _ = gcn(x, self.A * importance)_, c, t, v = x.size()feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1)# predictionx = self.fcn(x)output = x.view(N, M, -1, t, v).permute(0, 2, 3, 4, 1)return output, featureclass st_gcn(nn.Module):r"""Applies a spatial temporal graph convolution over an input graph sequence.Args:in_channels (int): Number of channels in the input sequence dataout_channels (int): Number of channels produced by the convolutionkernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernelstride (int, optional): Stride of the temporal convolution. Default: 1dropout (int, optional): Dropout rate of the final output. Default: 0residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``Shape:- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` formatwhere:math:`N` is a batch size,:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,:math:`T_{in}/T_{out}` is a length of input/output sequence,:math:`V` is the number of graph nodes."""def __init__(self,in_channels,out_channels,kernel_size,stride=1,dropout=0,residual=True):super().__init__()assert len(kernel_size) == 2assert kernel_size[0] % 2 == 1padding = ((kernel_size[0] - 1) // 2, 0)self.gcn = ConvTemporalGraphical(in_channels, out_channels,kernel_size[1])self.tcn = nn.Sequential(nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,(kernel_size[0], 1),(stride, 1),padding,),nn.BatchNorm2d(out_channels),nn.Dropout(dropout, inplace=True),)if not residual:self.residual = lambda x: 0elif (in_channels == out_channels) and (stride == 1):self.residual = lambda x: xelse:self.residual = nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=(stride, 1)),nn.BatchNorm2d(out_channels),)self.relu = nn.ReLU(inplace=True)def forward(self, x, A):res = self.residual(x)x, A = self.gcn(x, A)x = self.tcn(x) + resreturn self.relu(x), A
这是一个空间时间图卷积神经网络的模型。它的输入包括一个 5D 张量,表示 N 个样本,每个样本具有 C 个通道,T 个时间步长,V 个图节点和 M 个实例。输出是一个 2D 张量,其大小为(N, num_class)。
它的主要组成部分包括:
- 一个 Graph 类,构建了一个时空图;
- 一个自定义的 ConvTemporalGraphical 类,实现了空间卷积和时间卷积;
- 一系列 st_gcn 类,利用 ConvTemporalGraphical 对输入执行空间卷积和时间卷积,并包含一个可选的残差机制;
- 一个全局池化层,将所有节点的特征进行平均池化;
- 最后是一个简单的全连接层。
其中 forward()函数接收输入并执行前向传递,extract_feature()函数也接收输入并执行前向传递,但同时还返回中间特征。st_gcn()类负责网络中一个 st-gcn 层的定义。
分段理解
# load graphself.graph = Graph(graph_args)A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)self.register_buffer('A', A)
这部分代码定义了一个 Graph 对象,用于构建时空图。
- graph_args 表示将传递给 Graph 类的所有参数打包为字典类型。
- 然后,将 Graph.A 属性(表示邻接矩阵)转换为 torch.tensor,数据类型设置为float32,并且 requires_grad 属性设置为 False。
- 之后,使用 register_buffer()函数将邻接矩阵作为固定参数存储下来。
register_buffer()函数的作用是将输入作为持久缓冲区进行注册。在模型的前向函数被调用时,存储在这里的张量会被自动将移到与参数一样的设备上,而不会被视为模型参数,也不会被包含在优化器中。
这里将邻接矩阵存储为固定参数,主要是因为该邻接矩阵通常在整个训练期间都是不变的。
# build networksspatial_kernel_size = A.size(0)temporal_kernel_size = 9kernel_size = (temporal_kernel_size, spatial_kernel_size)self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}self.st_gcn_networks = nn.ModuleList((st_gcn(in_channels, 64, kernel_size, 1, residual=False, kwargs0),st_gcn(64, 64, kernel_size, 1, kwargs),st_gcn(64, 64, kernel_size, 1, kwargs),st_gcn(64, 64, kernel_size, 1, kwargs),st_gcn(64, 128, kernel_size, 2, kwargs),st_gcn(128, 128, kernel_size, 1, kwargs),st_gcn(128, 128, kernel_size, 1, kwargs),st_gcn(128, 256, kernel_size, 2, kwargs),st_gcn(256, 256, kernel_size, 1, kwargs),st_gcn(256, 256, kernel_size, 1, kwargs),))
这段代码是用于构建基于 st-gcn 模型的网络结构。具体来说,
- 它首先根据给定的邻接矩阵 A 尺寸计算出空间卷积核的大小 spatial_kernel_size,然后设置了时间卷积核的大小 temporal_kernel_size 为9。
- 接着,它使用了 nn.BatchNorm1d 对输入数据进行归一化。
- 接下来,它使用 nn.ModuleList 创建了10个 st_gcn 模块,并将这些模块放入列表中。其中,第一个 st_gcn 模块的输入通道数为 in_channels,输出通道数为64,使用的卷积核大小为(temporal_kernel_size,
spatial_kernel_size),不带有残差结构;其余各层 st_gcn 模块的输入通道数和输出通道数依次为64, 64, 64,
128, 128, 128, 256, 256, 256,每层都使用了(temporal_kernel_size,
spatial_kernel_size)大小的卷积核,其中第5和第8层使用了步长为2的卷积操作。 - 最后,它返回了由所有 st_gcn 模块组成的列表 self.st_gcn_networks,该列表即为整个基于 st-gcn 模型的网络结构。