Blog Details

Home / Blog Details

When Kernel Fusion meets Graph Neural Networks

In DGL’s first release last December, we focused on usability by introducing a set of carefully designed, easy-to-use APIs that support a variety of model implementations of Graph Neural Networks. We decided to keep DGL framework-agnostic to engage with users from different platforms (PyTorch, MXNet…). As a result, in our earlier releases, we largely leveraged the available functionalities provided by these frameworks, and based on many valuable feedback from our users, we are well-aware of the room of improvement particularly on some new models defined on the sparse and irregular graphs.

With DGL’s APIs from the first relase being well-received and gradually stable, we have been working on boosting its performance in a high gear. Our next major release (v0.3) will focus on much faster training speed, lower memory consumption, and scalability.

The performance gains of the upcoming DGL v0.3 are substantial. Compared to the current version, DGL v0.3 achieves up to 19X the training throughput and can train 8X larger graphs on a single GPU. DGL’s training speed is now competitive with alternative frameworks such as Pytorch Geometric, however with much better scalability. DGL allows training on considerably larger graphs—500M nodes and 25B edges. For a more concrete performance evaluation and comparison, check out our workshop paper for more details.

In the rest of this post, we’ll get technical and describe fused message passing, the key technique enabling these performance improvements. We will address the following questions:

  • Why cannot basic message passing scale to large graphs?
  • How does fused message passing help?
  • How to enable fused message passing in DGL?

Why can’t basic message passing scale to large graphs?

Most GNN models perform message-passing style computation over a graph. Such a computation constitutes two main user-defined functions:

  • Message function specifies to send a message along an edge from a node to its neighbor. This is edge-wise computation since it is performed for all (or a subset of) edges.
  • Reduce function specifies to aggregate the incoming messages of a node and updates the node feature. This is node-wide computation since it is performed for all (or a subset of) nodes.

The following figure gives an illustration. The user-defined message function is denoted as , which is used to generate a message (yellow boxes) on each edge. To create a message for edge , the message function takes into account the edge feature as well as the two endpoints’ node features, , . At each node, incoming messages are aggregated using the user-defined reduce function and another user-defined function is used to update the node feature. In DGL, one can easily implement such message passing by calling send and recv APIs (see our tutorial on message passing for more details).

The basic strategy for implementing message-passing is straighforward. First, we send messages by invoking edge-wise message functions. Then, we recv messages by applying node-wise computation to aggregate messages according to their destination nodes (and updating node features). In DGL example below, we implement Graph Convolution Network (GCN) by specifying its message and reduce functions in lambdas, and DGL uses basic message passing underneath.

# A GCN example with user-defined message function.
# Using user-defined message function causes DGL to use 
# the basic message passing strategy.
G.update_all(lambda edges: {'m' : edges.src['h']}, 
             lambda nodes: {'h' : sum(nodes.mailbox['m'], axis=1)})

The problem with basic message passing is that messages are explicitly materialized and stored, resulting in memory explosion as the graph becomes larger. As a concrete example, consider the reddit graph dataset introduced in the GraphSAGE paper. It has 232K nodes and 114M edges. If we are to train a GCN model whose message function copies the source node feature, this will cause a whopping ~500 times memory consumption of the storage for the node features! It gets even worse since the frequent memory access can bottleneck the computation, leading to under-utilization of GPU device.

Fused message passing == NO explicit messages

To avoid the overhead of materializing messages, we implement fused message passing, i.e. combining the send and recv into one operation send_and_recv (shown in following figure). Under the hood, the fused send_and_recv operation is implemented using a CUDA kernel in which each thread loads the source node feature into the per-thread local memory, computes a message, and directly aggregates it into the buffer for a given destination node and immediately discards the message afterwards.

To enable fused message passing, there are two challenges to be solved:

  • How to fuse the user-defined functions and ? In DGL, we provide a set of pre-defined functions, called DGL built-ins, and let users choose from them. This restricts the message and reduce functions to be used for fused message passing, but we provide a variety of common functions so most GNN models can be implemented. UDFs are still allowed, in which case DGL will use basic message passing instead.
  • How to backprop the gradient without explicit messages? The trick is to recompute the message during backward propagation, similar to the technique used in training very deep NN models. In practice, many do not require explicit messages to compute the gradient (such as copying the node features as messages), and our implementation leverages this optimization.

For example, we can re-write our earlier GCN implementation to use the built-in copy_src message function and sum reduce function, as shown below.

import dgl.function as fn
G = ... # some graph
# copy src feature 'h' as the message and sum it as new 'h'
G.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'h'))

A Graph Attention Network (GAT) can also be implemented using the built-in src_mul_edge message function and sum reduce function.

# the src features 'h' are weighted by the attention score 'e'
# on the edges and are summed as the new feature 'h'
G.update_all(fn.src_mul_edge('h', 'e', 'm'), fn.sum('m', 'h'))

DGL v0.3 includes built-in functions that support the following combinations:

  • The message function can be any one of add/sub/mul/div between two given features among src, edge, dst.
  • Broadcasting is supported on the feature dimensions. For example, a multi-head attention module where the node feature is of shape (NUM_HEADS, NUM_FEATS) and the attention score is of shape (NUM_HEADS, 1).
  • The reduce function can be any one of sum/max/min/prod.

The rule of thumb is to use built-in functions as much as possible so that DGL can use fused message passing for better performance. We notice that this might introduce some programming burden but the payoff is definitely worthy (see next section for the evaluation results).

Convince me with numbers

We compare with our v0.2 release to see how much we have improved the system. In addition, we compare against PyG (Pytorch Geometric v1.2.0). PyG implements basic message passing by first gather node features as edge messages and then scatter them for message aggregation, which generates explicit messages.

We benchmark both GCN and GAT models on several popular datasets following their original settings. The testbed is one AWS p3.2xlarge instance with NVIDIA V100 GPU (16GB memory).

Dataset #V #E Models DGL(v0.2) Times(s) PyG Time(s) DGL(v0.3) Time(s)
Cora 3K 11K GCN
Citeseer 3K 9K GCN
Pubmed 20K 889K GCN
Reddit 232K 114M GCN OOM OOM 25.30

The up-coming release has huge performance boost especially on the GAT model (19x) faster thanks to the kernel fusion. Compared with PyG, for small graphs (i.e, Cora, Citeseer and Pubmed), the computation and memory consumption stay the same, relatively invariant to the graph sizes. In this regime, the graph computation does not bottleneck the training and DGL has some slight, constant overhead compared with PyG. However, when evaluating on a much larger graph extracted from Reddit, PyG runs out of memory while DGL fits the graph quite easily.

We further perform some ablation studies using synthetic graphs.

We first vary the number of nodes in the graph with fixed graph density (0.0008), and test the training speed of GCN and GAT. DGL can train GCN on the graph with up to 500K nodes, twice larger than PyG. PyG is also 3.4x slower than DGL on the largest graph it can fits.

We then fix the number of nodes in the graph (32K), but vary the density of the graph. The result clearly shows the advantage of fused message passing. For both GCN and GAT, DGL can train on graphs with up to 8 times more edges and is 7.5x faster than PyG on the largest graph it can fits.

We also vary the hidden layer size and compare the performance on a median-sized graph (32K nodes with density 0.0008). For GCN, although PyG can fit the largest hidden size we tested, it is 4x slower than DGL. For GAT, PyG cannot train with hidden size of more than 32.

Finally, we push the limit to see how large is the graph can be trained on one machine with large CPU memory (AWS x1.32xlarge instance with 2TB memory).

#Nodes #Edges Times(s) Memory(GB)
5M 250M 4.7 8
50M 2.5B 46 75
500M 25B 505 740

The result shows that DGL can train a GCN model on graphs with up to 500M nodes and 25B edges.

What’s next?

The DGL team is very passionate about the bag of features to deliver on our future agenda. As a matter of fact, many of these have been in our mind since the intiation of the project. For example, the built-in functions have been our priority the whole time but they only shine with the help of kernel fusion. To shed some light on the directions we are working on:

  • A detailed demo and tutorial on how to reproduce the large graph experiement in this post on a large CPU machine.
  • Support for heterogenous graph.
  • Accelerate graph traversal and querying using GPU.

DGL always stays close with our users and we greatly value your feedback! To try out this new feature, simple as this: clone DGL’s repository, switch to kernel branch and build the library from source.

Please stay tuned for our next major release! More good stuff coming soon.