DGL
2.5

Get Started

  • Install and Setup
  • A Blitz Introduction to DGL

Advanced Materials

  • ๐Ÿ†• Stochastic Training of GNNs with GraphBolt
  • User Guide
    • Chapter 1: Graph
    • Chapter 2: Message Passing
    • Chapter 3: Building GNN Modules
    • Chapter 4: Graph Data Pipeline
    • Chapter 5: Training Graph Neural Networks
    • Chapter 6: Stochastic Training on Large Graphs
      • 6.1 Training GNN for Node Classification with Neighborhood Sampling
      • 6.2 Training GNN for Edge Classification with Neighborhood Sampling
      • 6.3 Training GNN for Link Prediction with Neighborhood Sampling
      • 6.4 Implementing Custom Graph Samplers
      • 6.5 Training GNN with DGL sparse
      • 6.6 Implementing Custom GNN Module for Mini-batch Training
      • 6.7 Exact Offline Inference on Large Graphs
      • 6.8 Using GPU for Neighborhood Sampling
      • 6.9 Data Loading Parallelism
    • Chapter 7: Distributed Training
    • Chapter 8: Mixed Precision Training
  • ็”จๆˆทๆŒ‡ๅ—ใ€ๅŒ…ๅซ่ฟ‡ๆ—ถไฟกๆฏใ€‘
  • ์‚ฌ์šฉ์ž ๊ฐ€์ด๋“œ[์‹œ๋Œ€์— ๋’ค์ณ์ง„]
  • ๐Ÿ†• Tutorial: Graph Transformer
  • Tutorials: dgl.sparse
  • Training on CPUs
  • Training on Multiple GPUs
  • Distributed training
  • Paper Study with DGL

API Reference

  • dgl
  • dgl.data
  • dgl.dataloading
  • dgl.DGLGraph
  • dgl.distributed
  • dgl.function
  • dgl.geometry
  • ๐Ÿ†• dgl.graphbolt
  • dgl.nn (PyTorch)
  • dgl.nn.functional
  • dgl.ops
  • dgl.optim
  • dgl.sampling
  • dgl.sparse
  • dgl.multiprocessing
  • dgl.transforms
  • User-defined Functions

Notes

  • Contribute to DGL
  • DGL Foreign Function Interface (FFI)
  • Performance Benchmarks

Misc

  • Frequently Asked Questions (FAQ)
  • Environment Variables
  • Resources
DGL
  • User Guide
  • Chapter 6: Stochastic Training on Large Graphs
  • 6.3 Training GNN for Link Prediction with Neighborhood Sampling
  • View page source

6.3 Training GNN for Link Prediction with Neighborhood Sampling๏ƒ

(ไธญๆ–‡็‰ˆ)

Define a data loader with neighbor and negative sampling๏ƒ

You can still use the same data loader as the one in node/edge classification. The only difference is that you need to add an additional stage negative sampling before neighbor sampling stage. The following data loader will pick 5 negative destination nodes uniformly for each source node of an edge.

datapipe = datapipe.sample_uniform_negative(graph, 5)

The whole data loader pipeline is as follows:

datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

For the details about the builtin uniform negative sampler please see UniformNegativeSampler.

You can also give your own negative sampler function, as long as it inherits from NegativeSampler and overrides the _sample_with_etype() method which takes in the node pairs in minibatch, and returns the negative node pairs back.

The following gives an example of custom negative sampler that samples negative destination nodes according to a probability distribution proportional to a power of degrees.

@functional_datapipe("customized_sample_negative")
class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
    def __init__(self, datapipe, k, node_degrees):
        super().__init__(datapipe, k)
        # caches the probability distribution
        self.weights = node_degrees ** 0.75
        self.k = k

    def _sample_with_etype(self, seeds, etype=None):
        src, _ = seeds.T
        src = src.repeat_interleave(self.k)
        dst = self.weights.multinomial(len(src), replacement=True)
        return src, dst

datapipe = datapipe.customized_sample_negative(5, node_degrees)

Define a GraphSAGE model for minibatch training๏ƒ

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

When a negative sampler is provided, the data loader will generate positive and negative node pairs for each minibatch besides the Message Flow Graphs (MFGs). Use compacted_seeds and labels to get compact node pairs and corresponding labels.

Training loop๏ƒ

The training loop simply involves iterating over the data loader and feeding in the graphs as well as the input features to the model defined above.

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in tqdm.trange(args.epochs):
    model.train()
    total_loss = 0
    start_epoch_time = time.time()
    for step, data in enumerate(dataloader):
        # Unpack MiniBatch.
        compacted_seeds = data.compacted_seeds.T
        labels = data.labels
        node_feature = data.node_features["feat"]
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks

        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_seeds[0]] * y[compacted_seeds[1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    end_epoch_time = time.time()

DGL provides the unsupervised learning GraphSAGE that shows an example of link prediction on homogeneous graphs.

For heterogeneous graphs๏ƒ

The previous model could be easily extended to heterogeneous graphs. The only difference is that you need to use HeteroGraphConv to wrap SAGEConv according to edge types.

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(in_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.layers.append(dglnn.HeteroGraphConv({
                rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                for rel in rel_names
            }))
        self.hidden_size = hidden_size
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x

Data loader definition is also very similar to that for homogeneous graph. The only difference is that you need to give edge types for feature fetching.

datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
datapipe = datapipe.fetch_feature(
    feature,
    node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

If you want to give your own negative sampling function, just inherit from the NegativeSampler class and override the _sample_with_etype() method.

@functional_datapipe("customized_sample_negative")
class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
    def __init__(self, datapipe, k, node_degrees):
        super().__init__(datapipe, k)
        # caches the probability distribution
        self.weights = {
            etype: node_degrees[etype] ** 0.75 for etype in node_degrees
        }
        self.k = k

    def _sample_with_etype(self, seeds, etype):
        src, _ = seeds.T
        src = src.repeat_interleave(self.k)
        dst = self.weights[etype].multinomial(len(src), replacement=True)
        return src, dst

datapipe = datapipe.customized_sample_negative(5, node_degrees)

For heterogeneous graphs, node pairs are grouped by edge types. The training loop is again almost the same as that on homogeneous graph, except for computing loss on specific edge type.

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

category = "user"
for epoch in tqdm.trange(args.epochs):
    model.train()
    total_loss = 0
    start_epoch_time = time.time()
    for step, data in enumerate(dataloader):
        # Unpack MiniBatch.
        compacted_seeds = data.compacted_seeds
        labels = data.labels
        node_features = {
            ntype: data.node_features[(ntype, "feat")]
            for ntype in data.blocks[0].srctypes
        }
        # Convert sampled subgraphs to DGL blocks.
        blocks = data.blocks
        # Get the embeddings of the input nodes.
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[category][compacted_pairs[category][:, 0]]
            * y[category][compacted_pairs[category][:, 1]]
        ).squeeze()

        # Compute loss.
        loss = F.binary_cross_entropy_with_logits(logits, labels[category])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    end_epoch_time = time.time()
Previous Next

ยฉ Copyright 2018, DGL Team.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
Downloads
On Read the Docs
Project Home
Builds