.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/blitz/5_graph_classification.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_blitz_5_graph_classification.py: Training a GNN for Graph Classification ======================================= By the end of this tutorial, you will be able to - Load a DGL-provided graph classification dataset. - Understand what *readout* function does. - Understand how to create and use a minibatch of graphs. - Build a GNN-based graph classification model. - Train and evaluate the model on a DGL-provided dataset. (Time estimate: 18 minutes) .. GENERATED FROM PYTHON SOURCE LINES 15-25 .. code-block:: Python import os os.environ["DGLBACKEND"] = "pytorch" import dgl import dgl.data import torch import torch.nn as nn import torch.nn.functional as F .. GENERATED FROM PYTHON SOURCE LINES 26-40 Overview of Graph Classification with GNN ----------------------------------------- Graph classification or regression requires a model to predict certain graph-level properties of a single graph given its node and edge features. Molecular property prediction is one particular application. This tutorial shows how to train a graph classification model for a small dataset from the paper `How Powerful Are Graph Neural Networks `__. Loading Data ------------ .. GENERATED FROM PYTHON SOURCE LINES 40-46 .. code-block:: Python # Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes. dataset = dgl.data.GINDataset("PROTEINS", self_loop=True) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading /root/.dgl/GINDataset.zip from https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip... /root/.dgl/GINDataset.zip: 0%| | 0.00/33.4M [00:00`__. For example, this tutorial creates a training ``GraphDataLoader`` and test ``GraphDataLoader``, using ``SubsetRandomSampler`` to tell PyTorch to sample from only a subset of the dataset. .. GENERATED FROM PYTHON SOURCE LINES 76-93 .. code-block:: Python from torch.utils.data.sampler import SubsetRandomSampler num_examples = len(dataset) num_train = int(num_examples * 0.8) train_sampler = SubsetRandomSampler(torch.arange(num_train)) test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples)) train_dataloader = GraphDataLoader( dataset, sampler=train_sampler, batch_size=5, drop_last=False ) test_dataloader = GraphDataLoader( dataset, sampler=test_sampler, batch_size=5, drop_last=False ) .. GENERATED FROM PYTHON SOURCE LINES 94-97 You can try to iterate over the created ``GraphDataLoader`` and see what it gives: .. GENERATED FROM PYTHON SOURCE LINES 97-103 .. code-block:: Python it = iter(train_dataloader) batch = next(it) print(batch) .. rst-class:: sphx-glr-script-out .. code-block:: none [Graph(num_nodes=236, num_edges=1068, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={}), tensor([0, 1, 0, 0, 0])] .. GENERATED FROM PYTHON SOURCE LINES 104-123 As each element in ``dataset`` has a graph and a label, the ``GraphDataLoader`` will return two objects for each iteration. The first element is the batched graph, and the second element is simply a label vector representing the category of each graph in the mini-batch. Next, we’ll talked about the batched graph. A Batched Graph in DGL ---------------------- In each mini-batch, the sampled graphs are combined into a single bigger batched graph via ``dgl.batch``. The single bigger batched graph merges all original graphs as separately connected components, with the node and edge features concatenated. This bigger graph is also a ``DGLGraph`` instance (so you can still treat it as a normal ``DGLGraph`` object as in `here <2_dglgraph.ipynb>`__). It however contains the information necessary for recovering the original graphs, such as the number of nodes and edges of each graph element. .. GENERATED FROM PYTHON SOURCE LINES 123-140 .. code-block:: Python batched_graph, labels = batch print( "Number of nodes for each graph element in the batch:", batched_graph.batch_num_nodes(), ) print( "Number of edges for each graph element in the batch:", batched_graph.batch_num_edges(), ) # Recover the original graph elements from the minibatch graphs = dgl.unbatch(batched_graph) print("The original graphs in the minibatch:") print(graphs) .. rst-class:: sphx-glr-script-out .. code-block:: none Number of nodes for each graph element in the batch: tensor([ 17, 15, 38, 146, 20]) Number of edges for each graph element in the batch: tensor([ 81, 71, 198, 628, 90]) The original graphs in the minibatch: [Graph(num_nodes=17, num_edges=81, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={}), Graph(num_nodes=15, num_edges=71, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={}), Graph(num_nodes=38, num_edges=198, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={}), Graph(num_nodes=146, num_edges=628, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={}), Graph(num_nodes=20, num_edges=90, ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={})] .. GENERATED FROM PYTHON SOURCE LINES 141-163 Define Model ------------ This tutorial will build a two-layer `Graph Convolutional Network (GCN) `__. Each of its layer computes new node representations by aggregating neighbor information. If you have gone through the :doc:`introduction <1_introduction>`, you will notice two differences: - Since the task is to predict a single category for the *entire graph* instead of for every node, you will need to aggregate the representations of all the nodes and potentially the edges to form a graph-level representation. Such process is more commonly referred as a *readout*. A simple choice is to average the node features of a graph with ``dgl.mean_nodes()``. - The input graph to the model will be a batched graph yielded by the ``GraphDataLoader``. The readout functions provided by DGL can handle batched graphs so that they will return one representation for each minibatch element. .. GENERATED FROM PYTHON SOURCE LINES 163-181 .. code-block:: Python from dgl.nn import GraphConv class GCN(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(GCN, self).__init__() self.conv1 = GraphConv(in_feats, h_feats) self.conv2 = GraphConv(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) g.ndata["h"] = h return dgl.mean_nodes(g, "h") .. GENERATED FROM PYTHON SOURCE LINES 182-189 Training Loop ------------- The training loop iterates over the training set with the ``GraphDataLoader`` object and computes the gradients, just like image classification or language modeling. .. GENERATED FROM PYTHON SOURCE LINES 189-212 .. code-block:: Python # Create the model with given dimensions model = GCN(dataset.dim_nfeats, 16, dataset.gclasses) optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for epoch in range(20): for batched_graph, labels in train_dataloader: pred = model(batched_graph, batched_graph.ndata["attr"].float()) loss = F.cross_entropy(pred, labels) optimizer.zero_grad() loss.backward() optimizer.step() num_correct = 0 num_tests = 0 for batched_graph, labels in test_dataloader: pred = model(batched_graph, batched_graph.ndata["attr"].float()) num_correct += (pred.argmax(1) == labels).sum().item() num_tests += len(labels) print("Test accuracy:", num_correct / num_tests) .. rst-class:: sphx-glr-script-out .. code-block:: none Test accuracy: 0.05829596412556054 .. GENERATED FROM PYTHON SOURCE LINES 213-220 What’s next ----------- - See `GIN example `__ for an end-to-end graph classification model. .. GENERATED FROM PYTHON SOURCE LINES 220-224 .. code-block:: Python # Thumbnail credits: DGL # sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png' .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 55.250 seconds) .. _sphx_glr_download_tutorials_blitz_5_graph_classification.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 5_graph_classification.ipynb <5_graph_classification.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 5_graph_classification.py <5_graph_classification.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 5_graph_classification.zip <5_graph_classification.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_