.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "tutorials/models/1_gnn/1_gcn.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_tutorials_models_1_gnn_1_gcn.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_tutorials_models_1_gnn_1_gcn.py:


.. _model-gcn:

Graph Convolutional Network
====================================

**Author:** `Qi Huang <https://github.com/HQ01>`_, `Minjie Wang  <https://jermainewang.github.io/>`_,
Yu Gai, Quan Gan, Zheng Zhang

.. warning::

    The tutorial aims at gaining insights into the paper, with code as a mean
    of explanation. The implementation thus is NOT optimized for running
    efficiency. For recommended implementation, please refer to the `official
    examples <https://github.com/dmlc/dgl/tree/master/examples>`_.

This is a gentle introduction of using DGL to implement Graph Convolutional
Networks (Kipf & Welling et al., `Semi-Supervised Classification with Graph
Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We explain
what is under the hood of the :class:`~dgl.nn.GraphConv` module.
The reader is expected to learn how to define a new GNN layer using DGL's
message passing APIs.

.. GENERATED FROM PYTHON SOURCE LINES 26-47

Model Overview
------------------------------------------
GCN from the perspective of message passing
```````````````````````````````````````````````
We describe a layer of graph convolutional neural network from a message
passing perspective; the math can be found `here <math_>`_.
It boils down to the following step, for each node :math:`u`:

1) Aggregate neighbors' representations :math:`h_{v}` to produce an
intermediate representation :math:`\hat{h}_u`.  2) Transform the aggregated
representation :math:`\hat{h}_{u}` with a linear projection followed by a
non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`.

We will implement step 1 with DGL message passing, and step 2 by
PyTorch ``nn.Module``.

GCN implementation with DGL
``````````````````````````````````````````
We first define the message and reduce function as usual.  Since the
aggregation on a node :math:`u` only involves summing over the neighbors'
representations :math:`h_v`, we can simply use builtin functions:

.. GENERATED FROM PYTHON SOURCE LINES 47-61

.. code-block:: Python


    import os

    os.environ["DGLBACKEND"] = "pytorch"
    import dgl
    import dgl.function as fn
    import torch as th
    import torch.nn as nn
    import torch.nn.functional as F
    from dgl import DGLGraph

    gcn_msg = fn.copy_u(u="h", out="m")
    gcn_reduce = fn.sum(msg="m", out="h")








.. GENERATED FROM PYTHON SOURCE LINES 62-70

We then proceed to define the GCNLayer module. A GCNLayer essentially performs
message passing on all the nodes then applies a fully-connected layer.

.. note::

   This is showing how to implement a GCN from scratch.  DGL provides a more
   efficient :class:`builtin GCN layer module <dgl.nn.pytorch.conv.GraphConv>`.


.. GENERATED FROM PYTHON SOURCE LINES 70-88

.. code-block:: Python



    class GCNLayer(nn.Module):
        def __init__(self, in_feats, out_feats):
            super(GCNLayer, self).__init__()
            self.linear = nn.Linear(in_feats, out_feats)

        def forward(self, g, feature):
            # Creating a local scope so that all the stored ndata and edata
            # (such as the `'h'` ndata below) are automatically popped out
            # when the scope exits.
            with g.local_scope():
                g.ndata["h"] = feature
                g.update_all(gcn_msg, gcn_reduce)
                h = g.ndata["h"]
                return self.linear(h)









.. GENERATED FROM PYTHON SOURCE LINES 89-95

The forward function is essentially the same as any other commonly seen NNs
model in PyTorch.  We can initialize GCN like any ``nn.Module``. For example,
let's define a simple neural network consisting of two GCN layers. Suppose we
are training the classifier for the cora dataset (the input feature size is
1433 and the number of classes is 7). The last GCN layer computes node embeddings,
so the last layer in general does not apply activation.

.. GENERATED FROM PYTHON SOURCE LINES 95-112

.. code-block:: Python



    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.layer1 = GCNLayer(1433, 16)
            self.layer2 = GCNLayer(16, 7)

        def forward(self, g, features):
            x = F.relu(self.layer1(g, features))
            x = self.layer2(g, x)
            return x


    net = Net()
    print(net)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Net(
      (layer1): GCNLayer(
        (linear): Linear(in_features=1433, out_features=16, bias=True)
      )
      (layer2): GCNLayer(
        (linear): Linear(in_features=16, out_features=7, bias=True)
      )
    )




.. GENERATED FROM PYTHON SOURCE LINES 113-114

We load the cora dataset using DGL's built-in data module.

.. GENERATED FROM PYTHON SOURCE LINES 114-128

.. code-block:: Python


    from dgl.data import CoraGraphDataset


    def load_cora_data():
        dataset = CoraGraphDataset()
        g = dataset[0]
        features = g.ndata["feat"]
        labels = g.ndata["label"]
        train_mask = g.ndata["train_mask"]
        test_mask = g.ndata["test_mask"]
        return g, features, labels, train_mask, test_mask









.. GENERATED FROM PYTHON SOURCE LINES 129-131

When a model is trained, we can use the following method to evaluate
the performance of the model on the test dataset:

.. GENERATED FROM PYTHON SOURCE LINES 131-144

.. code-block:: Python



    def evaluate(model, g, features, labels, mask):
        model.eval()
        with th.no_grad():
            logits = model(g, features)
            logits = logits[mask]
            labels = labels[mask]
            _, indices = th.max(logits, dim=1)
            correct = th.sum(indices == labels)
            return correct.item() * 1.0 / len(labels)









.. GENERATED FROM PYTHON SOURCE LINES 145-146

We then train the network as follows:

.. GENERATED FROM PYTHON SOURCE LINES 146-176

.. code-block:: Python


    import time

    import numpy as np

    g, features, labels, train_mask, test_mask = load_cora_data()
    # Add edges between each node and itself to preserve old node representations
    g.add_edges(g.nodes(), g.nodes())
    optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
    dur = []
    for epoch in range(50):
        if epoch >= 3:
            t0 = time.time()
        net.train()
        logits = net(g, features)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)
        acc = evaluate(net, g, features, labels, test_mask)
        print(
            "Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
                epoch, loss.item(), acc, np.mean(dur)
            )
        )




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

      NumNodes: 2708
      NumEdges: 10556
      NumFeats: 1433
      NumClasses: 7
      NumTrainingSamples: 140
      NumValidationSamples: 500
      NumTestSamples: 1000
    Done loading data from cached files.
    /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.
      return _methods._mean(a, axis=axis, dtype=dtype,
    /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide
      ret = ret.dtype.type(ret / rcount)
    Epoch 00000 | Loss 1.9446 | Test Acc 0.4060 | Time(s) nan
    Epoch 00001 | Loss 1.7764 | Test Acc 0.5710 | Time(s) nan
    Epoch 00002 | Loss 1.5769 | Test Acc 0.6260 | Time(s) nan
    Epoch 00003 | Loss 1.4453 | Test Acc 0.6760 | Time(s) 0.0083
    Epoch 00004 | Loss 1.3253 | Test Acc 0.7270 | Time(s) 0.0084
    Epoch 00005 | Loss 1.2157 | Test Acc 0.7520 | Time(s) 0.0084
    Epoch 00006 | Loss 1.1186 | Test Acc 0.7590 | Time(s) 0.0084
    Epoch 00007 | Loss 1.0321 | Test Acc 0.7650 | Time(s) 0.0084
    Epoch 00008 | Loss 0.9506 | Test Acc 0.7560 | Time(s) 0.0084
    Epoch 00009 | Loss 0.8720 | Test Acc 0.7520 | Time(s) 0.0084
    Epoch 00010 | Loss 0.7984 | Test Acc 0.7430 | Time(s) 0.0084
    Epoch 00011 | Loss 0.7302 | Test Acc 0.7450 | Time(s) 0.0083
    Epoch 00012 | Loss 0.6676 | Test Acc 0.7420 | Time(s) 0.0083
    Epoch 00013 | Loss 0.6103 | Test Acc 0.7420 | Time(s) 0.0083
    Epoch 00014 | Loss 0.5568 | Test Acc 0.7420 | Time(s) 0.0083
    Epoch 00015 | Loss 0.5067 | Test Acc 0.7450 | Time(s) 0.0083
    Epoch 00016 | Loss 0.4606 | Test Acc 0.7450 | Time(s) 0.0083
    Epoch 00017 | Loss 0.4186 | Test Acc 0.7470 | Time(s) 0.0083
    Epoch 00018 | Loss 0.3802 | Test Acc 0.7460 | Time(s) 0.0083
    Epoch 00019 | Loss 0.3451 | Test Acc 0.7450 | Time(s) 0.0083
    Epoch 00020 | Loss 0.3132 | Test Acc 0.7500 | Time(s) 0.0083
    Epoch 00021 | Loss 0.2844 | Test Acc 0.7520 | Time(s) 0.0083
    Epoch 00022 | Loss 0.2584 | Test Acc 0.7580 | Time(s) 0.0082
    Epoch 00023 | Loss 0.2349 | Test Acc 0.7630 | Time(s) 0.0082
    Epoch 00024 | Loss 0.2135 | Test Acc 0.7590 | Time(s) 0.0081
    Epoch 00025 | Loss 0.1938 | Test Acc 0.7590 | Time(s) 0.0081
    Epoch 00026 | Loss 0.1758 | Test Acc 0.7570 | Time(s) 0.0081
    Epoch 00027 | Loss 0.1594 | Test Acc 0.7550 | Time(s) 0.0080
    Epoch 00028 | Loss 0.1446 | Test Acc 0.7540 | Time(s) 0.0080
    Epoch 00029 | Loss 0.1312 | Test Acc 0.7490 | Time(s) 0.0080
    Epoch 00030 | Loss 0.1190 | Test Acc 0.7470 | Time(s) 0.0080
    Epoch 00031 | Loss 0.1081 | Test Acc 0.7460 | Time(s) 0.0080
    Epoch 00032 | Loss 0.0983 | Test Acc 0.7450 | Time(s) 0.0079
    Epoch 00033 | Loss 0.0895 | Test Acc 0.7410 | Time(s) 0.0079
    Epoch 00034 | Loss 0.0816 | Test Acc 0.7430 | Time(s) 0.0079
    Epoch 00035 | Loss 0.0745 | Test Acc 0.7420 | Time(s) 0.0079
    Epoch 00036 | Loss 0.0680 | Test Acc 0.7430 | Time(s) 0.0079
    Epoch 00037 | Loss 0.0623 | Test Acc 0.7430 | Time(s) 0.0079
    Epoch 00038 | Loss 0.0571 | Test Acc 0.7430 | Time(s) 0.0079
    Epoch 00039 | Loss 0.0524 | Test Acc 0.7450 | Time(s) 0.0079
    Epoch 00040 | Loss 0.0482 | Test Acc 0.7460 | Time(s) 0.0078
    Epoch 00041 | Loss 0.0444 | Test Acc 0.7470 | Time(s) 0.0078
    Epoch 00042 | Loss 0.0410 | Test Acc 0.7460 | Time(s) 0.0078
    Epoch 00043 | Loss 0.0380 | Test Acc 0.7470 | Time(s) 0.0078
    Epoch 00044 | Loss 0.0353 | Test Acc 0.7470 | Time(s) 0.0078
    Epoch 00045 | Loss 0.0328 | Test Acc 0.7470 | Time(s) 0.0078
    Epoch 00046 | Loss 0.0306 | Test Acc 0.7470 | Time(s) 0.0078
    Epoch 00047 | Loss 0.0286 | Test Acc 0.7470 | Time(s) 0.0078
    Epoch 00048 | Loss 0.0267 | Test Acc 0.7460 | Time(s) 0.0078
    Epoch 00049 | Loss 0.0250 | Test Acc 0.7440 | Time(s) 0.0078




.. GENERATED FROM PYTHON SOURCE LINES 177-206

.. _math:

GCN in one formula
------------------
Mathematically, the GCN model follows this formula:

:math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})`

Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network,
:math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for
this layer. :math:`\tilde{D}` and :math:`\tilde{A}` are separately the degree
and adjacency matrices for the graph. With the superscript ~, we are referring
to the variant where we add additional edges between each node and itself to
preserve its old representation in graph convolutions. The shape of the input
:math:`H^{(0)}` is :math:`N \times D`, where :math:`N` is the number of nodes
and :math:`D` is the number of input features. We can chain up multiple
layers as such to produce a node-level representation output with shape
:math:`N \times F`, where :math:`F` is the dimension of the output node
feature vector.

The equation can be efficiently implemented using sparse matrix
multiplication kernels (such as Kipf's
`pygcn <https://github.com/tkipf/pygcn>`_ code). The above DGL implementation
in fact has already used this trick due to the use of builtin functions.

Note that the tutorial code implements a simplified version of GCN where we
replace :math:`\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}` with
:math:`\tilde{A}`. For a full implementation, see our example
`here  <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn>`_.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 0.743 seconds)


.. _sphx_glr_download_tutorials_models_1_gnn_1_gcn.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: 1_gcn.ipynb <1_gcn.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: 1_gcn.py <1_gcn.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: 1_gcn.zip <1_gcn.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_