.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/blitz/3_message_passing.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_blitz_3_message_passing.py: Write your own GNN module ========================= Sometimes, your model goes beyond simply stacking existing GNN modules. For example, you would like to invent a new way of aggregating neighbor information by considering node importance or edge weights. By the end of this tutorial you will be able to - Understand DGL’s message passing APIs. - Implement GraphSAGE convolution module by your own. This tutorial assumes that you already know :doc:`the basics of training a GNN for node classification <1_introduction>`. (Time estimate: 10 minutes) .. GENERATED FROM PYTHON SOURCE LINES 20-30 .. code-block:: Python import os os.environ["DGLBACKEND"] = "pytorch" import dgl import dgl.function as fn import torch import torch.nn as nn import torch.nn.functional as F .. GENERATED FROM PYTHON SOURCE LINES 31-59 Message passing and GNNs ------------------------ DGL follows the *message passing paradigm* inspired by the Message Passing Neural Network proposed by `Gilmer et al. `__ Essentially, they found many GNN models can fit into the following framework: .. math:: m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) .. math:: m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)} .. math:: h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right) where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the *reduce function* and :math:`U^{(l)}` the *update function*. Note that :math:`\sum` here can represent any function and is not necessarily a summation. .. GENERATED FROM PYTHON SOURCE LINES 62-85 For example, the `GraphSAGE convolution (Hamilton et al., 2017) `__ takes the following mathematical form: .. math:: h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\} .. math:: h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right) You can see that message passing is directional: the message sent from one node :math:`u` to other node :math:`v` is not necessarily the same as the other message sent from node :math:`v` to node :math:`u` in the opposite direction. Although DGL has builtin support of GraphSAGE via :class:`dgl.nn.SAGEConv `, here is how you can implement GraphSAGE convolution in DGL by your own. .. GENERATED FROM PYTHON SOURCE LINES 85-125 .. code-block:: Python class SAGEConv(nn.Module): """Graph convolution module used by the GraphSAGE model. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(SAGEConv, self).__init__() # A linear submodule for projecting the input and neighbor feature to the output. self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h): """Forward computation Parameters ---------- g : Graph The input graph. h : Tensor The input node feature. """ with g.local_scope(): g.ndata["h"] = h # update_all is a message passing API. g.update_all( message_func=fn.copy_u("h", "m"), reduce_func=fn.mean("m", "h_N"), ) h_N = g.ndata["h_N"] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total) .. GENERATED FROM PYTHON SOURCE LINES 126-142 The central piece in this code is the :func:`g.update_all ` function, which gathers and averages the neighbor features. There are three concepts here: * Message function ``fn.copy_u('h', 'm')`` that copies the node feature under name ``'h'`` as *messages* with name ``'m'`` sent to neighbors. * Reduce function ``fn.mean('m', 'h_N')`` that averages all the received messages under name ``'m'`` and saves the result as a new node feature ``'h_N'``. * ``update_all`` tells DGL to trigger the message and reduce functions for all the nodes and edges. .. GENERATED FROM PYTHON SOURCE LINES 145-148 Afterwards, you can stack your own GraphSAGE convolution layers to form a multi-layer GraphSAGE network. .. GENERATED FROM PYTHON SOURCE LINES 148-163 .. code-block:: Python class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats) self.conv2 = SAGEConv(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) return h .. GENERATED FROM PYTHON SOURCE LINES 164-169 Training loop ~~~~~~~~~~~~~ The following code for data loading and training loop is directly copied from the introduction tutorial. .. GENERATED FROM PYTHON SOURCE LINES 169-227 .. code-block:: Python import dgl.data dataset = dgl.data.CoraGraphDataset() g = dataset[0] def train(g, model): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) all_logits = [] best_val_acc = 0 best_test_acc = 0 features = g.ndata["feat"] labels = g.ndata["label"] train_mask = g.ndata["train_mask"] val_mask = g.ndata["val_mask"] test_mask = g.ndata["test_mask"] for e in range(200): # Forward logits = model(g, features) # Compute prediction pred = logits.argmax(1) # Compute loss # Note that we should only compute the losses of the nodes in the training set, # i.e. with train_mask 1. loss = F.cross_entropy(logits[train_mask], labels[train_mask]) # Compute accuracy on training/validation/test train_acc = (pred[train_mask] == labels[train_mask]).float().mean() val_acc = (pred[val_mask] == labels[val_mask]).float().mean() test_acc = (pred[test_mask] == labels[test_mask]).float().mean() # Save the best validation accuracy and the corresponding test accuracy. if best_val_acc < val_acc: best_val_acc = val_acc best_test_acc = test_acc # Backward optimizer.zero_grad() loss.backward() optimizer.step() all_logits.append(logits.detach()) if e % 5 == 0: print( "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format( e, loss, val_acc, best_val_acc, test_acc, best_test_acc ) ) model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes) train(g, model) .. rst-class:: sphx-glr-script-out .. code-block:: none NumNodes: 2708 NumEdges: 10556 NumFeats: 1433 NumClasses: 7 NumTrainingSamples: 140 NumValidationSamples: 500 NumTestSamples: 1000 Done loading data from cached files. In epoch 0, loss: 1.948, val acc: 0.122 (best 0.122), test acc: 0.130 (best 0.130) In epoch 5, loss: 1.885, val acc: 0.106 (best 0.180), test acc: 0.113 (best 0.196) In epoch 10, loss: 1.754, val acc: 0.352 (best 0.352), test acc: 0.351 (best 0.351) In epoch 15, loss: 1.555, val acc: 0.434 (best 0.434), test acc: 0.437 (best 0.437) In epoch 20, loss: 1.295, val acc: 0.510 (best 0.510), test acc: 0.502 (best 0.502) In epoch 25, loss: 1.002, val acc: 0.548 (best 0.548), test acc: 0.567 (best 0.567) In epoch 30, loss: 0.714, val acc: 0.624 (best 0.624), test acc: 0.625 (best 0.625) In epoch 35, loss: 0.469, val acc: 0.662 (best 0.662), test acc: 0.676 (best 0.676) In epoch 40, loss: 0.287, val acc: 0.704 (best 0.704), test acc: 0.712 (best 0.712) In epoch 45, loss: 0.171, val acc: 0.732 (best 0.732), test acc: 0.726 (best 0.726) In epoch 50, loss: 0.102, val acc: 0.746 (best 0.746), test acc: 0.735 (best 0.735) In epoch 55, loss: 0.064, val acc: 0.746 (best 0.746), test acc: 0.736 (best 0.735) In epoch 60, loss: 0.042, val acc: 0.746 (best 0.746), test acc: 0.738 (best 0.735) In epoch 65, loss: 0.030, val acc: 0.746 (best 0.748), test acc: 0.741 (best 0.738) In epoch 70, loss: 0.022, val acc: 0.748 (best 0.750), test acc: 0.742 (best 0.742) In epoch 75, loss: 0.018, val acc: 0.750 (best 0.750), test acc: 0.743 (best 0.742) In epoch 80, loss: 0.015, val acc: 0.750 (best 0.750), test acc: 0.743 (best 0.742) In epoch 85, loss: 0.012, val acc: 0.750 (best 0.750), test acc: 0.747 (best 0.742) In epoch 90, loss: 0.011, val acc: 0.752 (best 0.752), test acc: 0.747 (best 0.747) In epoch 95, loss: 0.010, val acc: 0.754 (best 0.754), test acc: 0.748 (best 0.748) In epoch 100, loss: 0.009, val acc: 0.754 (best 0.754), test acc: 0.748 (best 0.748) In epoch 105, loss: 0.008, val acc: 0.754 (best 0.754), test acc: 0.751 (best 0.748) In epoch 110, loss: 0.007, val acc: 0.754 (best 0.754), test acc: 0.751 (best 0.748) In epoch 115, loss: 0.007, val acc: 0.752 (best 0.754), test acc: 0.753 (best 0.748) In epoch 120, loss: 0.006, val acc: 0.750 (best 0.754), test acc: 0.753 (best 0.748) In epoch 125, loss: 0.006, val acc: 0.750 (best 0.754), test acc: 0.752 (best 0.748) In epoch 130, loss: 0.005, val acc: 0.748 (best 0.754), test acc: 0.753 (best 0.748) In epoch 135, loss: 0.005, val acc: 0.748 (best 0.754), test acc: 0.753 (best 0.748) In epoch 140, loss: 0.005, val acc: 0.748 (best 0.754), test acc: 0.753 (best 0.748) In epoch 145, loss: 0.005, val acc: 0.748 (best 0.754), test acc: 0.752 (best 0.748) In epoch 150, loss: 0.004, val acc: 0.746 (best 0.754), test acc: 0.752 (best 0.748) In epoch 155, loss: 0.004, val acc: 0.748 (best 0.754), test acc: 0.752 (best 0.748) In epoch 160, loss: 0.004, val acc: 0.746 (best 0.754), test acc: 0.753 (best 0.748) In epoch 165, loss: 0.004, val acc: 0.746 (best 0.754), test acc: 0.753 (best 0.748) In epoch 170, loss: 0.003, val acc: 0.746 (best 0.754), test acc: 0.753 (best 0.748) In epoch 175, loss: 0.003, val acc: 0.748 (best 0.754), test acc: 0.754 (best 0.748) In epoch 180, loss: 0.003, val acc: 0.748 (best 0.754), test acc: 0.755 (best 0.748) In epoch 185, loss: 0.003, val acc: 0.750 (best 0.754), test acc: 0.755 (best 0.748) In epoch 190, loss: 0.003, val acc: 0.754 (best 0.754), test acc: 0.756 (best 0.748) In epoch 195, loss: 0.003, val acc: 0.754 (best 0.754), test acc: 0.756 (best 0.748) .. GENERATED FROM PYTHON SOURCE LINES 228-235 More customization ------------------ In DGL, we provide many built-in message and reduce functions under the ``dgl.function`` package. You can find more details in :ref:`the API doc `. .. GENERATED FROM PYTHON SOURCE LINES 238-244 These APIs allow one to quickly implement new graph convolution modules. For example, the following implements a new ``SAGEConv`` that aggregates neighbor representations using a weighted average. Note that ``edata`` member can hold edge features which can also take part in message passing. .. GENERATED FROM PYTHON SOURCE LINES 244-286 .. code-block:: Python class WeightedSAGEConv(nn.Module): """Graph convolution module used by the GraphSAGE model with edge weights. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(WeightedSAGEConv, self).__init__() # A linear submodule for projecting the input and neighbor feature to the output. self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h, w): """Forward computation Parameters ---------- g : Graph The input graph. h : Tensor The input node feature. w : Tensor The edge weight. """ with g.local_scope(): g.ndata["h"] = h g.edata["w"] = w g.update_all( message_func=fn.u_mul_e("h", "w", "m"), reduce_func=fn.mean("m", "h_N"), ) h_N = g.ndata["h_N"] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total) .. GENERATED FROM PYTHON SOURCE LINES 287-291 Because the graph in this dataset does not have edge weights, we manually assign all edge weights to one in the ``forward()`` function of the model. You can replace it with your own edge weights. .. GENERATED FROM PYTHON SOURCE LINES 291-310 .. code-block:: Python class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = WeightedSAGEConv(in_feats, h_feats) self.conv2 = WeightedSAGEConv(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device)) h = F.relu(h) h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device)) return h model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes) train(g, model) .. rst-class:: sphx-glr-script-out .. code-block:: none In epoch 0, loss: 1.952, val acc: 0.072 (best 0.072), test acc: 0.091 (best 0.091) In epoch 5, loss: 1.863, val acc: 0.156 (best 0.156), test acc: 0.181 (best 0.181) In epoch 10, loss: 1.702, val acc: 0.432 (best 0.504), test acc: 0.418 (best 0.498) In epoch 15, loss: 1.466, val acc: 0.448 (best 0.504), test acc: 0.425 (best 0.498) In epoch 20, loss: 1.170, val acc: 0.536 (best 0.536), test acc: 0.504 (best 0.504) In epoch 25, loss: 0.852, val acc: 0.650 (best 0.650), test acc: 0.638 (best 0.638) In epoch 30, loss: 0.559, val acc: 0.738 (best 0.738), test acc: 0.728 (best 0.728) In epoch 35, loss: 0.334, val acc: 0.766 (best 0.766), test acc: 0.768 (best 0.768) In epoch 40, loss: 0.189, val acc: 0.770 (best 0.770), test acc: 0.774 (best 0.774) In epoch 45, loss: 0.107, val acc: 0.776 (best 0.776), test acc: 0.775 (best 0.774) In epoch 50, loss: 0.064, val acc: 0.770 (best 0.776), test acc: 0.771 (best 0.774) In epoch 55, loss: 0.040, val acc: 0.770 (best 0.776), test acc: 0.770 (best 0.774) In epoch 60, loss: 0.028, val acc: 0.770 (best 0.776), test acc: 0.771 (best 0.774) In epoch 65, loss: 0.021, val acc: 0.774 (best 0.776), test acc: 0.769 (best 0.774) In epoch 70, loss: 0.016, val acc: 0.772 (best 0.776), test acc: 0.769 (best 0.774) In epoch 75, loss: 0.013, val acc: 0.770 (best 0.776), test acc: 0.766 (best 0.774) In epoch 80, loss: 0.011, val acc: 0.770 (best 0.776), test acc: 0.765 (best 0.774) In epoch 85, loss: 0.010, val acc: 0.770 (best 0.776), test acc: 0.764 (best 0.774) In epoch 90, loss: 0.009, val acc: 0.772 (best 0.776), test acc: 0.765 (best 0.774) In epoch 95, loss: 0.008, val acc: 0.770 (best 0.776), test acc: 0.765 (best 0.774) In epoch 100, loss: 0.007, val acc: 0.768 (best 0.776), test acc: 0.764 (best 0.774) In epoch 105, loss: 0.007, val acc: 0.768 (best 0.776), test acc: 0.764 (best 0.774) In epoch 110, loss: 0.006, val acc: 0.764 (best 0.776), test acc: 0.765 (best 0.774) In epoch 115, loss: 0.006, val acc: 0.764 (best 0.776), test acc: 0.765 (best 0.774) In epoch 120, loss: 0.005, val acc: 0.764 (best 0.776), test acc: 0.765 (best 0.774) In epoch 125, loss: 0.005, val acc: 0.766 (best 0.776), test acc: 0.766 (best 0.774) In epoch 130, loss: 0.005, val acc: 0.764 (best 0.776), test acc: 0.766 (best 0.774) In epoch 135, loss: 0.005, val acc: 0.764 (best 0.776), test acc: 0.765 (best 0.774) In epoch 140, loss: 0.004, val acc: 0.764 (best 0.776), test acc: 0.765 (best 0.774) In epoch 145, loss: 0.004, val acc: 0.764 (best 0.776), test acc: 0.765 (best 0.774) In epoch 150, loss: 0.004, val acc: 0.764 (best 0.776), test acc: 0.765 (best 0.774) In epoch 155, loss: 0.004, val acc: 0.762 (best 0.776), test acc: 0.766 (best 0.774) In epoch 160, loss: 0.003, val acc: 0.760 (best 0.776), test acc: 0.766 (best 0.774) In epoch 165, loss: 0.003, val acc: 0.760 (best 0.776), test acc: 0.766 (best 0.774) In epoch 170, loss: 0.003, val acc: 0.760 (best 0.776), test acc: 0.766 (best 0.774) In epoch 175, loss: 0.003, val acc: 0.760 (best 0.776), test acc: 0.766 (best 0.774) In epoch 180, loss: 0.003, val acc: 0.760 (best 0.776), test acc: 0.767 (best 0.774) In epoch 185, loss: 0.003, val acc: 0.760 (best 0.776), test acc: 0.767 (best 0.774) In epoch 190, loss: 0.003, val acc: 0.760 (best 0.776), test acc: 0.767 (best 0.774) In epoch 195, loss: 0.003, val acc: 0.760 (best 0.776), test acc: 0.767 (best 0.774) .. GENERATED FROM PYTHON SOURCE LINES 311-318 Even more customization by user-defined function ------------------------------------------------ DGL allows user-defined message and reduce function for the maximal expressiveness. Here is a user-defined message function that is equivalent to ``fn.u_mul_e('h', 'w', 'm')``. .. GENERATED FROM PYTHON SOURCE LINES 318-324 .. code-block:: Python def u_mul_e_udf(edges): return {"m": edges.src["h"] * edges.data["w"]} .. GENERATED FROM PYTHON SOURCE LINES 325-329 ``edges`` has three members: ``src``, ``data`` and ``dst``, representing the source node feature, edge feature, and destination node feature for all edges. .. GENERATED FROM PYTHON SOURCE LINES 332-336 You can also write your own reduce function. For example, the following is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages the incoming messages: .. GENERATED FROM PYTHON SOURCE LINES 336-342 .. code-block:: Python def mean_udf(nodes): return {"h_N": nodes.mailbox["m"].mean(1)} .. GENERATED FROM PYTHON SOURCE LINES 343-352 In short, DGL will group the nodes by their in-degrees, and for each group DGL stacks the incoming messages along the second dimension. You can then perform a reduction along the second dimension to aggregate messages. For more details on customizing message and reduce function with user-defined function, please refer to the :ref:`API reference `. .. GENERATED FROM PYTHON SOURCE LINES 355-367 Best practice of writing custom GNN modules ------------------------------------------- DGL recommends the following practice ranked by preference: - Use ``dgl.nn`` modules. - Use ``dgl.nn.functional`` functions which contain lower-level complex operations such as computing a softmax for each node over incoming edges. - Use ``update_all`` with builtin message and reduce functions. - Use user-defined message or reduce functions. .. GENERATED FROM PYTHON SOURCE LINES 370-376 What’s next? ------------ - :ref:`Writing Efficient Message Passing Code `. .. GENERATED FROM PYTHON SOURCE LINES 376-380 .. code-block:: Python # Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018 # sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png' .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.886 seconds) .. _sphx_glr_download_tutorials_blitz_3_message_passing.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 3_message_passing.ipynb <3_message_passing.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 3_message_passing.py <3_message_passing.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 3_message_passing.zip <3_message_passing.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_