AsNodePredDataset๏
- class dgl.data.AsNodePredDataset(dataset, split_ratio=None, target_ntype=None, **kwargs)[source]๏
Bases:
DGLDatasetRepurpose a dataset for a standard semi-supervised transductive node prediction task.
The class converts a given dataset into a new dataset object such that:
Contains only one graph, accessible from
dataset[0].The graph stores:
Node labels in
g.ndata['label'].Train/val/test masks in
g.ndata['train_mask'],g.ndata['val_mask'], andg.ndata['test_mask']respectively.
In addition, the dataset contains the following attributes:
num_classes, the number of classes to predict.train_idx,val_idx,test_idx, train/val/test indexes.
If the input dataset contains heterogeneous graphs, users need to specify the
target_ntypeargument to indicate which node type to make predictions for. In this case:Node labels are stored in
g.nodes[target_ntype].data['label'].Training masks are stored in
g.nodes[target_ntype].data['train_mask']. So do validation and test masks.
The class will keep only the first graph in the provided dataset and generate train/val/test masks according to the given split ratio. The generated masks will be cached to disk for fast re-loading. If the provided split ratio differs from the cached one, it will re-process the dataset properly.
- Parameters:
dataset (DGLDataset) โ The dataset to be converted.
split_ratio ((float, float, float), optional) โ Split ratios for training, validation and test sets. They must sum to one.
target_ntype (str, optional) โ The node type to add split mask for.
- train_idx๏
An 1-D integer tensor of training node IDs.
- Type:
Tensor
- val_idx๏
An 1-D integer tensor of validation node IDs.
- Type:
Tensor
- test_idx๏
An 1-D integer tensor of test node IDs.
- Type:
Tensor
Examples
>>> ds = dgl.data.AmazonCoBuyComputerDataset() >>> print(ds) Dataset("amazon_co_buy_computer", num_graphs=1, save_path=...) >>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1]) >>> print(new_ds) Dataset("amazon_co_buy_computer-as-nodepred", num_graphs=1, save_path=...) >>> print('train_mask' in new_ds[0].ndata) True