.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/models/1_gnn/9_gat.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_models_1_gnn_9_gat.py: .. _model-gat: Understand Graph Attention Network ======================================= **Authors:** `Hao Zhang `_, `Mufei Li `_, `Minjie Wang `_ `Zheng Zhang `_ .. warning:: The tutorial aims at gaining insights into the paper, with code as a mean of explanation. The implementation thus is NOT optimized for running efficiency. For recommended implementation, please refer to the `official examples `_. In this tutorial, you learn about a graph attention network (GAT) and how it can be implemented in PyTorch. You can also learn to visualize and understand what the attention mechanism has learned. The research described in the paper `Graph Convolutional Network (GCN) `_, indicates that combining local graph structure and node-level features yields good performance on node classification tasks. However, the way GCN aggregates is structure-dependent, which can hurt its generalizability. One workaround is to simply average over all neighbor node features as described in the research paper `GraphSAGE `_. However, `Graph Attention Network `_ proposes a different type of aggregation. GAT uses weighting neighbor features with feature dependent and structure-free normalization, in the style of attention. .. GENERATED FROM PYTHON SOURCE LINES 36-106 Introducing attention to GCN ---------------------------- The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated. For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors. .. math:: h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\frac{1}{c_{ij}} W^{(l)}h^{(l)}_j}\right) where :math:`\mathcal{N}(i)` is the set of its one-hop neighbors (to include :math:`v_i` in the set, simply add a self-loop to each node), :math:`c_{ij}=\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}` is a normalization constant based on graph structure, :math:`\sigma` is an activation function (GCN uses ReLU), and :math:`W^{(l)}` is a shared weight matrix for node-wise feature transformation. Another model proposed in `GraphSAGE `_ employs the same update rule except that they set :math:`c_{ij}=|\mathcal{N}(i)|`. GAT introduces the attention mechanism as a substitute for the statically normalized convolution operation. Below are the equations to compute the node embedding :math:`h_i^{(l+1)}` of layer :math:`l+1` from the embeddings of layer :math:`l`. .. image:: https://data.dgl.ai/tutorial/gat/gat.png :width: 450px :align: center .. math:: \begin{align} z_i^{(l)}&=W^{(l)}h_i^{(l)},&(1) \\ e_{ij}^{(l)}&=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}||z_j^{(l)})),&(2)\\ \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&(3)\\ h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),&(4) \end{align} Explanations: * Equation (1) is a linear transformation of the lower layer embedding :math:`h_i^{(l)}` and :math:`W^{(l)}` is its learnable weight matrix. * Equation (2) computes a pair-wise *un-normalized* attention score between two neighbors. Here, it first concatenates the :math:`z` embeddings of the two nodes, where :math:`||` denotes concatenation, then takes a dot product of it and a learnable weight vector :math:`\vec a^{(l)}`, and applies a LeakyReLU in the end. This form of attention is usually called *additive attention*, contrast with the dot-product attention in the Transformer model. * Equation (3) applies a softmax to normalize the attention scores on each node's incoming edges. * Equation (4) is similar to GCN. The embeddings from neighbors are aggregated together, scaled by the attention scores. There are other details from the paper, such as dropout and skip connections. For the purpose of simplicity, those details are left out of this tutorial. To see more details, download the `full example `_. In its essence, GAT is just a different aggregation function with attention over features of neighbors, instead of a simple mean aggregation. GAT in DGL ---------- DGL provides an off-the-shelf implementation of the GAT layer under the ``dgl.nn.`` subpackage. Simply import the ``GATConv`` as the follows. .. GENERATED FROM PYTHON SOURCE LINES 106-110 .. code-block:: Python import os os.environ["DGLBACKEND"] = "pytorch" .. GENERATED FROM PYTHON SOURCE LINES 111-123 Readers can skip the following step-by-step explanation of the implementation and jump to the `Put everything together`_ for training and visualization results. To begin, you can get an overall impression about how a ``GATLayer`` module is implemented in DGL. In this section, the four equations above are broken down one at a time. .. note:: This is showing how to implement a GAT from scratch. DGL provides a more efficient :class:`builtin GAT layer module `. .. GENERATED FROM PYTHON SOURCE LINES 123-175 .. code-block:: Python import torch import torch.nn as nn import torch.nn.functional as F from dgl.nn.pytorch import GATConv class GATLayer(nn.Module): def __init__(self, g, in_dim, out_dim): super(GATLayer, self).__init__() self.g = g # equation (1) self.fc = nn.Linear(in_dim, out_dim, bias=False) # equation (2) self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False) self.reset_parameters() def reset_parameters(self): """Reinitialize learnable parameters.""" gain = nn.init.calculate_gain("relu") nn.init.xavier_normal_(self.fc.weight, gain=gain) nn.init.xavier_normal_(self.attn_fc.weight, gain=gain) def edge_attention(self, edges): # edge UDF for equation (2) z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1) a = self.attn_fc(z2) return {"e": F.leaky_relu(a)} def message_func(self, edges): # message UDF for equation (3) & (4) return {"z": edges.src["z"], "e": edges.data["e"]} def reduce_func(self, nodes): # reduce UDF for equation (3) & (4) # equation (3) alpha = F.softmax(nodes.mailbox["e"], dim=1) # equation (4) h = torch.sum(alpha * nodes.mailbox["z"], dim=1) return {"h": h} def forward(self, h): # equation (1) z = self.fc(h) self.g.ndata["z"] = z # equation (2) self.g.apply_edges(self.edge_attention) # equation (3) & (4) self.g.update_all(self.message_func, self.reduce_func) return self.g.ndata.pop("h") .. GENERATED FROM PYTHON SOURCE LINES 176-198 Equation (1) ^^^^^^^^^^^^ .. math:: z_i^{(l)}=W^{(l)}h_i^{(l)},(1) The first one shows linear transformation. It's common and can be easily implemented in Pytorch using ``torch.nn.Linear``. Equation (2) ^^^^^^^^^^^^ .. math:: e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2) The un-normalized attention score :math:`e_{ij}` is calculated using the embeddings of adjacent nodes :math:`i` and :math:`j`. This suggests that the attention scores can be viewed as edge data, which can be calculated by the ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**, which is defined as below: .. GENERATED FROM PYTHON SOURCE LINES 198-207 .. code-block:: Python def edge_attention(self, edges): # edge UDF for equation (2) z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1) a = self.attn_fc(z2) return {"e": F.leaky_relu(a)} .. GENERATED FROM PYTHON SOURCE LINES 208-234 Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}` is implemented again using PyTorch's linear transformation ``attn_fc``. Note that ``apply_edges`` will **batch** all the edge data in one tensor, so the ``cat``, ``attn_fc`` here are applied on all the edges in parallel. Equation (3) & (4) ^^^^^^^^^^^^^^^^^^ .. math:: \begin{align} \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&(3)\\ h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),&(4) \end{align} Similar to GCN, ``update_all`` API is used to trigger message passing on all the nodes. The message function sends out two tensors: the transformed ``z`` embedding of the source node and the un-normalized attention score ``e`` on each edge. The reduce function then performs two tasks: * Normalize the attention scores using softmax (equation (3)). * Aggregate neighbor embeddings weighted by the attention scores (equation(4)). Both tasks first fetch data from the mailbox and then manipulate it on the second dimension (``dim=1``), on which the messages are batched. .. GENERATED FROM PYTHON SOURCE LINES 234-245 .. code-block:: Python def reduce_func(self, nodes): # reduce UDF for equation (3) & (4) # equation (3) alpha = F.softmax(nodes.mailbox["e"], dim=1) # equation (4) h = torch.sum(alpha * nodes.mailbox["z"], dim=1) return {"h": h} .. GENERATED FROM PYTHON SOURCE LINES 246-265 Multi-head attention ^^^^^^^^^^^^^^^^^^^^ Analogous to multiple channels in ConvNet, GAT introduces **multi-head attention** to enrich the model capacity and to stabilize the learning process. Each attention head has its own parameters and their outputs can be merged in two ways: .. math:: \text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\sigma\left(\sum_{j\in \mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right) or .. math:: \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right) where :math:`K` is the number of heads. You can use concatenation for intermediary layers and average for the final layer. Use the above defined single-head ``GATLayer`` as the building block for the ``MultiHeadGATLayer`` below: .. GENERATED FROM PYTHON SOURCE LINES 265-285 .. code-block:: Python class MultiHeadGATLayer(nn.Module): def __init__(self, g, in_dim, out_dim, num_heads, merge="cat"): super(MultiHeadGATLayer, self).__init__() self.heads = nn.ModuleList() for i in range(num_heads): self.heads.append(GATLayer(g, in_dim, out_dim)) self.merge = merge def forward(self, h): head_outs = [attn_head(h) for attn_head in self.heads] if self.merge == "cat": # concat on the output feature dimension (dim=1) return torch.cat(head_outs, dim=1) else: # merge using average return torch.mean(torch.stack(head_outs)) .. GENERATED FROM PYTHON SOURCE LINES 286-290 Put everything together ^^^^^^^^^^^^^^^^^^^^^^^ Now, you can define a two-layer GAT model. .. GENERATED FROM PYTHON SOURCE LINES 290-310 .. code-block:: Python class GAT(nn.Module): def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads): super(GAT, self).__init__() self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads) # Be aware that the input dimension is hidden_dim*num_heads since # multiple head outputs are concatenated together. Also, only # one attention head in the output layer. self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1) def forward(self, h): h = self.layer1(h) h = F.elu(h) h = self.layer2(h) return h import networkx as nx .. GENERATED FROM PYTHON SOURCE LINES 311-312 We then load the Cora dataset using DGL's built-in data module. .. GENERATED FROM PYTHON SOURCE LINES 312-324 .. code-block:: Python from dgl import DGLGraph from dgl.data import citation_graph as citegrh def load_cora_data(): data = citegrh.load_cora() g = data[0] mask = torch.BoolTensor(g.ndata["train_mask"]) return g, g.ndata["feat"], g.ndata["label"], mask .. GENERATED FROM PYTHON SOURCE LINES 325-326 The training loop is exactly the same as in the GCN tutorial. .. GENERATED FROM PYTHON SOURCE LINES 326-362 .. code-block:: Python import time import numpy as np g, features, labels, mask = load_cora_data() # create the model, 2 heads, each head has hidden size 8 net = GAT(g, in_dim=features.size()[1], hidden_dim=8, out_dim=7, num_heads=2) # create optimizer optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) # main loop dur = [] for epoch in range(30): if epoch >= 3: t0 = time.time() logits = net(features) logp = F.log_softmax(logits, 1) loss = F.nll_loss(logp[mask], labels[mask]) optimizer.zero_grad() loss.backward() optimizer.step() if epoch >= 3: dur.append(time.time() - t0) print( "Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format( epoch, loss.item(), np.mean(dur) ) ) .. rst-class:: sphx-glr-script-out .. code-block:: none NumNodes: 2708 NumEdges: 10556 NumFeats: 1433 NumClasses: 7 NumTrainingSamples: 140 NumValidationSamples: 500 NumTestSamples: 1000 Done loading data from cached files. /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice. return _methods._mean(a, axis=axis, dtype=dtype, /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide ret = ret.dtype.type(ret / rcount) Epoch 00000 | Loss 1.9434 | Time(s) nan Epoch 00001 | Loss 1.9414 | Time(s) nan Epoch 00002 | Loss 1.9393 | Time(s) nan Epoch 00003 | Loss 1.9373 | Time(s) 0.0998 Epoch 00004 | Loss 1.9352 | Time(s) 0.0962 Epoch 00005 | Loss 1.9332 | Time(s) 0.0944 Epoch 00006 | Loss 1.9311 | Time(s) 0.0955 Epoch 00007 | Loss 1.9290 | Time(s) 0.0948 Epoch 00008 | Loss 1.9269 | Time(s) 0.0946 Epoch 00009 | Loss 1.9249 | Time(s) 0.0948 Epoch 00010 | Loss 1.9228 | Time(s) 0.0945 Epoch 00011 | Loss 1.9207 | Time(s) 0.0955 Epoch 00012 | Loss 1.9186 | Time(s) 0.0951 Epoch 00013 | Loss 1.9165 | Time(s) 0.0952 Epoch 00014 | Loss 1.9144 | Time(s) 0.0950 Epoch 00015 | Loss 1.9122 | Time(s) 0.0944 Epoch 00016 | Loss 1.9101 | Time(s) 0.0941 Epoch 00017 | Loss 1.9080 | Time(s) 0.0939 Epoch 00018 | Loss 1.9058 | Time(s) 0.0941 Epoch 00019 | Loss 1.9036 | Time(s) 0.0938 Epoch 00020 | Loss 1.9015 | Time(s) 0.0939 Epoch 00021 | Loss 1.8993 | Time(s) 0.0937 Epoch 00022 | Loss 1.8971 | Time(s) 0.0937 Epoch 00023 | Loss 1.8949 | Time(s) 0.0939 Epoch 00024 | Loss 1.8927 | Time(s) 0.0941 Epoch 00025 | Loss 1.8904 | Time(s) 0.0945 Epoch 00026 | Loss 1.8882 | Time(s) 0.0945 Epoch 00027 | Loss 1.8859 | Time(s) 0.0943 Epoch 00028 | Loss 1.8837 | Time(s) 0.0945 Epoch 00029 | Loss 1.8814 | Time(s) 0.0945 .. GENERATED FROM PYTHON SOURCE LINES 363-546 Visualizing and understanding attention learned ---------------------------------------------- Cora ^^^^ The following table summarizes the model performance on Cora that is reported in `the GAT paper `_ and obtained with DGL implementations. .. list-table:: :header-rows: 1 * - Model - Accuracy * - GCN (paper) - :math:`81.4\pm 0.5%` * - GCN (dgl) - :math:`82.05\pm 0.33%` * - GAT (paper) - :math:`83.0\pm 0.7%` * - GAT (dgl) - :math:`83.69\pm 0.529%` *What kind of attention distribution has our model learned?* Because the attention weight :math:`a_{ij}` is associated with edges, you can visualize it by coloring edges. Below you can pick a subgraph of Cora and plot the attention weights of the last ``GATLayer``. The nodes are colored according to their labels, whereas the edges are colored according to the magnitude of the attention weights, which can be referred with the colorbar on the right. .. image:: https://data.dgl.ai/tutorial/gat/cora-attention.png :width: 600px :align: center You can see that the model seems to learn different attention weights. To understand the distribution more thoroughly, measure the `entropy `_) of the attention distribution. For any node :math:`i`, :math:`\{\alpha_{ij}\}_{j\in\mathcal{N}(i)}` forms a discrete probability distribution over all its neighbors with the entropy given by .. math:: H({\alpha_{ij}}_{j\in\mathcal{N}(i)})=-\sum_{j\in\mathcal{N}(i)} \alpha_{ij}\log\alpha_{ij} A low entropy means a high degree of concentration, and vice versa. An entropy of 0 means all attention is on one source node. The uniform distribution has the highest entropy of :math:`\log(\mathcal{N}(i))`. Ideally, you want to see the model learns a distribution of lower entropy (i.e, one or two neighbors are much more important than the others). Note that since nodes can have different degrees, the maximum entropy will also be different. Therefore, you plot the aggregated histogram of entropy values of all nodes in the entire graph. Below are the attention histogram of learned by each attention head. |image2| As a reference, here is the histogram if all the nodes have uniform attention weight distribution. .. image:: https://data.dgl.ai/tutorial/gat/cora-attention-uniform-hist.png :width: 250px :align: center One can see that **the attention values learned is quite similar to uniform distribution** (i.e, all neighbors are equally important). This partially explains why the performance of GAT is close to that of GCN on Cora (according to `author's reported result `_, the accuracy difference averaged over 100 runs is less than 2 percent). Attention does not matter since it does not differentiate much. *Does that mean the attention mechanism is not useful?* No! A different dataset exhibits an entirely different pattern, as you can see next. Protein-protein interaction (PPI) networks ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The PPI dataset used here consists of :math:`24` graphs corresponding to different human tissues. Nodes can have up to :math:`121` kinds of labels, so the label of node is represented as a binary tensor of size :math:`121`. The task is to predict node label. Use :math:`20` graphs for training, :math:`2` for validation and :math:`2` for test. The average number of nodes per graph is :math:`2372`. Each node has :math:`50` features that are composed of positional gene sets, motif gene sets, and immunological signatures. Critically, test graphs remain completely unobserved during training, a setting called "inductive learning". Compare the performance of GAT and GCN for :math:`10` random runs on this task and use hyperparameter search on the validation set to find the best model. .. list-table:: :header-rows: 1 * - Model - F1 Score(micro) * - GAT - :math:`0.975 \pm 0.006` * - GCN - :math:`0.509 \pm 0.025` * - Paper - :math:`0.973 \pm 0.002` The table above is the result of this experiment, where you use micro `F1 score `_ to evaluate the model performance. .. note:: Below is the calculation process of F1 score: .. math:: precision=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FP_{t})} recall=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FN_{t})} F1_{micro}=2\frac{precision*recall}{precision+recall} * :math:`TP_{t}` represents for number of nodes that both have and are predicted to have label :math:`t` * :math:`FP_{t}` represents for number of nodes that do not have but are predicted to have label :math:`t` * :math:`FN_{t}` represents for number of output classes labeled as :math:`t` but predicted as others. * :math:`n` is the number of labels, i.e. :math:`121` in our case. During training, use ``BCEWithLogitsLoss`` as the loss function. The learning curves of GAT and GCN are presented below; what is evident is the dramatic performance adavantage of GAT over GCN. .. image:: https://data.dgl.ai/tutorial/gat/ppi-curve.png :width: 300px :align: center As before, you can have a statistical understanding of the attentions learned by showing the histogram plot for the node-wise attention entropy. Below are the attention histograms learned by different attention layers. *Attention learned in layer 1:* |image5| *Attention learned in layer 2:* |image6| *Attention learned in final layer:* |image7| Again, comparing with uniform distribution: .. image:: https://data.dgl.ai/tutorial/gat/ppi-uniform-hist.png :width: 250px :align: center Clearly, **GAT does learn sharp attention weights**! There is a clear pattern over the layers as well: **the attention gets sharper with a higher layer**. Unlike the Cora dataset where GAT's gain is minimal at best, for PPI there is a significant performance gap between GAT and other GNN variants compared in `the GAT paper `_ (at least 20 percent), and the attention distributions between the two clearly differ. While this deserves further research, one immediate conclusion is that GAT's advantage lies perhaps more in its ability to handle a graph with more complex neighborhood structure. What's next? ------------ So far, you have seen how to use DGL to implement GAT. There are some missing details such as dropout, skip connections, and hyper-parameter tuning, which are practices that do not involve DGL-related concepts. For more information check out the full example. * See the optimized `full example `_. * The next tutorial describes how to speedup GAT models by parallelizing multiple attention heads and SPMV optimization. .. |image2| image:: https://data.dgl.ai/tutorial/gat/cora-attention-hist.png .. |image5| image:: https://data.dgl.ai/tutorial/gat/ppi-first-layer-hist.png .. |image6| image:: https://data.dgl.ai/tutorial/gat/ppi-second-layer-hist.png .. |image7| image:: https://data.dgl.ai/tutorial/gat/ppi-final-layer-hist.png .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.888 seconds) .. _sphx_glr_download_tutorials_models_1_gnn_9_gat.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 9_gat.ipynb <9_gat.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 9_gat.py <9_gat.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 9_gat.zip <9_gat.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_