.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/blitz/3_message_passing.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_blitz_3_message_passing.py: Write your own GNN module ========================= Sometimes, your model goes beyond simply stacking existing GNN modules. For example, you would like to invent a new way of aggregating neighbor information by considering node importance or edge weights. By the end of this tutorial you will be able to - Understand DGL’s message passing APIs. - Implement GraphSAGE convolution module by your own. This tutorial assumes that you already know :doc:`the basics of training a GNN for node classification <1_introduction>`. (Time estimate: 10 minutes) .. GENERATED FROM PYTHON SOURCE LINES 20-30 .. code-block:: Python import os os.environ["DGLBACKEND"] = "pytorch" import dgl import dgl.function as fn import torch import torch.nn as nn import torch.nn.functional as F .. GENERATED FROM PYTHON SOURCE LINES 31-59 Message passing and GNNs ------------------------ DGL follows the *message passing paradigm* inspired by the Message Passing Neural Network proposed by `Gilmer et al. `__ Essentially, they found many GNN models can fit into the following framework: .. math:: m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) .. math:: m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)} .. math:: h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right) where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the *reduce function* and :math:`U^{(l)}` the *update function*. Note that :math:`\sum` here can represent any function and is not necessarily a summation. .. GENERATED FROM PYTHON SOURCE LINES 62-85 For example, the `GraphSAGE convolution (Hamilton et al., 2017) `__ takes the following mathematical form: .. math:: h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\} .. math:: h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right) You can see that message passing is directional: the message sent from one node :math:`u` to other node :math:`v` is not necessarily the same as the other message sent from node :math:`v` to node :math:`u` in the opposite direction. Although DGL has builtin support of GraphSAGE via :class:`dgl.nn.SAGEConv `, here is how you can implement GraphSAGE convolution in DGL by your own. .. GENERATED FROM PYTHON SOURCE LINES 85-125 .. code-block:: Python class SAGEConv(nn.Module): """Graph convolution module used by the GraphSAGE model. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(SAGEConv, self).__init__() # A linear submodule for projecting the input and neighbor feature to the output. self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h): """Forward computation Parameters ---------- g : Graph The input graph. h : Tensor The input node feature. """ with g.local_scope(): g.ndata["h"] = h # update_all is a message passing API. g.update_all( message_func=fn.copy_u("h", "m"), reduce_func=fn.mean("m", "h_N"), ) h_N = g.ndata["h_N"] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total) .. GENERATED FROM PYTHON SOURCE LINES 126-142 The central piece in this code is the :func:`g.update_all ` function, which gathers and averages the neighbor features. There are three concepts here: * Message function ``fn.copy_u('h', 'm')`` that copies the node feature under name ``'h'`` as *messages* with name ``'m'`` sent to neighbors. * Reduce function ``fn.mean('m', 'h_N')`` that averages all the received messages under name ``'m'`` and saves the result as a new node feature ``'h_N'``. * ``update_all`` tells DGL to trigger the message and reduce functions for all the nodes and edges. .. GENERATED FROM PYTHON SOURCE LINES 145-148 Afterwards, you can stack your own GraphSAGE convolution layers to form a multi-layer GraphSAGE network. .. GENERATED FROM PYTHON SOURCE LINES 148-163 .. code-block:: Python class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats) self.conv2 = SAGEConv(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) return h .. GENERATED FROM PYTHON SOURCE LINES 164-169 Training loop ~~~~~~~~~~~~~ The following code for data loading and training loop is directly copied from the introduction tutorial. .. GENERATED FROM PYTHON SOURCE LINES 169-227 .. code-block:: Python import dgl.data dataset = dgl.data.CoraGraphDataset() g = dataset[0] def train(g, model): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) all_logits = [] best_val_acc = 0 best_test_acc = 0 features = g.ndata["feat"] labels = g.ndata["label"] train_mask = g.ndata["train_mask"] val_mask = g.ndata["val_mask"] test_mask = g.ndata["test_mask"] for e in range(200): # Forward logits = model(g, features) # Compute prediction pred = logits.argmax(1) # Compute loss # Note that we should only compute the losses of the nodes in the training set, # i.e. with train_mask 1. loss = F.cross_entropy(logits[train_mask], labels[train_mask]) # Compute accuracy on training/validation/test train_acc = (pred[train_mask] == labels[train_mask]).float().mean() val_acc = (pred[val_mask] == labels[val_mask]).float().mean() test_acc = (pred[test_mask] == labels[test_mask]).float().mean() # Save the best validation accuracy and the corresponding test accuracy. if best_val_acc < val_acc: best_val_acc = val_acc best_test_acc = test_acc # Backward optimizer.zero_grad() loss.backward() optimizer.step() all_logits.append(logits.detach()) if e % 5 == 0: print( "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format( e, loss, val_acc, best_val_acc, test_acc, best_test_acc ) ) model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes) train(g, model) .. rst-class:: sphx-glr-script-out .. code-block:: none NumNodes: 2708 NumEdges: 10556 NumFeats: 1433 NumClasses: 7 NumTrainingSamples: 140 NumValidationSamples: 500 NumTestSamples: 1000 Done loading data from cached files. In epoch 0, loss: 1.951, val acc: 0.162 (best 0.162), test acc: 0.149 (best 0.149) In epoch 5, loss: 1.875, val acc: 0.328 (best 0.328), test acc: 0.312 (best 0.312) In epoch 10, loss: 1.731, val acc: 0.628 (best 0.654), test acc: 0.585 (best 0.609) In epoch 15, loss: 1.512, val acc: 0.594 (best 0.654), test acc: 0.571 (best 0.609) In epoch 20, loss: 1.223, val acc: 0.666 (best 0.666), test acc: 0.625 (best 0.625) In epoch 25, loss: 0.896, val acc: 0.728 (best 0.728), test acc: 0.698 (best 0.698) In epoch 30, loss: 0.586, val acc: 0.758 (best 0.758), test acc: 0.736 (best 0.736) In epoch 35, loss: 0.348, val acc: 0.770 (best 0.770), test acc: 0.762 (best 0.762) In epoch 40, loss: 0.197, val acc: 0.782 (best 0.782), test acc: 0.771 (best 0.771) In epoch 45, loss: 0.112, val acc: 0.772 (best 0.782), test acc: 0.770 (best 0.771) In epoch 50, loss: 0.067, val acc: 0.766 (best 0.782), test acc: 0.764 (best 0.771) In epoch 55, loss: 0.043, val acc: 0.762 (best 0.782), test acc: 0.762 (best 0.771) In epoch 60, loss: 0.030, val acc: 0.760 (best 0.782), test acc: 0.760 (best 0.771) In epoch 65, loss: 0.022, val acc: 0.758 (best 0.782), test acc: 0.759 (best 0.771) In epoch 70, loss: 0.017, val acc: 0.758 (best 0.782), test acc: 0.761 (best 0.771) In epoch 75, loss: 0.014, val acc: 0.760 (best 0.782), test acc: 0.761 (best 0.771) In epoch 80, loss: 0.012, val acc: 0.760 (best 0.782), test acc: 0.760 (best 0.771) In epoch 85, loss: 0.011, val acc: 0.760 (best 0.782), test acc: 0.761 (best 0.771) In epoch 90, loss: 0.009, val acc: 0.760 (best 0.782), test acc: 0.761 (best 0.771) In epoch 95, loss: 0.009, val acc: 0.760 (best 0.782), test acc: 0.760 (best 0.771) In epoch 100, loss: 0.008, val acc: 0.758 (best 0.782), test acc: 0.759 (best 0.771) In epoch 105, loss: 0.007, val acc: 0.758 (best 0.782), test acc: 0.759 (best 0.771) In epoch 110, loss: 0.007, val acc: 0.758 (best 0.782), test acc: 0.758 (best 0.771) In epoch 115, loss: 0.006, val acc: 0.758 (best 0.782), test acc: 0.757 (best 0.771) In epoch 120, loss: 0.006, val acc: 0.760 (best 0.782), test acc: 0.758 (best 0.771) In epoch 125, loss: 0.005, val acc: 0.758 (best 0.782), test acc: 0.758 (best 0.771) In epoch 130, loss: 0.005, val acc: 0.758 (best 0.782), test acc: 0.758 (best 0.771) In epoch 135, loss: 0.005, val acc: 0.758 (best 0.782), test acc: 0.758 (best 0.771) In epoch 140, loss: 0.005, val acc: 0.758 (best 0.782), test acc: 0.756 (best 0.771) In epoch 145, loss: 0.004, val acc: 0.760 (best 0.782), test acc: 0.756 (best 0.771) In epoch 150, loss: 0.004, val acc: 0.760 (best 0.782), test acc: 0.756 (best 0.771) In epoch 155, loss: 0.004, val acc: 0.758 (best 0.782), test acc: 0.755 (best 0.771) In epoch 160, loss: 0.004, val acc: 0.756 (best 0.782), test acc: 0.755 (best 0.771) In epoch 165, loss: 0.004, val acc: 0.756 (best 0.782), test acc: 0.755 (best 0.771) In epoch 170, loss: 0.003, val acc: 0.756 (best 0.782), test acc: 0.753 (best 0.771) In epoch 175, loss: 0.003, val acc: 0.756 (best 0.782), test acc: 0.753 (best 0.771) In epoch 180, loss: 0.003, val acc: 0.756 (best 0.782), test acc: 0.752 (best 0.771) In epoch 185, loss: 0.003, val acc: 0.756 (best 0.782), test acc: 0.753 (best 0.771) In epoch 190, loss: 0.003, val acc: 0.756 (best 0.782), test acc: 0.753 (best 0.771) In epoch 195, loss: 0.003, val acc: 0.756 (best 0.782), test acc: 0.753 (best 0.771) .. GENERATED FROM PYTHON SOURCE LINES 228-235 More customization ------------------ In DGL, we provide many built-in message and reduce functions under the ``dgl.function`` package. You can find more details in :ref:`the API doc `. .. GENERATED FROM PYTHON SOURCE LINES 238-244 These APIs allow one to quickly implement new graph convolution modules. For example, the following implements a new ``SAGEConv`` that aggregates neighbor representations using a weighted average. Note that ``edata`` member can hold edge features which can also take part in message passing. .. GENERATED FROM PYTHON SOURCE LINES 244-286 .. code-block:: Python class WeightedSAGEConv(nn.Module): """Graph convolution module used by the GraphSAGE model with edge weights. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(WeightedSAGEConv, self).__init__() # A linear submodule for projecting the input and neighbor feature to the output. self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h, w): """Forward computation Parameters ---------- g : Graph The input graph. h : Tensor The input node feature. w : Tensor The edge weight. """ with g.local_scope(): g.ndata["h"] = h g.edata["w"] = w g.update_all( message_func=fn.u_mul_e("h", "w", "m"), reduce_func=fn.mean("m", "h_N"), ) h_N = g.ndata["h_N"] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total) .. GENERATED FROM PYTHON SOURCE LINES 287-291 Because the graph in this dataset does not have edge weights, we manually assign all edge weights to one in the ``forward()`` function of the model. You can replace it with your own edge weights. .. GENERATED FROM PYTHON SOURCE LINES 291-310 .. code-block:: Python class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = WeightedSAGEConv(in_feats, h_feats) self.conv2 = WeightedSAGEConv(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device)) h = F.relu(h) h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device)) return h model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes) train(g, model) .. rst-class:: sphx-glr-script-out .. code-block:: none In epoch 0, loss: 1.955, val acc: 0.114 (best 0.114), test acc: 0.104 (best 0.104) In epoch 5, loss: 1.891, val acc: 0.224 (best 0.224), test acc: 0.219 (best 0.219) In epoch 10, loss: 1.771, val acc: 0.400 (best 0.400), test acc: 0.380 (best 0.380) In epoch 15, loss: 1.585, val acc: 0.652 (best 0.652), test acc: 0.615 (best 0.615) In epoch 20, loss: 1.338, val acc: 0.668 (best 0.670), test acc: 0.649 (best 0.639) In epoch 25, loss: 1.050, val acc: 0.666 (best 0.672), test acc: 0.657 (best 0.654) In epoch 30, loss: 0.761, val acc: 0.676 (best 0.676), test acc: 0.667 (best 0.667) In epoch 35, loss: 0.512, val acc: 0.704 (best 0.704), test acc: 0.684 (best 0.684) In epoch 40, loss: 0.326, val acc: 0.706 (best 0.706), test acc: 0.700 (best 0.692) In epoch 45, loss: 0.203, val acc: 0.716 (best 0.716), test acc: 0.709 (best 0.709) In epoch 50, loss: 0.127, val acc: 0.712 (best 0.720), test acc: 0.715 (best 0.712) In epoch 55, loss: 0.082, val acc: 0.712 (best 0.720), test acc: 0.712 (best 0.712) In epoch 60, loss: 0.055, val acc: 0.710 (best 0.720), test acc: 0.709 (best 0.712) In epoch 65, loss: 0.040, val acc: 0.716 (best 0.720), test acc: 0.709 (best 0.712) In epoch 70, loss: 0.030, val acc: 0.716 (best 0.720), test acc: 0.710 (best 0.712) In epoch 75, loss: 0.024, val acc: 0.716 (best 0.720), test acc: 0.709 (best 0.712) In epoch 80, loss: 0.020, val acc: 0.714 (best 0.720), test acc: 0.707 (best 0.712) In epoch 85, loss: 0.017, val acc: 0.712 (best 0.720), test acc: 0.708 (best 0.712) In epoch 90, loss: 0.015, val acc: 0.708 (best 0.720), test acc: 0.707 (best 0.712) In epoch 95, loss: 0.013, val acc: 0.708 (best 0.720), test acc: 0.708 (best 0.712) In epoch 100, loss: 0.012, val acc: 0.708 (best 0.720), test acc: 0.708 (best 0.712) In epoch 105, loss: 0.011, val acc: 0.708 (best 0.720), test acc: 0.708 (best 0.712) In epoch 110, loss: 0.010, val acc: 0.706 (best 0.720), test acc: 0.706 (best 0.712) In epoch 115, loss: 0.009, val acc: 0.706 (best 0.720), test acc: 0.706 (best 0.712) In epoch 120, loss: 0.008, val acc: 0.706 (best 0.720), test acc: 0.708 (best 0.712) In epoch 125, loss: 0.008, val acc: 0.708 (best 0.720), test acc: 0.706 (best 0.712) In epoch 130, loss: 0.007, val acc: 0.708 (best 0.720), test acc: 0.707 (best 0.712) In epoch 135, loss: 0.007, val acc: 0.708 (best 0.720), test acc: 0.707 (best 0.712) In epoch 140, loss: 0.006, val acc: 0.708 (best 0.720), test acc: 0.707 (best 0.712) In epoch 145, loss: 0.006, val acc: 0.708 (best 0.720), test acc: 0.707 (best 0.712) In epoch 150, loss: 0.006, val acc: 0.708 (best 0.720), test acc: 0.707 (best 0.712) In epoch 155, loss: 0.005, val acc: 0.704 (best 0.720), test acc: 0.706 (best 0.712) In epoch 160, loss: 0.005, val acc: 0.706 (best 0.720), test acc: 0.706 (best 0.712) In epoch 165, loss: 0.005, val acc: 0.706 (best 0.720), test acc: 0.707 (best 0.712) In epoch 170, loss: 0.004, val acc: 0.706 (best 0.720), test acc: 0.707 (best 0.712) In epoch 175, loss: 0.004, val acc: 0.706 (best 0.720), test acc: 0.707 (best 0.712) In epoch 180, loss: 0.004, val acc: 0.706 (best 0.720), test acc: 0.707 (best 0.712) In epoch 185, loss: 0.004, val acc: 0.706 (best 0.720), test acc: 0.707 (best 0.712) In epoch 190, loss: 0.004, val acc: 0.706 (best 0.720), test acc: 0.708 (best 0.712) In epoch 195, loss: 0.004, val acc: 0.706 (best 0.720), test acc: 0.708 (best 0.712) .. GENERATED FROM PYTHON SOURCE LINES 311-318 Even more customization by user-defined function ------------------------------------------------ DGL allows user-defined message and reduce function for the maximal expressiveness. Here is a user-defined message function that is equivalent to ``fn.u_mul_e('h', 'w', 'm')``. .. GENERATED FROM PYTHON SOURCE LINES 318-324 .. code-block:: Python def u_mul_e_udf(edges): return {"m": edges.src["h"] * edges.data["w"]} .. GENERATED FROM PYTHON SOURCE LINES 325-329 ``edges`` has three members: ``src``, ``data`` and ``dst``, representing the source node feature, edge feature, and destination node feature for all edges. .. GENERATED FROM PYTHON SOURCE LINES 332-336 You can also write your own reduce function. For example, the following is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages the incoming messages: .. GENERATED FROM PYTHON SOURCE LINES 336-342 .. code-block:: Python def mean_udf(nodes): return {"h_N": nodes.mailbox["m"].mean(1)} .. GENERATED FROM PYTHON SOURCE LINES 343-352 In short, DGL will group the nodes by their in-degrees, and for each group DGL stacks the incoming messages along the second dimension. You can then perform a reduction along the second dimension to aggregate messages. For more details on customizing message and reduce function with user-defined function, please refer to the :ref:`API reference `. .. GENERATED FROM PYTHON SOURCE LINES 355-367 Best practice of writing custom GNN modules ------------------------------------------- DGL recommends the following practice ranked by preference: - Use ``dgl.nn`` modules. - Use ``dgl.nn.functional`` functions which contain lower-level complex operations such as computing a softmax for each node over incoming edges. - Use ``update_all`` with builtin message and reduce functions. - Use user-defined message or reduce functions. .. GENERATED FROM PYTHON SOURCE LINES 370-376 What’s next? ------------ - :ref:`Writing Efficient Message Passing Code `. .. GENERATED FROM PYTHON SOURCE LINES 376-380 .. code-block:: Python # Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018 # sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png' .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.655 seconds) .. _sphx_glr_download_tutorials_blitz_3_message_passing.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 3_message_passing.ipynb <3_message_passing.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 3_message_passing.py <3_message_passing.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 3_message_passing.zip <3_message_passing.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_