> 文章列表 > GAT的基础理论

GAT的基础理论

GAT的基础理论

文章目录

  • GAT原理(理解用)
    • GAT工作流程
      • 计算注意力系数(attention coefficient)
      • 加权求和(aggregate)
    • GAT深入理解
  • GAT的实用基础理论(编代码用)
    • 1. GAT的底层实现(pytorch)
      • PyG中的GATConv实现
    • 2. GAT的实例
  • 引用

GAT原理(理解用)

引入:

GCN的缺点:

  1. 无法完成inductive任务,即处理动态图问题。inductive任务是指:训练阶段与测试阶段需要处理的graph不同。通常是训练阶段只是在子图(subgraph)上进行,测试阶段需要处理未知的顶点。(unseen node)
  2. 处理有向图的瓶颈,不容易实现分配不同的学习权重给不同的neighbor。

GAT:重点在获取其余节点对本节点的影响上。GAT本质上有两种运算方式,

  • mask graph attention:注意力机制的运算只在邻居顶点上进行。目前常用方式。
  • global graph attention:每一个顶点 iii 都对于图上任意顶点都进行attention运算

优点:完全不依赖于图的结构,对于inductive任务无压力
缺点:(1)丢掉了图结构的这个特征,无异于自废武功,效果可能会很差(2)运算面临着高昂的成本

GAT工作流程

计算注意力系数(attention coefficient)

  1. 对于顶点 iii, 逐个计算它的邻居节点 jjj 和它的相关系数:ei,j=score([Whi∣∣Whj]),j∈Nie_{i,j} = score([Wh_i||Wh_j]),j \\in N_iei,j=score([Whi∣∣Whj]),jNi
    • W:共享参数W的线性映射对于顶点的特征进行了增维,这是一种常见的特征增强方法。
    • ||:对于顶点 i,ji,ji,j的变换后的特征进行了concat拼接
    • score:score函数将拼接后的高维特征映射到一个实数上
  2. 根据相关系数,计算注意力系数:αi,j=exp(LeakyReLU(ei,j))∑k∈Niexp(LeakyReLU(eik))\\alpha_{i,j} = \\frac{exp(LeakyReLU(e_{i,j}))}{\\sum_{k\\in N_i}exp(LeakyReLU(e_{ik}))}αi,j=kNiexp(LeakyReLU(eik))exp(LeakyReLU(ei,j))

加权求和(aggregate)

  1. 对于每个顶点,GAT输出为融合了邻域信息的新特征:hi′=σ(∑j∈Niαi,jWhj)h_i' = \\sigma(\\sum_{j\\in N_i}\\alpha_{i,j}Wh_j)hi=σ(jNiαi,jWhj)

  2. 为了获得更多的信息,可以使用多头注意力机制:hi′(K)=concatk=1K[σ(∑j∈Niαi.jkWkhj)]h_i'(K) = concat_{k=1}^{K}[\\sigma(\\sum_{j\\in N_i}\\alpha_{i.j}^kW^kh_j)]hi(K)=concatk=1K[σ(jNiαi.jkWkhj)] 其中,K为多头注意力的头数其中,K为多头注意力的头数其中,K为多头注意力的头数

    特别地,如果是在网络的最终(预测或分类)层,那么concat拼接操作不再是明智的——相反,使用平均就可以解决分类问题,具体公式如下:
    hi′=σ(1K∑k=1K∑j∈Niαi,jkWkhj)h_i' = \\sigma (\\frac{1}{K}\\sum_{k=1}^K\\sum_{j\\in N_i}\\alpha_{i,j}^kW^kh_j)hi=σ(K1k=1KjNiαi,jkWkhj)
    其中,K为多头注意力的头数其中,K为多头注意力的头数其中,K为多头注意力的头数

GAT深入理解

  1. 与GCN的联系与区别

    本质上而言:GCN与GAT都是将邻居顶点的特征聚合到中心顶点上(一种aggregate运算),利用graph上的local stationary学习新的顶点特征表达。不同的是GCN利用了拉普拉斯矩阵,GAT利用attention系数。一定程度上而言,GAT会更强,因为顶点特征之间的相关性被更好地融入到模型中。

  2. 为什么GAT适用于有向图

    GCN 假设图是无向的,因为利用了对称的拉普拉斯矩阵 (只有邻接矩阵 A 是对称的,拉普拉斯矩阵才可以正交分解),不能直接用于有向图。 GCN 的作者为了处理有向图,需要对 Graph 结构进行调整, 要把有向边划分成两个节点放入 Graph 中 。例如 e1、e2 为两个节点,r 为 e1,e2 的有向关系,则需要把 r 划分为两个关系节点 r1 和 r2 放入图中。连接 (e1, r1)、(e2, r2)。

    GAT能处理有向图最根本的原因是它的运算方式是逐顶点的运算(node-wise),这一点可从上面的所有公式中很明显地看出。每一次运算都需要循环遍历图上的所有顶点来完成。逐顶点运算意味着,摆脱了拉普拉斯矩阵的束缚,使得有向图问题迎刃而解。

  3. 为什么GAT适用于inductive任务(动态图问题)

    GAT中重要的学习参数是 W 与 score函数,因为上述的逐顶点运算方式,这两个参数仅与顶点特征相关,与图的结构毫无关系。所以测试任务中改变图的结构,对于GAT影响并不大。

    与此相反的是,GCN是一种全图的计算方式,一次计算就更新全图的节点特征。学习的参数很大程度与图结构相关,这使得GCN在inductive任务上遇到困境。

GAT的实用基础理论(编代码用)

1. GAT的底层实现(pytorch)

PyG中的GATConv实现

论文中的方法:
ei,j=score([Whi∣∣Whj]),j∈Nie_{i,j} = score([Wh_i||Wh_j]),j \\in N_iei,j=score([Whi∣∣Whj]),jNi
αi,j=exp(LeakyReLU(ei,j))∑k∈Niexp(LeakyReLU(eik))\\alpha_{i,j} = \\frac{exp(LeakyReLU(e_{i,j}))}{\\sum_{k\\in N_i}exp(LeakyReLU(e_{ik}))}αi,j=kNiexp(LeakyReLU(eik))exp(LeakyReLU(ei,j))
hi′=σ(∑j∈Niαi,jWhj)h_i' = \\sigma(\\sum_{j\\in N_i}\\alpha_{i,j}Wh_j)hi=σ(jNiαi,jWhj)

PyG中实现方法:
ei,j=score([W1hi∣∣W2hj]),j∈Ni∪{i}e_{i,j} = score([W_1h_i||W_2h_j]),j \\in N_i \\cup \\{i\\}ei,j=score([W1hi∣∣W2hj]),jNi{i}
具体实现方法为:ei,j=aT[W1hi∣∣W2hj],j∈Ni∪{i}具体实现方法为:e_{i,j} = a^T[W_1h_i||W_2h_j],j \\in N_i \\cup \\{i\\}具体实现方法为:ei,j=aT[W1hi∣∣W2hj],jNi{i}
αi,j=exp(LeakyReLU(ei,j))∑k∈Ni∪{i}exp(LeakyReLU(eik))\\alpha_{i,j} = \\frac{exp(LeakyReLU(e_{i,j}))}{\\sum_{k\\in N_i\\cup \\{i\\}}exp(LeakyReLU(e_{ik}))}αi,j=kNi{i}exp(LeakyReLU(eik))exp(LeakyReLU(ei,j))

hi′=αi,iW1hi+∑j∈N(i)αi,jW2hjh_i' = \\alpha_{i,i}W_1h_i + \\sum_{j \\in N(i)}\\alpha_{i,j}W_2h_jhi=αi,iW1hi+jN(i)αi,jW2hj

区别:

  1. PyG求score时,对target节点(即hih_ihi)和source节点(即hjh_jhj)使用了不同的参数 WWW
  2. PyG求score时,对target节点使用的aTa^TaT和source节点使用的aTa^TaT也不相同,公式中为了简便就没有体现。
  3. PyG求score时,j的取值范围包括自己,即j∈Ni∪{i}j \\in N_i \\cup \\{i\\}jNi{i}
  4. PyG求注意力分数时,PyG中的实现方法分母中求和的集合多了一个自身,即∑k∈Ni∪{i}\\sum_{k\\in N_i\\cup \\{i\\}}kNi{i}
  5. PyG中最后聚合的时候同样对自身和邻域节点分别进行了加权求和。
  • init函数

    参数说明:

    • in_channels: Union[int, Tuple[int, int]]:输入原始特征或者隐含层embedding的维度。如果是-1,则根据传入的x来推断特征维度。注意in_channels可以是一个整数,也可以是两个整数组成的tuple,分别对应source节点和target节点的特征维度。
    • source节点: 中心节点的邻居节点。{xj,∀j∈N(i)}\\{x_j, \\forall j\\in N(i)\\}{xj,jN(i)}
    • target节点:中心节点。xix_ixi
    • in_channels[0]:参数W2W_2W2的shape[0],对应source节点(邻域节点)的特征维度
    • in_channels[1]:参数W1W_1W1的shape[0],对应target节点(目标节点)的特征维度
    • out_channels:输出embedding的维度
    • heads:表示注意力头数,默认为1
    • concat:表示multi-head输出后的多个特征向量的处理方法是否需要拼接,默认为True
    • negative_slope:采用leakyRELU的激活函数,x的负半平面斜率系数,默认为0.2
    • dropout:过拟合参数p,默认为0
    • add_self_loops:GAT要求加入自环,即每个节点要与自身连接,默认为True
    • bias:偏差,默认为True
    • kwargs.setdefault('aggr', 'add'):邻域聚合方式,默认aggr='add'
    def __init__(self, in_channels: Union[int, Tuple[int, int]],out_channels: int, heads: int = 1, concat: bool = True,negative_slope: float = 0.2, dropout: float = 0.,add_self_loops: bool = True, bias: bool = True, **kwargs):kwargs.setdefault('aggr', 'add')super(GATConv, self).__init__(node_dim=0, **kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.heads = headsself.concat = concatself.negative_slope = negative_slopeself.dropout = dropoutself.add_self_loops = add_self_loopsif isinstance(in_channels, int):# 如果是单个整数,那么邻域节点和目标节点公用同一组参数Wself.lin_l = Linear(in_channels, heads * out_channels, bias=False)self.lin_r = self.lin_lelse:# 如果是tuple,那么邻域节点(source)使用参数W2,维度为in_channels[0]# 目标节点(target)使用参数W1,维度为in_channels[1]self.lin_l = Linear(in_channels[0], heads * out_channels, False)self.lin_r = Linear(in_channels[1], heads * out_channels, False)# att_l和att_r对应公式中的a^T,l和r也是分别用于source node和target nodeself.att_l = Parameter(torch.Tensor(1, heads, out_channels))self.att_r = Parameter(torch.Tensor(1, heads, out_channels))if bias and concat:self.bias = Parameter(torch.Tensor(heads * out_channels))elif bias and not concat:self.bias = Parameter(torch.Tensor(out_channels))else:self.register_parameter('bias', None)self._alpha = Noneself.reset_parameters()
    
  • forward函数:

    参数说明:

    • x:Union[Tensor, OptPairTensor]:可以是Tensor,也可以是OptPairTensor (pyg定义的tuple of Tensor)。

    当图是bipartite的时候,x是OptPairTensor,为了和init函数中定义的in_channel对应,要使得:

    • source节点(邻居节点)特征对应x[0] ,在代码中赋值给x_lin_channel[0]W2W_2W2)定义为lin_l
    • target节点(中心节点)特征对应x[1],在代码中赋值给 x_rin_channel[1]W1W_1W1)定义为lin_r
    • edge_index: Adj: Adj是pyg定义的邻接矩阵类型,可以是Tensor,也可以是SparseTensor。
    • return_attention_weights:是否返回注意力权重,默认为False
    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,size: Size = None, return_attention_weights=None):H, C = self.heads, self.out_channelsx_l: OptTensor = Nonex_r: OptTensor = Nonealpha_l: OptTensor = Nonealpha_r: OptTensor = None# 求注意力相关系数alpha_l和alpha_rif isinstance(x, Tensor):assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'x_l = x_r = self.lin_l(x).view(-1, H, C)alpha_l = (x_l * self.att_l).sum(dim=-1)alpha_r = (x_r * self.att_r).sum(dim=-1)else:x_l, x_r = x[0], x[1]assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'x_l = self.lin_l(x_l).view(-1, H, C)alpha_l = (x_l * self.att_l).sum(dim=-1)if x_r is not None:x_r = self.lin_r(x_r).view(-1, H, C)alpha_r = (x_r * self.att_r).sum(dim=-1)assert x_l is not Noneassert alpha_l is not None# 为邻接矩阵添加自环if self.add_self_loops:if isinstance(edge_index, Tensor):num_nodes = x_l.size(0)if x_r is not None:num_nodes = min(num_nodes, x_r.size(0))if size is not None:num_nodes = min(size[0], size[1])edge_index, _ = remove_self_loops(edge_index)edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)elif isinstance(edge_index, SparseTensor):edge_index = set_diag(edge_index)# 最重要的步骤:计算注意力分数,并将注意力分数赋值给self._alpha,从而赋值给_alpha,并将self._alpha清空# propagate_type: (x: OptPairTensor, alpha: OptPairTensor)out = self.propagate(edge_index, x=(x_l, x_r),alpha=(alpha_l, alpha_r), size=size)alpha = self._alphaself._alpha = None# 判断是否concat,如果不concat就表示是最后一层,要用meanif self.concat:out = out.view(-1, self.heads * self.out_channels)else:out = out.mean(dim=1)# 添加偏差if self.bias is not None:out += self.bias# if isinstance(return_attention_weights, bool):assert alpha is not Noneif isinstance(edge_index, Tensor):return out, (edge_index, alpha)elif isinstance(edge_index, SparseTensor):return out, edge_index.set_value(alpha, layout='coo')else:return out
  • message函数

    参数说明:

    • x_j:邻域节点,即source节点,x_l
    • alpha_j:邻域节点的注意力相关系数
    • alpha_i:目标节点的注意力相关系数
    • index:index是与source node相连的target node的标号,就是edge_index的第二行
    def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,index: Tensor, ptr: OptTensor,size_i: Optional[int]) -> Tensor:alpha = alpha_j if alpha_i is None else alpha_j + alpha_ialpha = F.leaky_relu(alpha, self.negative_slope)alpha = softmax(alpha, index, ptr, size_i)self._alpha = alphaalpha = F.dropout(alpha, p=self.dropout, training=self.training)return x_j * alpha.unsqueeze(-1)
    

2. GAT的实例

import torch
import math
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GATConvfrom torch_geometric.utils import add_self_loops,degree
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as Fclass Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.gat1=GATConv(dataset.num_node_features,8,8,dropout=0.6)self.gat2=GATConv(64,7,1,dropout=0.6)def forward(self,data):x,edge_index=data.x, data.edge_indexx=self.gat1(x,edge_index)x=self.gat2(x,edge_index)return F.log_softmax(x,dim=1)ssl._create_default_https_context = ssl._create_unverified_context
dataset = Planetoid(root='Cora', name='Cora')
x=dataset[0].x
edge_index=dataset[0].edge_indexdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)model.train()
for epoch in range(100):optimizer.zero_grad()out = model(data)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct/int(data.test_mask.sum())
print('Accuracy:{:.4f}'.format(acc))
>>>Accuracy:0.7960

引用

文章参考了:

  1. https://blog.csdn.net/weixin_44839047/article/details/115724958
  2. https://blog.csdn.net/qq_41995574/article/details/99931294
  3. https://blog.csdn.net/xiao_muyu/article/details/121762806
  4. https://blog.csdn.net/StarfishCu/article/details/109644271