.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/blitz/4_link_predict.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_blitz_4_link_predict.py: Link Prediction using Graph Neural Networks =========================================== In the :doc:`introduction <1_introduction>`, you have already learned the basic workflow of using GNNs for node classification, i.e. predicting the category of a node in a graph. This tutorial will teach you how to train a GNN for link prediction, i.e. predicting the existence of an edge between two arbitrary nodes in a graph. By the end of this tutorial you will be able to - Build a GNN-based link prediction model. - Train and evaluate the model on a small DGL-provided dataset. (Time estimate: 28 minutes) .. GENERATED FROM PYTHON SOURCE LINES 19-33 .. code-block:: Python import itertools import os os.environ["DGLBACKEND"] = "pytorch" import dgl import dgl.data import numpy as np import scipy.sparse as sp import torch import torch.nn as nn import torch.nn.functional as F .. GENERATED FROM PYTHON SOURCE LINES 34-73 Overview of Link Prediction with GNN ------------------------------------ Many applications such as social recommendation, item recommendation, knowledge graph completion, etc., can be formulated as link prediction, which predicts whether an edge exists between two particular nodes. This tutorial shows an example of predicting whether a citation relationship, either citing or being cited, between two papers exists in a citation network. This tutorial formulates the link prediction problem as a binary classification problem as follows: - Treat the edges in the graph as *positive examples*. - Sample a number of non-existent edges (i.e. node pairs with no edges between them) as *negative* examples. - Divide the positive examples and negative examples into a training set and a test set. - Evaluate the model with any binary classification metric such as Area Under Curve (AUC). .. note:: The practice comes from `SEAL `__, although the model here does not use their idea of node labeling. In some domains such as large-scale recommender systems or information retrieval, you may favor metrics that emphasize good performance of top-K predictions. In these cases you may want to consider other metrics such as mean average precision, and use other negative sampling methods, which are beyond the scope of this tutorial. Loading graph and features -------------------------- Following the :doc:`introduction <1_introduction>`, this tutorial first loads the Cora dataset. .. GENERATED FROM PYTHON SOURCE LINES 73-79 .. code-block:: Python dataset = dgl.data.CoraGraphDataset() g = dataset[0] .. rst-class:: sphx-glr-script-out .. code-block:: none NumNodes: 2708 NumEdges: 10556 NumFeats: 1433 NumClasses: 7 NumTrainingSamples: 140 NumValidationSamples: 500 NumTestSamples: 1000 Done loading data from cached files. .. GENERATED FROM PYTHON SOURCE LINES 80-87 Prepare training and testing sets --------------------------------- This tutorial randomly picks 10% of the edges for positive examples in the test set, and leave the rest for the training set. It then samples the same number of edges for negative examples in both sets. .. GENERATED FROM PYTHON SOURCE LINES 87-114 .. code-block:: Python # Split edge set for training and testing u, v = g.edges() eids = np.arange(g.num_edges()) eids = np.random.permutation(eids) test_size = int(len(eids) * 0.1) train_size = g.num_edges() - test_size test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]] train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]] # Find all negative edges and split them for training and testing adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy()))) adj_neg = 1 - adj.todense() - np.eye(g.num_nodes()) neg_u, neg_v = np.where(adj_neg != 0) neg_eids = np.random.choice(len(neg_u), g.num_edges()) test_neg_u, test_neg_v = ( neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]], ) train_neg_u, train_neg_v = ( neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]], ) .. GENERATED FROM PYTHON SOURCE LINES 115-125 When training, you will need to remove the edges in the test set from the original graph. You can do this via ``dgl.remove_edges``. .. note:: ``dgl.remove_edges`` works by creating a subgraph from the original graph, resulting in a copy and therefore could be slow for large graphs. If so, you could save the training and test graph to disk, as you would do for preprocessing. .. GENERATED FROM PYTHON SOURCE LINES 125-129 .. code-block:: Python train_g = dgl.remove_edges(g, eids[:test_size]) .. GENERATED FROM PYTHON SOURCE LINES 130-138 Define a GraphSAGE model ------------------------ This tutorial builds a model consisting of two `GraphSAGE `__ layers, each computes new node representations by averaging neighbor information. DGL provides ``dgl.nn.SAGEConv`` that conveniently creates a GraphSAGE layer. .. GENERATED FROM PYTHON SOURCE LINES 138-157 .. code-block:: Python from dgl.nn import SAGEConv # ----------- 2. create model -------------- # # build a two-layer GraphSAGE model class GraphSAGE(nn.Module): def __init__(self, in_feats, h_feats): super(GraphSAGE, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats, "mean") self.conv2 = SAGEConv(h_feats, h_feats, "mean") def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) return h .. GENERATED FROM PYTHON SOURCE LINES 158-168 The model then predicts the probability of existence of an edge by computing a score between the representations of both incident nodes with a function (e.g. an MLP or a dot product), which you will see in the next section. .. math:: \hat{y}_{u\sim v} = f(h_u, h_v) .. GENERATED FROM PYTHON SOURCE LINES 171-192 Positive graph, negative graph, and ``apply_edges`` --------------------------------------------------- In previous tutorials you have learned how to compute node representations with a GNN. However, link prediction requires you to compute representation of *pairs of nodes*. DGL recommends you to treat the pairs of nodes as another graph, since you can describe a pair of nodes with an edge. In link prediction, you will have a *positive graph* consisting of all the positive examples as edges, and a *negative graph* consisting of all the negative examples. The *positive graph* and the *negative graph* will contain the same set of nodes as the original graph. This makes it easier to pass node features among multiple graphs for computation. As you will see later, you can directly feed the node representations computed on the entire graph to the positive and the negative graphs for computing pair-wise scores. The following code constructs the positive graph and the negative graph for the training set and the test set respectively. .. GENERATED FROM PYTHON SOURCE LINES 192-200 .. code-block:: Python train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.num_nodes()) train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.num_nodes()) test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.num_nodes()) test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.num_nodes()) .. GENERATED FROM PYTHON SOURCE LINES 201-211 The benefit of treating the pairs of nodes as a graph is that you can use the ``DGLGraph.apply_edges`` method, which conveniently computes new edge features based on the incident nodes’ features and the original edge features (if applicable). DGL provides a set of optimized builtin functions to compute new edge features based on the original node/edge features. For example, ``dgl.function.u_dot_v`` computes a dot product of the incident nodes’ representations for each edge. .. GENERATED FROM PYTHON SOURCE LINES 211-226 .. code-block:: Python import dgl.function as fn class DotPredictor(nn.Module): def forward(self, g, h): with g.local_scope(): g.ndata["h"] = h # Compute a new edge feature named 'score' by a dot-product between the # source node feature 'h' and destination node feature 'h'. g.apply_edges(fn.u_dot_v("h", "h", "score")) # u_dot_v returns a 1-element vector for each edge so you need to squeeze it. return g.edata["score"][:, 0] .. GENERATED FROM PYTHON SOURCE LINES 227-231 You can also write your own function if it is complex. For instance, the following module produces a scalar score on each edge by concatenating the incident nodes’ features and passing it to an MLP. .. GENERATED FROM PYTHON SOURCE LINES 231-266 .. code-block:: Python class MLPPredictor(nn.Module): def __init__(self, h_feats): super().__init__() self.W1 = nn.Linear(h_feats * 2, h_feats) self.W2 = nn.Linear(h_feats, 1) def apply_edges(self, edges): """ Computes a scalar score for each edge of the given graph. Parameters ---------- edges : Has three members ``src``, ``dst`` and ``data``, each of which is a dictionary representing the features of the source nodes, the destination nodes, and the edges themselves. Returns ------- dict A dictionary of new edge features. """ h = torch.cat([edges.src["h"], edges.dst["h"]], 1) return {"score": self.W2(F.relu(self.W1(h))).squeeze(1)} def forward(self, g, h): with g.local_scope(): g.ndata["h"] = h g.apply_edges(self.apply_edges) return g.edata["score"] .. GENERATED FROM PYTHON SOURCE LINES 267-279 .. note:: The builtin functions are optimized for both speed and memory. We recommend using builtin functions whenever possible. .. note:: If you have read the :doc:`message passing tutorial <3_message_passing>`, you will notice that the argument ``apply_edges`` takes has exactly the same form as a message function in ``update_all``. .. GENERATED FROM PYTHON SOURCE LINES 282-298 Training loop ------------- After you defined the node representation computation and the edge score computation, you can go ahead and define the overall model, loss function, and evaluation metric. The loss function is simply binary cross entropy loss. .. math:: \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right) The evaluation metric in this tutorial is AUC. .. GENERATED FROM PYTHON SOURCE LINES 298-321 .. code-block:: Python model = GraphSAGE(train_g.ndata["feat"].shape[1], 16) # You can replace DotPredictor with MLPPredictor. # pred = MLPPredictor(16) pred = DotPredictor() def compute_loss(pos_score, neg_score): scores = torch.cat([pos_score, neg_score]) labels = torch.cat( [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])] ) return F.binary_cross_entropy_with_logits(scores, labels) def compute_auc(pos_score, neg_score): scores = torch.cat([pos_score, neg_score]).numpy() labels = torch.cat( [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])] ).numpy() return roc_auc_score(labels, scores) .. GENERATED FROM PYTHON SOURCE LINES 322-330 The training loop goes as follows: .. note:: This tutorial does not include evaluation on a validation set. In practice you should save and evaluate the best model based on performance on the validation set. .. GENERATED FROM PYTHON SOURCE LINES 330-365 .. code-block:: Python # ----------- 3. set up loss and optimizer -------------- # # in this case, loss will in training loop optimizer = torch.optim.Adam( itertools.chain(model.parameters(), pred.parameters()), lr=0.01 ) # ----------- 4. training -------------------------------- # all_logits = [] for e in range(100): # forward h = model(train_g, train_g.ndata["feat"]) pos_score = pred(train_pos_g, h) neg_score = pred(train_neg_g, h) loss = compute_loss(pos_score, neg_score) # backward optimizer.zero_grad() loss.backward() optimizer.step() if e % 5 == 0: print("In epoch {}, loss: {}".format(e, loss)) # ----------- 5. check results ------------------------ # from sklearn.metrics import roc_auc_score with torch.no_grad(): pos_score = pred(test_pos_g, h) neg_score = pred(test_neg_g, h) print("AUC", compute_auc(pos_score, neg_score)) # Thumbnail credits: Link Prediction with Neo4j, Mark Needham # sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png' .. rst-class:: sphx-glr-script-out .. code-block:: none In epoch 0, loss: 0.7149215340614319 In epoch 5, loss: 0.6916221976280212 In epoch 10, loss: 0.6746882200241089 In epoch 15, loss: 0.6267638206481934 In epoch 20, loss: 0.5622928738594055 In epoch 25, loss: 0.5335882306098938 In epoch 30, loss: 0.5172640085220337 In epoch 35, loss: 0.4941767156124115 In epoch 40, loss: 0.4754476547241211 In epoch 45, loss: 0.4552139639854431 In epoch 50, loss: 0.434887558221817 In epoch 55, loss: 0.41928672790527344 In epoch 60, loss: 0.4020548164844513 In epoch 65, loss: 0.3857075870037079 In epoch 70, loss: 0.368327260017395 In epoch 75, loss: 0.3507421016693115 In epoch 80, loss: 0.3327259123325348 In epoch 85, loss: 0.31475234031677246 In epoch 90, loss: 0.2969270646572113 In epoch 95, loss: 0.2797468602657318 AUC 0.8425893398620876 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 1.206 seconds) .. _sphx_glr_download_tutorials_blitz_4_link_predict.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 4_link_predict.ipynb <4_link_predict.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 4_link_predict.py <4_link_predict.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 4_link_predict.zip <4_link_predict.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_