.. _guide-minibatch-node-classification-sampler: 6.1 Training GNN for Node Classification with Neighborhood Sampling ----------------------------------------------------------------------- :ref:`(中文版) ` To make your model been trained stochastically, you need to do the followings: - Define a neighborhood sampler. - Adapt your model for minibatch training. - Modify your training loop. The following sub-subsections address these steps one by one. Define a neighborhood sampler and data loader ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DGL provides several neighborhood sampler classes that generates the computation dependencies needed for each layer given the nodes we wish to compute on. The simplest neighborhood sampler is :class:`~dgl.graphbolt.NeighborSampler` or the equivalent function-like interface :func:`~dgl.graphbolt.sample_neighbor` which makes the node gather messages from its neighbors. To use a sampler provided by DGL, one also need to combine it with :class:`~dgl.graphbolt.DataLoader`, which iterates over a set of indices (nodes in this case) in minibatches. For example, the following code creates a DataLoader that iterates over the training node ID set of ``ogbn-arxiv`` in batches, putting the list of generated MFGs onto GPU. .. code:: python import dgl import dgl.graphbolt as gb import dgl.nn as dglnn import torch import torch.nn as nn import torch.nn.functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = gb.BuiltinDataset("ogbn-arxiv").load() g = dataset.graph feature = dataset.feature train_set = dataset.tasks[0].train_set datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. # Or equivalently: # datapipe = gb.NeighborSampler(datapipe, g, [10, 10]) datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) datapipe = datapipe.copy_to(device) dataloader = gb.DataLoader(datapipe) Iterating over the DataLoader will yield :class:`~dgl.graphbolt.MiniBatch` which contains a list of specially created graphs representing the computation dependencies on each layer. In order to train with DGL, you can access the *message flow graphs* (MFGs) by calling `mini_batch.blocks`. .. code:: python mini_batch = next(iter(dataloader)) print(mini_batch.blocks) .. note:: See the `Stochastic Training Tutorial <../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__ for the concept of message flow graph. If you wish to develop your own neighborhood sampler or you want a more detailed explanation of the concept of MFGs, please refer to :ref:`guide-minibatch-customizing-neighborhood-sampler`. .. _guide-minibatch-node-classification-model: Adapt your model for minibatch training ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If your message passing modules are all provided by DGL, the changes required to adapt your model to minibatch training is minimal. Take a multi-layer GCN as an example. If your model on full graph is implemented as follows: .. code:: python class TwoLayerGCN(nn.Module): def __init__(self, in_features, hidden_features, out_features): super().__init__() self.conv1 = dglnn.GraphConv(in_features, hidden_features) self.conv2 = dglnn.GraphConv(hidden_features, out_features) def forward(self, g, x): x = F.relu(self.conv1(g, x)) x = F.relu(self.conv2(g, x)) return x Then all you need is to replace ``g`` with ``blocks`` generated above. .. code:: python class StochasticTwoLayerGCN(nn.Module): def __init__(self, in_features, hidden_features, out_features): super().__init__() self.conv1 = dgl.nn.GraphConv(in_features, hidden_features) self.conv2 = dgl.nn.GraphConv(hidden_features, out_features) def forward(self, blocks, x): x = F.relu(self.conv1(blocks[0], x)) x = F.relu(self.conv2(blocks[1], x)) return x The DGL ``GraphConv`` modules above accepts an element in ``blocks`` generated by the data loader as an argument. :ref:`The API reference of each NN module ` will tell you whether it supports accepting a MFG as an argument. If you wish to use your own message passing module, please refer to :ref:`guide-minibatch-custom-gnn-module`. Training Loop ~~~~~~~~~~~~~ The training loop simply consists of iterating over the dataset with the customized batching iterator. During each iteration that yields :class:`~dgl.graphbolt.MiniBatch`, we: 1. Access the node features corresponding to the input nodes via ``data.node_features["feat"]``. These features are already moved to the target device (CPU or GPU) by the data loader. 2. Access the node labels corresponding to the output nodes via ``data.labels``. These labels are already moved to the target device (CPU or GPU) by the data loader. 3. Feed the list of MFGs and the input node features to the multilayer GNN and get the outputs. 4. Compute the loss and backpropagate. .. code:: python model = StochasticTwoLayerGCN(in_features, hidden_features, out_features) model = model.to(device) opt = torch.optim.Adam(model.parameters()) for data in dataloader: input_features = data.node_features["feat"] output_labels = data.labels output_predictions = model(data.blocks, input_features) loss = compute_loss(output_labels, output_predictions) opt.zero_grad() loss.backward() opt.step() DGL provides an end-to-end stochastic training example `GraphSAGE implementation `__. For heterogeneous graphs ~~~~~~~~~~~~~~~~~~~~~~~~ Training a graph neural network for node classification on heterogeneous graph is similar. For instance, we have previously seen :ref:`how to train a 2-layer RGCN on full graph `. The code for RGCN implementation on minibatch training looks very similar to that (with self-loops, non-linearity and basis decomposition removed for simplicity): .. code:: python class StochasticTwoLayerRGCN(nn.Module): def __init__(self, in_feat, hidden_feat, out_feat, rel_names): super().__init__() self.conv1 = dglnn.HeteroGraphConv({ rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right') for rel in rel_names }) self.conv2 = dglnn.HeteroGraphConv({ rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right') for rel in rel_names }) def forward(self, blocks, x): x = self.conv1(blocks[0], x) x = self.conv2(blocks[1], x) return x The samplers provided by DGL also support heterogeneous graphs. For example, one can still use the provided :class:`~dgl.graphbolt.NeighborSampler` class and :class:`~dgl.graphbolt.DataLoader` class for stochastic training. The only difference is that the itemset is now an instance of :class:`~dgl.graphbolt.ItemSetDict` which is a dictionary of node types to node IDs. .. code:: python device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = gb.BuiltinDataset("ogbn-mag").load() g = dataset.graph feature = dataset.feature train_set = dataset.tasks[0].train_set datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. # Or equivalently: # datapipe = gb.NeighborSampler(datapipe, g, [10, 10]) # For heterogeneous graphs, we need to specify the node feature keys # for each node type. datapipe = datapipe.fetch_feature( feature, node_feature_keys={"author": ["feat"], "paper": ["feat"]} ) datapipe = datapipe.copy_to(device) dataloader = gb.DataLoader(datapipe) The training loop is almost the same as that of homogeneous graphs, except for the implementation of ``compute_loss`` that will take in two dictionaries of node types and predictions here. .. code:: python model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes) model = model.to(device) opt = torch.optim.Adam(model.parameters()) for data in dataloader: # For heterogeneous graphs, we need to specify the node types and # feature name when accessing the node features. So does the labels. input_features = { "author": data.node_features[("author", "feat")], "paper": data.node_features[("paper", "feat")] } output_labels = data.labels["paper"] output_predictions = model(data.blocks, input_features) loss = compute_loss(output_labels, output_predictions) opt.zero_grad() loss.backward() opt.step() DGL provides an end-to-end stochastic training example `RGCN implementation `__.