.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/blitz/3_message_passing.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials_blitz_3_message_passing.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials_blitz_3_message_passing.py:


Write your own GNN module
=========================

Sometimes, your model goes beyond simply stacking existing GNN modules.
For example, you would like to invent a new way of aggregating neighbor
information by considering node importance or edge weights.

By the end of this tutorial you will be able to

-  Understand DGL’s message passing APIs.
-  Implement GraphSAGE convolution module by your own.

This tutorial assumes that you already know :doc:`the basics of training a
GNN for node classification <1_introduction>`.

(Time estimate: 10 minutes)

.. GENERATED FROM PYTHON SOURCE LINES 20-30

.. code-block:: Python


    import os

    os.environ["DGLBACKEND"] = "pytorch"
    import dgl
    import dgl.function as fn
    import torch
    import torch.nn as nn
    import torch.nn.functional as F








.. GENERATED FROM PYTHON SOURCE LINES 31-59

Message passing and GNNs
------------------------

DGL follows the *message passing paradigm* inspired by the Message
Passing Neural Network proposed by `Gilmer et
al. <https://arxiv.org/abs/1704.01212>`__ Essentially, they found many
GNN models can fit into the following framework:

.. math::


   m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)

.. math::


   m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}

.. math::


   h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)

where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the
*reduce function* and :math:`U^{(l)}` the *update function*. Note that
:math:`\sum` here can represent any function and is not necessarily a
summation.


.. GENERATED FROM PYTHON SOURCE LINES 62-85

For example, the `GraphSAGE convolution (Hamilton et al.,
2017) <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__
takes the following mathematical form:

.. math::


   h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\}

.. math::


   h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right)

You can see that message passing is directional: the message sent from
one node :math:`u` to other node :math:`v` is not necessarily the same
as the other message sent from node :math:`v` to node :math:`u` in the
opposite direction.

Although DGL has builtin support of GraphSAGE via
:class:`dgl.nn.SAGEConv <dgl.nn.pytorch.SAGEConv>`,
here is how you can implement GraphSAGE convolution in DGL by your own.


.. GENERATED FROM PYTHON SOURCE LINES 85-125

.. code-block:: Python



    class SAGEConv(nn.Module):
        """Graph convolution module used by the GraphSAGE model.

        Parameters
        ----------
        in_feat : int
            Input feature size.
        out_feat : int
            Output feature size.
        """

        def __init__(self, in_feat, out_feat):
            super(SAGEConv, self).__init__()
            # A linear submodule for projecting the input and neighbor feature to the output.
            self.linear = nn.Linear(in_feat * 2, out_feat)

        def forward(self, g, h):
            """Forward computation

            Parameters
            ----------
            g : Graph
                The input graph.
            h : Tensor
                The input node feature.
            """
            with g.local_scope():
                g.ndata["h"] = h
                # update_all is a message passing API.
                g.update_all(
                    message_func=fn.copy_u("h", "m"),
                    reduce_func=fn.mean("m", "h_N"),
                )
                h_N = g.ndata["h_N"]
                h_total = torch.cat([h, h_N], dim=1)
                return self.linear(h_total)









.. GENERATED FROM PYTHON SOURCE LINES 126-142

The central piece in this code is the
:func:`g.update_all <dgl.DGLGraph.update_all>`
function, which gathers and averages the neighbor features. There are
three concepts here:

* Message function ``fn.copy_u('h', 'm')`` that
  copies the node feature under name ``'h'`` as *messages* with name
  ``'m'`` sent to neighbors.

* Reduce function ``fn.mean('m', 'h_N')`` that averages
  all the received messages under name ``'m'`` and saves the result as a
  new node feature ``'h_N'``.

* ``update_all`` tells DGL to trigger the
  message and reduce functions for all the nodes and edges.


.. GENERATED FROM PYTHON SOURCE LINES 145-148

Afterwards, you can stack your own GraphSAGE convolution layers to form
a multi-layer GraphSAGE network.


.. GENERATED FROM PYTHON SOURCE LINES 148-163

.. code-block:: Python



    class Model(nn.Module):
        def __init__(self, in_feats, h_feats, num_classes):
            super(Model, self).__init__()
            self.conv1 = SAGEConv(in_feats, h_feats)
            self.conv2 = SAGEConv(h_feats, num_classes)

        def forward(self, g, in_feat):
            h = self.conv1(g, in_feat)
            h = F.relu(h)
            h = self.conv2(g, h)
            return h









.. GENERATED FROM PYTHON SOURCE LINES 164-169

Training loop
~~~~~~~~~~~~~
The following code for data loading and training loop is directly copied
from the introduction tutorial.


.. GENERATED FROM PYTHON SOURCE LINES 169-227

.. code-block:: Python


    import dgl.data

    dataset = dgl.data.CoraGraphDataset()
    g = dataset[0]


    def train(g, model):
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        all_logits = []
        best_val_acc = 0
        best_test_acc = 0

        features = g.ndata["feat"]
        labels = g.ndata["label"]
        train_mask = g.ndata["train_mask"]
        val_mask = g.ndata["val_mask"]
        test_mask = g.ndata["test_mask"]
        for e in range(200):
            # Forward
            logits = model(g, features)

            # Compute prediction
            pred = logits.argmax(1)

            # Compute loss
            # Note that we should only compute the losses of the nodes in the training set,
            # i.e. with train_mask 1.
            loss = F.cross_entropy(logits[train_mask], labels[train_mask])

            # Compute accuracy on training/validation/test
            train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
            val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
            test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

            # Save the best validation accuracy and the corresponding test accuracy.
            if best_val_acc < val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            all_logits.append(logits.detach())

            if e % 5 == 0:
                print(
                    "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
                        e, loss, val_acc, best_val_acc, test_acc, best_test_acc
                    )
                )


    model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
    train(g, model)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

      NumNodes: 2708
      NumEdges: 10556
      NumFeats: 1433
      NumClasses: 7
      NumTrainingSamples: 140
      NumValidationSamples: 500
      NumTestSamples: 1000
    Done loading data from cached files.
    In epoch 0, loss: 1.949, val acc: 0.114 (best 0.114), test acc: 0.103 (best 0.103)
    In epoch 5, loss: 1.884, val acc: 0.316 (best 0.316), test acc: 0.334 (best 0.334)
    In epoch 10, loss: 1.736, val acc: 0.658 (best 0.660), test acc: 0.688 (best 0.675)
    In epoch 15, loss: 1.505, val acc: 0.696 (best 0.696), test acc: 0.694 (best 0.694)
    In epoch 20, loss: 1.202, val acc: 0.708 (best 0.710), test acc: 0.708 (best 0.707)
    In epoch 25, loss: 0.869, val acc: 0.728 (best 0.728), test acc: 0.732 (best 0.732)
    In epoch 30, loss: 0.565, val acc: 0.744 (best 0.744), test acc: 0.746 (best 0.746)
    In epoch 35, loss: 0.337, val acc: 0.746 (best 0.746), test acc: 0.750 (best 0.747)
    In epoch 40, loss: 0.192, val acc: 0.750 (best 0.750), test acc: 0.757 (best 0.757)
    In epoch 45, loss: 0.110, val acc: 0.748 (best 0.750), test acc: 0.757 (best 0.757)
    In epoch 50, loss: 0.066, val acc: 0.750 (best 0.750), test acc: 0.756 (best 0.757)
    In epoch 55, loss: 0.042, val acc: 0.748 (best 0.752), test acc: 0.759 (best 0.758)
    In epoch 60, loss: 0.029, val acc: 0.742 (best 0.752), test acc: 0.757 (best 0.758)
    In epoch 65, loss: 0.021, val acc: 0.740 (best 0.752), test acc: 0.754 (best 0.758)
    In epoch 70, loss: 0.017, val acc: 0.738 (best 0.752), test acc: 0.749 (best 0.758)
    In epoch 75, loss: 0.014, val acc: 0.738 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 80, loss: 0.012, val acc: 0.738 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 85, loss: 0.010, val acc: 0.738 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 90, loss: 0.009, val acc: 0.738 (best 0.752), test acc: 0.749 (best 0.758)
    In epoch 95, loss: 0.008, val acc: 0.734 (best 0.752), test acc: 0.746 (best 0.758)
    In epoch 100, loss: 0.008, val acc: 0.734 (best 0.752), test acc: 0.747 (best 0.758)
    In epoch 105, loss: 0.007, val acc: 0.736 (best 0.752), test acc: 0.747 (best 0.758)
    In epoch 110, loss: 0.007, val acc: 0.736 (best 0.752), test acc: 0.746 (best 0.758)
    In epoch 115, loss: 0.006, val acc: 0.734 (best 0.752), test acc: 0.747 (best 0.758)
    In epoch 120, loss: 0.006, val acc: 0.734 (best 0.752), test acc: 0.747 (best 0.758)
    In epoch 125, loss: 0.005, val acc: 0.734 (best 0.752), test acc: 0.747 (best 0.758)
    In epoch 130, loss: 0.005, val acc: 0.732 (best 0.752), test acc: 0.749 (best 0.758)
    In epoch 135, loss: 0.005, val acc: 0.730 (best 0.752), test acc: 0.748 (best 0.758)
    In epoch 140, loss: 0.004, val acc: 0.730 (best 0.752), test acc: 0.748 (best 0.758)
    In epoch 145, loss: 0.004, val acc: 0.730 (best 0.752), test acc: 0.749 (best 0.758)
    In epoch 150, loss: 0.004, val acc: 0.728 (best 0.752), test acc: 0.748 (best 0.758)
    In epoch 155, loss: 0.004, val acc: 0.728 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 160, loss: 0.004, val acc: 0.728 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 165, loss: 0.003, val acc: 0.728 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 170, loss: 0.003, val acc: 0.730 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 175, loss: 0.003, val acc: 0.730 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 180, loss: 0.003, val acc: 0.730 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 185, loss: 0.003, val acc: 0.730 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 190, loss: 0.003, val acc: 0.730 (best 0.752), test acc: 0.750 (best 0.758)
    In epoch 195, loss: 0.003, val acc: 0.732 (best 0.752), test acc: 0.750 (best 0.758)




.. GENERATED FROM PYTHON SOURCE LINES 228-235

More customization
------------------

In DGL, we provide many built-in message and reduce functions under the
``dgl.function`` package. You can find more details in :ref:`the API
doc <apifunction>`.


.. GENERATED FROM PYTHON SOURCE LINES 238-244

These APIs allow one to quickly implement new graph convolution modules.
For example, the following implements a new ``SAGEConv`` that aggregates
neighbor representations using a weighted average. Note that ``edata``
member can hold edge features which can also take part in message
passing.


.. GENERATED FROM PYTHON SOURCE LINES 244-286

.. code-block:: Python



    class WeightedSAGEConv(nn.Module):
        """Graph convolution module used by the GraphSAGE model with edge weights.

        Parameters
        ----------
        in_feat : int
            Input feature size.
        out_feat : int
            Output feature size.
        """

        def __init__(self, in_feat, out_feat):
            super(WeightedSAGEConv, self).__init__()
            # A linear submodule for projecting the input and neighbor feature to the output.
            self.linear = nn.Linear(in_feat * 2, out_feat)

        def forward(self, g, h, w):
            """Forward computation

            Parameters
            ----------
            g : Graph
                The input graph.
            h : Tensor
                The input node feature.
            w : Tensor
                The edge weight.
            """
            with g.local_scope():
                g.ndata["h"] = h
                g.edata["w"] = w
                g.update_all(
                    message_func=fn.u_mul_e("h", "w", "m"),
                    reduce_func=fn.mean("m", "h_N"),
                )
                h_N = g.ndata["h_N"]
                h_total = torch.cat([h, h_N], dim=1)
                return self.linear(h_total)









.. GENERATED FROM PYTHON SOURCE LINES 287-291

Because the graph in this dataset does not have edge weights, we
manually assign all edge weights to one in the ``forward()`` function of
the model. You can replace it with your own edge weights.


.. GENERATED FROM PYTHON SOURCE LINES 291-310

.. code-block:: Python



    class Model(nn.Module):
        def __init__(self, in_feats, h_feats, num_classes):
            super(Model, self).__init__()
            self.conv1 = WeightedSAGEConv(in_feats, h_feats)
            self.conv2 = WeightedSAGEConv(h_feats, num_classes)

        def forward(self, g, in_feat):
            h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
            h = F.relu(h)
            h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
            return h


    model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
    train(g, model)






.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    In epoch 0, loss: 1.949, val acc: 0.072 (best 0.072), test acc: 0.091 (best 0.091)
    In epoch 5, loss: 1.869, val acc: 0.088 (best 0.088), test acc: 0.112 (best 0.112)
    In epoch 10, loss: 1.713, val acc: 0.530 (best 0.530), test acc: 0.500 (best 0.500)
    In epoch 15, loss: 1.470, val acc: 0.648 (best 0.648), test acc: 0.618 (best 0.618)
    In epoch 20, loss: 1.155, val acc: 0.678 (best 0.678), test acc: 0.653 (best 0.653)
    In epoch 25, loss: 0.813, val acc: 0.700 (best 0.700), test acc: 0.686 (best 0.686)
    In epoch 30, loss: 0.512, val acc: 0.720 (best 0.720), test acc: 0.713 (best 0.713)
    In epoch 35, loss: 0.295, val acc: 0.720 (best 0.722), test acc: 0.734 (best 0.724)
    In epoch 40, loss: 0.164, val acc: 0.732 (best 0.732), test acc: 0.738 (best 0.738)
    In epoch 45, loss: 0.093, val acc: 0.730 (best 0.734), test acc: 0.745 (best 0.740)
    In epoch 50, loss: 0.055, val acc: 0.736 (best 0.736), test acc: 0.743 (best 0.743)
    In epoch 55, loss: 0.035, val acc: 0.734 (best 0.736), test acc: 0.748 (best 0.743)
    In epoch 60, loss: 0.024, val acc: 0.734 (best 0.736), test acc: 0.748 (best 0.743)
    In epoch 65, loss: 0.018, val acc: 0.740 (best 0.740), test acc: 0.749 (best 0.749)
    In epoch 70, loss: 0.014, val acc: 0.742 (best 0.742), test acc: 0.750 (best 0.750)
    In epoch 75, loss: 0.012, val acc: 0.738 (best 0.742), test acc: 0.750 (best 0.750)
    In epoch 80, loss: 0.010, val acc: 0.736 (best 0.742), test acc: 0.748 (best 0.750)
    In epoch 85, loss: 0.009, val acc: 0.734 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 90, loss: 0.008, val acc: 0.734 (best 0.742), test acc: 0.748 (best 0.750)
    In epoch 95, loss: 0.007, val acc: 0.734 (best 0.742), test acc: 0.747 (best 0.750)
    In epoch 100, loss: 0.007, val acc: 0.734 (best 0.742), test acc: 0.747 (best 0.750)
    In epoch 105, loss: 0.006, val acc: 0.734 (best 0.742), test acc: 0.747 (best 0.750)
    In epoch 110, loss: 0.006, val acc: 0.734 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 115, loss: 0.005, val acc: 0.734 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 120, loss: 0.005, val acc: 0.728 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 125, loss: 0.005, val acc: 0.728 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 130, loss: 0.004, val acc: 0.724 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 135, loss: 0.004, val acc: 0.724 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 140, loss: 0.004, val acc: 0.724 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 145, loss: 0.004, val acc: 0.724 (best 0.742), test acc: 0.750 (best 0.750)
    In epoch 150, loss: 0.003, val acc: 0.724 (best 0.742), test acc: 0.750 (best 0.750)
    In epoch 155, loss: 0.003, val acc: 0.724 (best 0.742), test acc: 0.750 (best 0.750)
    In epoch 160, loss: 0.003, val acc: 0.724 (best 0.742), test acc: 0.750 (best 0.750)
    In epoch 165, loss: 0.003, val acc: 0.724 (best 0.742), test acc: 0.750 (best 0.750)
    In epoch 170, loss: 0.003, val acc: 0.726 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 175, loss: 0.003, val acc: 0.726 (best 0.742), test acc: 0.749 (best 0.750)
    In epoch 180, loss: 0.003, val acc: 0.728 (best 0.742), test acc: 0.748 (best 0.750)
    In epoch 185, loss: 0.002, val acc: 0.728 (best 0.742), test acc: 0.748 (best 0.750)
    In epoch 190, loss: 0.002, val acc: 0.730 (best 0.742), test acc: 0.748 (best 0.750)
    In epoch 195, loss: 0.002, val acc: 0.730 (best 0.742), test acc: 0.748 (best 0.750)




.. GENERATED FROM PYTHON SOURCE LINES 311-318

Even more customization by user-defined function
------------------------------------------------

DGL allows user-defined message and reduce function for the maximal
expressiveness. Here is a user-defined message function that is
equivalent to ``fn.u_mul_e('h', 'w', 'm')``.


.. GENERATED FROM PYTHON SOURCE LINES 318-324

.. code-block:: Python



    def u_mul_e_udf(edges):
        return {"m": edges.src["h"] * edges.data["w"]}









.. GENERATED FROM PYTHON SOURCE LINES 325-329

``edges`` has three members: ``src``, ``data`` and ``dst``, representing
the source node feature, edge feature, and destination node feature for
all edges.


.. GENERATED FROM PYTHON SOURCE LINES 332-336

You can also write your own reduce function. For example, the following
is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages
the incoming messages:


.. GENERATED FROM PYTHON SOURCE LINES 336-342

.. code-block:: Python



    def mean_udf(nodes):
        return {"h_N": nodes.mailbox["m"].mean(1)}









.. GENERATED FROM PYTHON SOURCE LINES 343-352

In short, DGL will group the nodes by their in-degrees, and for each
group DGL stacks the incoming messages along the second dimension. You
can then perform a reduction along the second dimension to aggregate
messages.

For more details on customizing message and reduce function with
user-defined function, please refer to the :ref:`API
reference <apiudf>`.


.. GENERATED FROM PYTHON SOURCE LINES 355-367

Best practice of writing custom GNN modules
-------------------------------------------

DGL recommends the following practice ranked by preference:

-  Use ``dgl.nn`` modules.
-  Use ``dgl.nn.functional`` functions which contain lower-level complex
   operations such as computing a softmax for each node over incoming
   edges.
-  Use ``update_all`` with builtin message and reduce functions.
-  Use user-defined message or reduce functions.


.. GENERATED FROM PYTHON SOURCE LINES 370-376

What’s next?
------------

-  :ref:`Writing Efficient Message Passing
   Code <guide-message-passing-efficient>`.


.. GENERATED FROM PYTHON SOURCE LINES 376-380

.. code-block:: Python



    # Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
    # sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'








.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 3.959 seconds)


.. _sphx_glr_download_tutorials_blitz_3_message_passing.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: 3_message_passing.ipynb <3_message_passing.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: 3_message_passing.py <3_message_passing.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: 3_message_passing.zip <3_message_passing.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_