OnDiskDataset for Heterogeneous Graph
This tutorial shows how to create OnDiskDataset
for heterogeneous graph that could be used in GraphBolt framework. The major difference from creating dataset for homogeneous graph is that we need to specify node/edge types for edges, feature data, training/validation/test sets.
By the end of this tutorial, you will be able to
organize graph structure data.
organize feature data.
organize training/validation/test set for specific tasks.
To create an OnDiskDataset
object, you need to organize all the data including graph structure, feature data and tasks into a directory. The directory should contain a metadata.yaml
file that describes the metadata of the dataset.
Now let’s generate various data step by step and organize them together to instantiate OnDiskDataset
finally.
Install DGL package
[1]:
# Install required packages.
import os
import torch
import numpy as np
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"
# Install the CPU version.
device = torch.device("cpu")
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html
try:
import dgl
import dgl.graphbolt as gb
installed = True
except ImportError as error:
installed = False
print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Requirement already satisfied: dgl in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (2.2a240410)
Requirement already satisfied: numpy>=1.14.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.26.4)
Requirement already satisfied: scipy>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (1.14.0)
Requirement already satisfied: networkx>=2.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (3.3)
Requirement already satisfied: requests>=2.19.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.32.3)
Requirement already satisfied: tqdm in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (4.66.5)
Requirement already satisfied: psutil>=5.8.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (6.0.0)
Requirement already satisfied: torchdata>=0.5.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (0.7.1)
Requirement already satisfied: pandas in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from dgl) (2.2.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (3.8)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from requests>=2.19.0->dgl) (2024.7.4)
Requirement already satisfied: torch>=2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torchdata>=0.5.0->dgl) (2.4.0+cpu)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from pandas->dgl) (2024.1)
Requirement already satisfied: six>=1.5 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->dgl) (1.16.0)
Requirement already satisfied: filelock in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.15.4)
Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.12.2)
Requirement already satisfied: sympy in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.13.2)
Requirement already satisfied: jinja2 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.4)
Requirement already satisfied: fsspec in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from torch>=2->torchdata>=0.5.0->dgl) (2024.6.1)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.5)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/envs/dgl-dev-cpu/lib/python3.10/site-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
/home/ubuntu/regression_test/dgl/python/dgl/graphbolt/base.py:81: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("graphbolt::expand_indptr")
DGL installed!
Data preparation
In order to demonstrate how to organize various data, let’s create a base directory first.
[2]:
base_dir = './ondisk_dataset_heterograph'
os.makedirs(base_dir, exist_ok=True)
print(f"Created base directory: {base_dir}")
Created base directory: ./ondisk_dataset_heterograph
Generate graph structure data
For heterogeneous graph, we need to save different edge edges(namely seeds) into separate Numpy or CSV files.
Note: - when saving to Numpy, the array requires to be in shape of (2, N)
. This format is recommended as constructing graph from it is much faster than CSV file. - when saving to CSV file, do not save index and header.
[3]:
import numpy as np
import pandas as pd
# For simplicity, we create a heterogeneous graph with
# 2 node types: `user`, `item`
# 2 edge types: `user:like:item`, `user:follow:user`
# And each node/edge type has the same number of nodes/edges.
num_nodes = 1000
num_edges = 10 * num_nodes
# Edge type: "user:like:item"
like_edges_path = os.path.join(base_dir, "like-edges.csv")
like_edges = np.random.randint(0, num_nodes, size=(num_edges, 2))
print(f"Part of [user:like:item] edges: {like_edges[:5, :]}\n")
df = pd.DataFrame(like_edges)
df.to_csv(like_edges_path, index=False, header=False)
print(f"[user:like:item] edges are saved into {like_edges_path}\n")
# Edge type: "user:follow:user"
follow_edges_path = os.path.join(base_dir, "follow-edges.csv")
follow_edges = np.random.randint(0, num_nodes, size=(num_edges, 2))
print(f"Part of [user:follow:user] edges: {follow_edges[:5, :]}\n")
df = pd.DataFrame(follow_edges)
df.to_csv(follow_edges_path, index=False, header=False)
print(f"[user:follow:user] edges are saved into {follow_edges_path}\n")
Part of [user:like:item] edges: [[853 84]
[770 310]
[584 243]
[ 68 835]
[981 622]]
[user:like:item] edges are saved into ./ondisk_dataset_heterograph/like-edges.csv
Part of [user:follow:user] edges: [[328 971]
[194 594]
[823 854]
[ 51 808]
[758 464]]
[user:follow:user] edges are saved into ./ondisk_dataset_heterograph/follow-edges.csv
Generate feature data for graph
For feature data, numpy arrays and torch tensors are supported for now. Let’s generate feature data for each node/edge type.
[4]:
# Generate node[user] feature in numpy array.
node_user_feat_0_path = os.path.join(base_dir, "node-user-feat-0.npy")
node_user_feat_0 = np.random.rand(num_nodes, 5)
print(f"Part of node[user] feature [feat_0]: {node_user_feat_0[:3, :]}")
np.save(node_user_feat_0_path, node_user_feat_0)
print(f"Node[user] feature [feat_0] is saved to {node_user_feat_0_path}\n")
# Generate another node[user] feature in torch tensor
node_user_feat_1_path = os.path.join(base_dir, "node-user-feat-1.pt")
node_user_feat_1 = torch.rand(num_nodes, 5)
print(f"Part of node[user] feature [feat_1]: {node_user_feat_1[:3, :]}")
torch.save(node_user_feat_1, node_user_feat_1_path)
print(f"Node[user] feature [feat_1] is saved to {node_user_feat_1_path}\n")
# Generate node[item] feature in numpy array.
node_item_feat_0_path = os.path.join(base_dir, "node-item-feat-0.npy")
node_item_feat_0 = np.random.rand(num_nodes, 5)
print(f"Part of node[item] feature [feat_0]: {node_item_feat_0[:3, :]}")
np.save(node_item_feat_0_path, node_item_feat_0)
print(f"Node[item] feature [feat_0] is saved to {node_item_feat_0_path}\n")
# Generate another node[item] feature in torch tensor
node_item_feat_1_path = os.path.join(base_dir, "node-item-feat-1.pt")
node_item_feat_1 = torch.rand(num_nodes, 5)
print(f"Part of node[item] feature [feat_1]: {node_item_feat_1[:3, :]}")
torch.save(node_item_feat_1, node_item_feat_1_path)
print(f"Node[item] feature [feat_1] is saved to {node_item_feat_1_path}\n")
# Generate edge[user:like:item] feature in numpy array.
edge_like_feat_0_path = os.path.join(base_dir, "edge-like-feat-0.npy")
edge_like_feat_0 = np.random.rand(num_edges, 5)
print(f"Part of edge[user:like:item] feature [feat_0]: {edge_like_feat_0[:3, :]}")
np.save(edge_like_feat_0_path, edge_like_feat_0)
print(f"Edge[user:like:item] feature [feat_0] is saved to {edge_like_feat_0_path}\n")
# Generate another edge[user:like:item] feature in torch tensor
edge_like_feat_1_path = os.path.join(base_dir, "edge-like-feat-1.pt")
edge_like_feat_1 = torch.rand(num_edges, 5)
print(f"Part of edge[user:like:item] feature [feat_1]: {edge_like_feat_1[:3, :]}")
torch.save(edge_like_feat_1, edge_like_feat_1_path)
print(f"Edge[user:like:item] feature [feat_1] is saved to {edge_like_feat_1_path}\n")
# Generate edge[user:follow:user] feature in numpy array.
edge_follow_feat_0_path = os.path.join(base_dir, "edge-follow-feat-0.npy")
edge_follow_feat_0 = np.random.rand(num_edges, 5)
print(f"Part of edge[user:follow:user] feature [feat_0]: {edge_follow_feat_0[:3, :]}")
np.save(edge_follow_feat_0_path, edge_follow_feat_0)
print(f"Edge[user:follow:user] feature [feat_0] is saved to {edge_follow_feat_0_path}\n")
# Generate another edge[user:follow:user] feature in torch tensor
edge_follow_feat_1_path = os.path.join(base_dir, "edge-follow-feat-1.pt")
edge_follow_feat_1 = torch.rand(num_edges, 5)
print(f"Part of edge[user:follow:user] feature [feat_1]: {edge_follow_feat_1[:3, :]}")
torch.save(edge_follow_feat_1, edge_follow_feat_1_path)
print(f"Edge[user:follow:user] feature [feat_1] is saved to {edge_follow_feat_1_path}\n")
Part of node[user] feature [feat_0]: [[0.43357282 0.05457615 0.71786967 0.68109853 0.67253698]
[0.76741477 0.46045833 0.50752018 0.35179977 0.34801423]
[0.11332403 0.51885244 0.25792239 0.14995902 0.50440764]]
Node[user] feature [feat_0] is saved to ./ondisk_dataset_heterograph/node-user-feat-0.npy
Part of node[user] feature [feat_1]: tensor([[0.8393, 0.6778, 0.1173, 0.8037, 0.3057],
[0.4282, 0.1495, 0.9322, 0.3118, 0.9123],
[0.9314, 0.6421, 0.6010, 0.9509, 0.9876]])
Node[user] feature [feat_1] is saved to ./ondisk_dataset_heterograph/node-user-feat-1.pt
Part of node[item] feature [feat_0]: [[0.40049266 0.24327858 0.39298428 0.32445332 0.96998347]
[0.38646057 0.35916219 0.48855499 0.77800221 0.82044517]
[0.20834231 0.88941697 0.18970113 0.40857924 0.57111527]]
Node[item] feature [feat_0] is saved to ./ondisk_dataset_heterograph/node-item-feat-0.npy
Part of node[item] feature [feat_1]: tensor([[0.5202, 0.3933, 0.0078, 0.8738, 0.4788],
[0.5636, 0.8503, 0.5413, 0.5927, 0.1807],
[0.9967, 0.3166, 0.5953, 0.2362, 0.3988]])
Node[item] feature [feat_1] is saved to ./ondisk_dataset_heterograph/node-item-feat-1.pt
Part of edge[user:like:item] feature [feat_0]: [[0.78795306 0.07268911 0.67751723 0.29304806 0.97868884]
[0.59203378 0.00721705 0.80395019 0.46796979 0.17003983]
[0.62935353 0.46624539 0.42923252 0.23098918 0.81724238]]
Edge[user:like:item] feature [feat_0] is saved to ./ondisk_dataset_heterograph/edge-like-feat-0.npy
Part of edge[user:like:item] feature [feat_1]: tensor([[0.0350, 0.9093, 0.5759, 0.9654, 0.3551],
[0.8878, 0.9265, 0.6586, 0.9531, 0.5601],
[0.3325, 0.7989, 0.0130, 0.7637, 0.0281]])
Edge[user:like:item] feature [feat_1] is saved to ./ondisk_dataset_heterograph/edge-like-feat-1.pt
Part of edge[user:follow:user] feature [feat_0]: [[0.77249067 0.84666916 0.32531136 0.42676095 0.5096722 ]
[0.50097483 0.97074567 0.07944273 0.6614282 0.04410161]
[0.71713693 0.50693416 0.06863488 0.16582206 0.04175442]]
Edge[user:follow:user] feature [feat_0] is saved to ./ondisk_dataset_heterograph/edge-follow-feat-0.npy
Part of edge[user:follow:user] feature [feat_1]: tensor([[0.5491, 0.0466, 0.0218, 0.5377, 0.6500],
[0.7820, 0.9557, 0.3029, 0.9574, 0.0498],
[0.9156, 0.5180, 0.8775, 0.6983, 0.8005]])
Edge[user:follow:user] feature [feat_1] is saved to ./ondisk_dataset_heterograph/edge-follow-feat-1.pt
Generate tasks
OnDiskDataset
supports multiple tasks. For each task, we need to prepare training/validation/test sets respectively. Such sets usually vary among different tasks. In this tutorial, let’s create a Node Classification task and Link Prediction task.
Node Classification Task
For node classification task, we need node IDs and corresponding labels for each training/validation/test set. Like feature data, numpy arrays and torch tensors are supported for these sets.
[5]:
# For illustration, let's generate item sets for each node type.
num_trains = int(num_nodes * 0.6)
num_vals = int(num_nodes * 0.2)
num_tests = num_nodes - num_trains - num_vals
user_ids = np.arange(num_nodes)
np.random.shuffle(user_ids)
item_ids = np.arange(num_nodes)
np.random.shuffle(item_ids)
# Train IDs for user.
nc_train_user_ids_path = os.path.join(base_dir, "nc-train-user-ids.npy")
nc_train_user_ids = user_ids[:num_trains]
print(f"Part of train ids[user] for node classification: {nc_train_user_ids[:3]}")
np.save(nc_train_user_ids_path, nc_train_user_ids)
print(f"NC train ids[user] are saved to {nc_train_user_ids_path}\n")
# Train labels for user.
nc_train_user_labels_path = os.path.join(base_dir, "nc-train-user-labels.pt")
nc_train_user_labels = torch.randint(0, 10, (num_trains,))
print(f"Part of train labels[user] for node classification: {nc_train_user_labels[:3]}")
torch.save(nc_train_user_labels, nc_train_user_labels_path)
print(f"NC train labels[user] are saved to {nc_train_user_labels_path}\n")
# Train IDs for item.
nc_train_item_ids_path = os.path.join(base_dir, "nc-train-item-ids.npy")
nc_train_item_ids = item_ids[:num_trains]
print(f"Part of train ids[item] for node classification: {nc_train_item_ids[:3]}")
np.save(nc_train_item_ids_path, nc_train_item_ids)
print(f"NC train ids[item] are saved to {nc_train_item_ids_path}\n")
# Train labels for item.
nc_train_item_labels_path = os.path.join(base_dir, "nc-train-item-labels.pt")
nc_train_item_labels = torch.randint(0, 10, (num_trains,))
print(f"Part of train labels[item] for node classification: {nc_train_item_labels[:3]}")
torch.save(nc_train_item_labels, nc_train_item_labels_path)
print(f"NC train labels[item] are saved to {nc_train_item_labels_path}\n")
# Val IDs for user.
nc_val_user_ids_path = os.path.join(base_dir, "nc-val-user-ids.npy")
nc_val_user_ids = user_ids[num_trains:num_trains+num_vals]
print(f"Part of val ids[user] for node classification: {nc_val_user_ids[:3]}")
np.save(nc_val_user_ids_path, nc_val_user_ids)
print(f"NC val ids[user] are saved to {nc_val_user_ids_path}\n")
# Val labels for user.
nc_val_user_labels_path = os.path.join(base_dir, "nc-val-user-labels.pt")
nc_val_user_labels = torch.randint(0, 10, (num_vals,))
print(f"Part of val labels[user] for node classification: {nc_val_user_labels[:3]}")
torch.save(nc_val_user_labels, nc_val_user_labels_path)
print(f"NC val labels[user] are saved to {nc_val_user_labels_path}\n")
# Val IDs for item.
nc_val_item_ids_path = os.path.join(base_dir, "nc-val-item-ids.npy")
nc_val_item_ids = item_ids[num_trains:num_trains+num_vals]
print(f"Part of val ids[item] for node classification: {nc_val_item_ids[:3]}")
np.save(nc_val_item_ids_path, nc_val_item_ids)
print(f"NC val ids[item] are saved to {nc_val_item_ids_path}\n")
# Val labels for item.
nc_val_item_labels_path = os.path.join(base_dir, "nc-val-item-labels.pt")
nc_val_item_labels = torch.randint(0, 10, (num_vals,))
print(f"Part of val labels[item] for node classification: {nc_val_item_labels[:3]}")
torch.save(nc_val_item_labels, nc_val_item_labels_path)
print(f"NC val labels[item] are saved to {nc_val_item_labels_path}\n")
# Test IDs for user.
nc_test_user_ids_path = os.path.join(base_dir, "nc-test-user-ids.npy")
nc_test_user_ids = user_ids[-num_tests:]
print(f"Part of test ids[user] for node classification: {nc_test_user_ids[:3]}")
np.save(nc_test_user_ids_path, nc_test_user_ids)
print(f"NC test ids[user] are saved to {nc_test_user_ids_path}\n")
# Test labels for user.
nc_test_user_labels_path = os.path.join(base_dir, "nc-test-user-labels.pt")
nc_test_user_labels = torch.randint(0, 10, (num_tests,))
print(f"Part of test labels[user] for node classification: {nc_test_user_labels[:3]}")
torch.save(nc_test_user_labels, nc_test_user_labels_path)
print(f"NC test labels[user] are saved to {nc_test_user_labels_path}\n")
# Test IDs for item.
nc_test_item_ids_path = os.path.join(base_dir, "nc-test-item-ids.npy")
nc_test_item_ids = item_ids[-num_tests:]
print(f"Part of test ids[item] for node classification: {nc_test_item_ids[:3]}")
np.save(nc_test_item_ids_path, nc_test_item_ids)
print(f"NC test ids[item] are saved to {nc_test_item_ids_path}\n")
# Test labels for item.
nc_test_item_labels_path = os.path.join(base_dir, "nc-test-item-labels.pt")
nc_test_item_labels = torch.randint(0, 10, (num_tests,))
print(f"Part of test labels[item] for node classification: {nc_test_item_labels[:3]}")
torch.save(nc_test_item_labels, nc_test_item_labels_path)
print(f"NC test labels[item] are saved to {nc_test_item_labels_path}\n")
Part of train ids[user] for node classification: [ 75 949 385]
NC train ids[user] are saved to ./ondisk_dataset_heterograph/nc-train-user-ids.npy
Part of train labels[user] for node classification: tensor([9, 8, 7])
NC train labels[user] are saved to ./ondisk_dataset_heterograph/nc-train-user-labels.pt
Part of train ids[item] for node classification: [724 954 984]
NC train ids[item] are saved to ./ondisk_dataset_heterograph/nc-train-item-ids.npy
Part of train labels[item] for node classification: tensor([0, 8, 2])
NC train labels[item] are saved to ./ondisk_dataset_heterograph/nc-train-item-labels.pt
Part of val ids[user] for node classification: [283 592 300]
NC val ids[user] are saved to ./ondisk_dataset_heterograph/nc-val-user-ids.npy
Part of val labels[user] for node classification: tensor([2, 6, 7])
NC val labels[user] are saved to ./ondisk_dataset_heterograph/nc-val-user-labels.pt
Part of val ids[item] for node classification: [731 855 179]
NC val ids[item] are saved to ./ondisk_dataset_heterograph/nc-val-item-ids.npy
Part of val labels[item] for node classification: tensor([3, 0, 4])
NC val labels[item] are saved to ./ondisk_dataset_heterograph/nc-val-item-labels.pt
Part of test ids[user] for node classification: [128 17 898]
NC test ids[user] are saved to ./ondisk_dataset_heterograph/nc-test-user-ids.npy
Part of test labels[user] for node classification: tensor([8, 0, 5])
NC test labels[user] are saved to ./ondisk_dataset_heterograph/nc-test-user-labels.pt
Part of test ids[item] for node classification: [487 902 630]
NC test ids[item] are saved to ./ondisk_dataset_heterograph/nc-test-item-ids.npy
Part of test labels[item] for node classification: tensor([2, 1, 2])
NC test labels[item] are saved to ./ondisk_dataset_heterograph/nc-test-item-labels.pt
Link Prediction Task
For link prediction task, we need seeds or corresponding labels and indexes which representing the pos/neg property and group of the seeds for each training/validation/test set. Like feature data, numpy arrays and torch tensors are supported for these sets.
[6]:
# For illustration, let's generate item sets for each edge type.
num_trains = int(num_edges * 0.6)
num_vals = int(num_edges * 0.2)
num_tests = num_edges - num_trains - num_vals
# Train seeds for user:like:item.
lp_train_like_seeds_path = os.path.join(base_dir, "lp-train-like-seeds.npy")
lp_train_like_seeds = like_edges[:num_trains, :]
print(f"Part of train seeds[user:like:item] for link prediction: {lp_train_like_seeds[:3]}")
np.save(lp_train_like_seeds_path, lp_train_like_seeds)
print(f"LP train seeds[user:like:item] are saved to {lp_train_like_seeds_path}\n")
# Train seeds for user:follow:user.
lp_train_follow_seeds_path = os.path.join(base_dir, "lp-train-follow-seeds.npy")
lp_train_follow_seeds = follow_edges[:num_trains, :]
print(f"Part of train seeds[user:follow:user] for link prediction: {lp_train_follow_seeds[:3]}")
np.save(lp_train_follow_seeds_path, lp_train_follow_seeds)
print(f"LP train seeds[user:follow:user] are saved to {lp_train_follow_seeds_path}\n")
# Val seeds for user:like:item.
lp_val_like_seeds_path = os.path.join(base_dir, "lp-val-like-seeds.npy")
lp_val_like_seeds = like_edges[num_trains:num_trains+num_vals, :]
lp_val_like_neg_dsts = np.random.randint(0, num_nodes, (num_vals, 10)).reshape(-1)
lp_val_like_neg_srcs = np.repeat(lp_val_like_seeds[:,0], 10)
lp_val_like_neg_seeds = np.concatenate((lp_val_like_neg_srcs, lp_val_like_neg_dsts)).reshape(2,-1).T
lp_val_like_seeds = np.concatenate((lp_val_like_seeds, lp_val_like_neg_seeds))
print(f"Part of val seeds[user:like:item] for link prediction: {lp_val_like_seeds[:3]}")
np.save(lp_val_like_seeds_path, lp_val_like_seeds)
print(f"LP val seeds[user:like:item] are saved to {lp_val_like_seeds_path}\n")
# Val labels for user:like:item.
lp_val_like_labels_path = os.path.join(base_dir, "lp-val-like-labels.npy")
lp_val_like_labels = np.empty(num_vals * (10 + 1))
lp_val_like_labels[:num_vals] = 1
lp_val_like_labels[num_vals:] = 0
print(f"Part of val labels[user:like:item] for link prediction: {lp_val_like_labels[:3]}")
np.save(lp_val_like_labels_path, lp_val_like_labels)
print(f"LP val labels[user:like:item] are saved to {lp_val_like_labels_path}\n")
# Val indexes for user:like:item.
lp_val_like_indexes_path = os.path.join(base_dir, "lp-val-like-indexes.npy")
lp_val_like_indexes = np.arange(0, num_vals)
lp_val_like_neg_indexes = np.repeat(lp_val_like_indexes, 10)
lp_val_like_indexes = np.concatenate([lp_val_like_indexes, lp_val_like_neg_indexes])
print(f"Part of val indexes[user:like:item] for link prediction: {lp_val_like_indexes[:3]}")
np.save(lp_val_like_indexes_path, lp_val_like_indexes)
print(f"LP val indexes[user:like:item] are saved to {lp_val_like_indexes_path}\n")
# Val seeds for user:follow:item.
lp_val_follow_seeds_path = os.path.join(base_dir, "lp-val-follow-seeds.npy")
lp_val_follow_seeds = follow_edges[num_trains:num_trains+num_vals, :]
lp_val_follow_neg_dsts = np.random.randint(0, num_nodes, (num_vals, 10)).reshape(-1)
lp_val_follow_neg_srcs = np.repeat(lp_val_follow_seeds[:,0], 10)
lp_val_follow_neg_seeds = np.concatenate((lp_val_follow_neg_srcs, lp_val_follow_neg_dsts)).reshape(2,-1).T
lp_val_follow_seeds = np.concatenate((lp_val_follow_seeds, lp_val_follow_neg_seeds))
print(f"Part of val seeds[user:follow:item] for link prediction: {lp_val_follow_seeds[:3]}")
np.save(lp_val_follow_seeds_path, lp_val_follow_seeds)
print(f"LP val seeds[user:follow:item] are saved to {lp_val_follow_seeds_path}\n")
# Val labels for user:follow:item.
lp_val_follow_labels_path = os.path.join(base_dir, "lp-val-follow-labels.npy")
lp_val_follow_labels = np.empty(num_vals * (10 + 1))
lp_val_follow_labels[:num_vals] = 1
lp_val_follow_labels[num_vals:] = 0
print(f"Part of val labels[user:follow:item] for link prediction: {lp_val_follow_labels[:3]}")
np.save(lp_val_follow_labels_path, lp_val_follow_labels)
print(f"LP val labels[user:follow:item] are saved to {lp_val_follow_labels_path}\n")
# Val indexes for user:follow:item.
lp_val_follow_indexes_path = os.path.join(base_dir, "lp-val-follow-indexes.npy")
lp_val_follow_indexes = np.arange(0, num_vals)
lp_val_follow_neg_indexes = np.repeat(lp_val_follow_indexes, 10)
lp_val_follow_indexes = np.concatenate([lp_val_follow_indexes, lp_val_follow_neg_indexes])
print(f"Part of val indexes[user:follow:item] for link prediction: {lp_val_follow_indexes[:3]}")
np.save(lp_val_follow_indexes_path, lp_val_follow_indexes)
print(f"LP val indexes[user:follow:item] are saved to {lp_val_follow_indexes_path}\n")
# Test seeds for user:like:item.
lp_test_like_seeds_path = os.path.join(base_dir, "lp-test-like-seeds.npy")
lp_test_like_seeds = like_edges[-num_tests:, :]
lp_test_like_neg_dsts = np.random.randint(0, num_nodes, (num_tests, 10)).reshape(-1)
lp_test_like_neg_srcs = np.repeat(lp_test_like_seeds[:,0], 10)
lp_test_like_neg_seeds = np.concatenate((lp_test_like_neg_srcs, lp_test_like_neg_dsts)).reshape(2,-1).T
lp_test_like_seeds = np.concatenate((lp_test_like_seeds, lp_test_like_neg_seeds))
print(f"Part of test seeds[user:like:item] for link prediction: {lp_test_like_seeds[:3]}")
np.save(lp_test_like_seeds_path, lp_test_like_seeds)
print(f"LP test seeds[user:like:item] are saved to {lp_test_like_seeds_path}\n")
# Test labels for user:like:item.
lp_test_like_labels_path = os.path.join(base_dir, "lp-test-like-labels.npy")
lp_test_like_labels = np.empty(num_tests * (10 + 1))
lp_test_like_labels[:num_tests] = 1
lp_test_like_labels[num_tests:] = 0
print(f"Part of test labels[user:like:item] for link prediction: {lp_test_like_labels[:3]}")
np.save(lp_test_like_labels_path, lp_test_like_labels)
print(f"LP test labels[user:like:item] are saved to {lp_test_like_labels_path}\n")
# Test indexes for user:like:item.
lp_test_like_indexes_path = os.path.join(base_dir, "lp-test-like-indexes.npy")
lp_test_like_indexes = np.arange(0, num_tests)
lp_test_like_neg_indexes = np.repeat(lp_test_like_indexes, 10)
lp_test_like_indexes = np.concatenate([lp_test_like_indexes, lp_test_like_neg_indexes])
print(f"Part of test indexes[user:like:item] for link prediction: {lp_test_like_indexes[:3]}")
np.save(lp_test_like_indexes_path, lp_test_like_indexes)
print(f"LP test indexes[user:like:item] are saved to {lp_test_like_indexes_path}\n")
# Test seeds for user:follow:item.
lp_test_follow_seeds_path = os.path.join(base_dir, "lp-test-follow-seeds.npy")
lp_test_follow_seeds = follow_edges[-num_tests:, :]
lp_test_follow_neg_dsts = np.random.randint(0, num_nodes, (num_tests, 10)).reshape(-1)
lp_test_follow_neg_srcs = np.repeat(lp_test_follow_seeds[:,0], 10)
lp_test_follow_neg_seeds = np.concatenate((lp_test_follow_neg_srcs, lp_test_follow_neg_dsts)).reshape(2,-1).T
lp_test_follow_seeds = np.concatenate((lp_test_follow_seeds, lp_test_follow_neg_seeds))
print(f"Part of test seeds[user:follow:item] for link prediction: {lp_test_follow_seeds[:3]}")
np.save(lp_test_follow_seeds_path, lp_test_follow_seeds)
print(f"LP test seeds[user:follow:item] are saved to {lp_test_follow_seeds_path}\n")
# Test labels for user:follow:item.
lp_test_follow_labels_path = os.path.join(base_dir, "lp-test-follow-labels.npy")
lp_test_follow_labels = np.empty(num_tests * (10 + 1))
lp_test_follow_labels[:num_tests] = 1
lp_test_follow_labels[num_tests:] = 0
print(f"Part of test labels[user:follow:item] for link prediction: {lp_test_follow_labels[:3]}")
np.save(lp_test_follow_labels_path, lp_test_follow_labels)
print(f"LP test labels[user:follow:item] are saved to {lp_test_follow_labels_path}\n")
# Test indexes for user:follow:item.
lp_test_follow_indexes_path = os.path.join(base_dir, "lp-test-follow-indexes.npy")
lp_test_follow_indexes = np.arange(0, num_tests)
lp_test_follow_neg_indexes = np.repeat(lp_test_follow_indexes, 10)
lp_test_follow_indexes = np.concatenate([lp_test_follow_indexes, lp_test_follow_neg_indexes])
print(f"Part of test indexes[user:follow:item] for link prediction: {lp_test_follow_indexes[:3]}")
np.save(lp_test_follow_indexes_path, lp_test_follow_indexes)
print(f"LP test indexes[user:follow:item] are saved to {lp_test_follow_indexes_path}\n")
Part of train seeds[user:like:item] for link prediction: [[853 84]
[770 310]
[584 243]]
LP train seeds[user:like:item] are saved to ./ondisk_dataset_heterograph/lp-train-like-seeds.npy
Part of train seeds[user:follow:user] for link prediction: [[328 971]
[194 594]
[823 854]]
LP train seeds[user:follow:user] are saved to ./ondisk_dataset_heterograph/lp-train-follow-seeds.npy
Part of val seeds[user:like:item] for link prediction: [[915 146]
[ 21 366]
[988 707]]
LP val seeds[user:like:item] are saved to ./ondisk_dataset_heterograph/lp-val-like-seeds.npy
Part of val labels[user:like:item] for link prediction: [1. 1. 1.]
LP val labels[user:like:item] are saved to ./ondisk_dataset_heterograph/lp-val-like-labels.npy
Part of val indexes[user:like:item] for link prediction: [0 1 2]
LP val indexes[user:like:item] are saved to ./ondisk_dataset_heterograph/lp-val-like-indexes.npy
Part of val seeds[user:follow:item] for link prediction: [[664 92]
[940 843]
[474 154]]
LP val seeds[user:follow:item] are saved to ./ondisk_dataset_heterograph/lp-val-follow-seeds.npy
Part of val labels[user:follow:item] for link prediction: [1. 1. 1.]
LP val labels[user:follow:item] are saved to ./ondisk_dataset_heterograph/lp-val-follow-labels.npy
Part of val indexes[user:follow:item] for link prediction: [0 1 2]
LP val indexes[user:follow:item] are saved to ./ondisk_dataset_heterograph/lp-val-follow-indexes.npy
Part of test seeds[user:like:item] for link prediction: [[937 782]
[541 259]
[954 432]]
LP test seeds[user:like:item] are saved to ./ondisk_dataset_heterograph/lp-test-like-seeds.npy
Part of test labels[user:like:item] for link prediction: [1. 1. 1.]
LP test labels[user:like:item] are saved to ./ondisk_dataset_heterograph/lp-test-like-labels.npy
Part of test indexes[user:like:item] for link prediction: [0 1 2]
LP test indexes[user:like:item] are saved to ./ondisk_dataset_heterograph/lp-test-like-indexes.npy
Part of test seeds[user:follow:item] for link prediction: [[431 378]
[121 986]
[308 665]]
LP test seeds[user:follow:item] are saved to ./ondisk_dataset_heterograph/lp-test-follow-seeds.npy
Part of test labels[user:follow:item] for link prediction: [1. 1. 1.]
LP test labels[user:follow:item] are saved to ./ondisk_dataset_heterograph/lp-test-follow-labels.npy
Part of test indexes[user:follow:item] for link prediction: [0 1 2]
LP test indexes[user:follow:item] are saved to ./ondisk_dataset_heterograph/lp-test-follow-indexes.npy
Organize Data into YAML File
Now we need to create a metadata.yaml
file which contains the paths, dadta types of graph structure, feature data, training/validation/test sets. Please note that all path should be relative to metadata.yaml
.
For heterogeneous graph, we need to specify the node/edge type in type fields. For edge type, canonical etype is required which is a string that’s concatenated by source node type, etype, and destination node type together with :
.
Notes: - all path should be relative to metadata.yaml
. - Below fields are optional and not specified in below example. - in_memory
: indicates whether to load dada into memory or mmap
. Default is True
.
Please refer to YAML specification for more details.
[7]:
yaml_content = f"""
dataset_name: heterogeneous_graph_nc_lp
graph:
nodes:
- type: user
num: {num_nodes}
- type: item
num: {num_nodes}
edges:
- type: "user:like:item"
format: csv
path: {os.path.basename(like_edges_path)}
- type: "user:follow:user"
format: csv
path: {os.path.basename(follow_edges_path)}
feature_data:
- domain: node
type: user
name: feat_0
format: numpy
path: {os.path.basename(node_user_feat_0_path)}
- domain: node
type: user
name: feat_1
format: torch
path: {os.path.basename(node_user_feat_1_path)}
- domain: node
type: item
name: feat_0
format: numpy
path: {os.path.basename(node_item_feat_0_path)}
- domain: node
type: item
name: feat_1
format: torch
path: {os.path.basename(node_item_feat_1_path)}
- domain: edge
type: "user:like:item"
name: feat_0
format: numpy
path: {os.path.basename(edge_like_feat_0_path)}
- domain: edge
type: "user:like:item"
name: feat_1
format: torch
path: {os.path.basename(edge_like_feat_1_path)}
- domain: edge
type: "user:follow:user"
name: feat_0
format: numpy
path: {os.path.basename(edge_follow_feat_0_path)}
- domain: edge
type: "user:follow:user"
name: feat_1
format: torch
path: {os.path.basename(edge_follow_feat_1_path)}
tasks:
- name: node_classification
num_classes: 10
train_set:
- type: user
data:
- name: seeds
format: numpy
path: {os.path.basename(nc_train_user_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_train_user_labels_path)}
- type: item
data:
- name: seeds
format: numpy
path: {os.path.basename(nc_train_item_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_train_item_labels_path)}
validation_set:
- type: user
data:
- name: seeds
format: numpy
path: {os.path.basename(nc_val_user_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_val_user_labels_path)}
- type: item
data:
- name: seeds
format: numpy
path: {os.path.basename(nc_val_item_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_val_item_labels_path)}
test_set:
- type: user
data:
- name: seeds
format: numpy
path: {os.path.basename(nc_test_user_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_test_user_labels_path)}
- type: item
data:
- name: seeds
format: numpy
path: {os.path.basename(nc_test_item_ids_path)}
- name: labels
format: torch
path: {os.path.basename(nc_test_item_labels_path)}
- name: link_prediction
num_classes: 10
train_set:
- type: "user:like:item"
data:
- name: seeds
format: numpy
path: {os.path.basename(lp_train_like_seeds_path)}
- type: "user:follow:user"
data:
- name: seeds
format: numpy
path: {os.path.basename(lp_train_follow_seeds_path)}
validation_set:
- type: "user:like:item"
data:
- name: seeds
format: numpy
path: {os.path.basename(lp_val_like_seeds_path)}
- name: labels
format: numpy
path: {os.path.basename(lp_val_like_labels_path)}
- name: indexes
format: numpy
path: {os.path.basename(lp_val_like_indexes_path)}
- type: "user:follow:user"
data:
- name: seeds
format: numpy
path: {os.path.basename(lp_val_follow_seeds_path)}
- name: labels
format: numpy
path: {os.path.basename(lp_val_follow_labels_path)}
- name: indexes
format: numpy
path: {os.path.basename(lp_val_follow_indexes_path)}
test_set:
- type: "user:like:item"
data:
- name: seeds
format: numpy
path: {os.path.basename(lp_test_like_seeds_path)}
- name: labels
format: numpy
path: {os.path.basename(lp_test_like_labels_path)}
- name: indexes
format: numpy
path: {os.path.basename(lp_test_like_indexes_path)}
- type: "user:follow:user"
data:
- name: seeds
format: numpy
path: {os.path.basename(lp_test_follow_seeds_path)}
- name: labels
format: numpy
path: {os.path.basename(lp_test_follow_labels_path)}
- name: indexes
format: numpy
path: {os.path.basename(lp_test_follow_indexes_path)}
"""
metadata_path = os.path.join(base_dir, "metadata.yaml")
with open(metadata_path, "w") as f:
f.write(yaml_content)
Instantiate OnDiskDataset
Now we’re ready to load dataset via dgl.graphbolt.OnDiskDataset
. When instantiating, we just pass in the base directory where metadata.yaml
file lies.
During first instantiation, GraphBolt preprocesses the raw data such as constructing FusedCSCSamplingGraph
from edges. All data including graph, feature data, training/validation/test sets are put into preprocessed
directory after preprocessing. Any following dataset loading will skip the preprocess stage.
After preprocessing, load()
is required to be called explicitly in order to load graph, feature data and tasks.
[8]:
dataset = gb.OnDiskDataset(base_dir).load()
graph = dataset.graph
print(f"Loaded graph: {graph}\n")
feature = dataset.feature
print(f"Loaded feature store: {feature}\n")
tasks = dataset.tasks
nc_task = tasks[0]
print(f"Loaded node classification task: {nc_task}\n")
lp_task = tasks[1]
print(f"Loaded link prediction task: {lp_task}\n")
The on-disk dataset is re-preprocessing, so the existing preprocessed dataset has been removed.
Start to preprocess the on-disk dataset.
Finish preprocessing the on-disk dataset.
Loaded graph: FusedCSCSamplingGraph(csc_indptr=tensor([ 0, 13, 22, ..., 19976, 19985, 20000], dtype=torch.int32),
indices=tensor([1454, 1635, 1011, ..., 1777, 1645, 1053], dtype=torch.int32),
total_num_nodes=2000, num_edges={'user:follow:user': 10000, 'user:like:item': 10000},
node_type_offset=tensor([ 0, 1000, 2000], dtype=torch.int32),
type_per_edge=tensor([1, 1, 1, ..., 0, 0, 0], dtype=torch.uint8),
node_type_to_id={'item': 0, 'user': 1},
edge_type_to_id={'user:follow:user': 0, 'user:like:item': 1},)
Loaded feature store: TorchBasedFeatureStore(
{(<OnDiskFeatureDataDomain.NODE: 'node'>, 'user', 'feat_0'): TorchBasedFeature(
feature=tensor([[0.4336, 0.0546, 0.7179, 0.6811, 0.6725],
[0.7674, 0.4605, 0.5075, 0.3518, 0.3480],
[0.1133, 0.5189, 0.2579, 0.1500, 0.5044],
...,
[0.2405, 0.4958, 0.9371, 0.3089, 0.2998],
[0.5247, 0.4104, 0.1667, 0.0378, 0.0939],
[0.6411, 0.5869, 0.7613, 0.0387, 0.3582]], dtype=torch.float64),
metadata={},
), (<OnDiskFeatureDataDomain.NODE: 'node'>, 'user', 'feat_1'): TorchBasedFeature(
feature=tensor([[0.8393, 0.6778, 0.1173, 0.8037, 0.3057],
[0.4282, 0.1495, 0.9322, 0.3118, 0.9123],
[0.9314, 0.6421, 0.6010, 0.9509, 0.9876],
...,
[0.0877, 0.8837, 0.7247, 0.0304, 0.5153],
[0.1127, 0.8993, 0.0637, 0.1143, 0.3922],
[0.9678, 0.0740, 0.7835, 0.4773, 0.8176]]),
metadata={},
), (<OnDiskFeatureDataDomain.NODE: 'node'>, 'item', 'feat_0'): TorchBasedFeature(
feature=tensor([[0.4005, 0.2433, 0.3930, 0.3245, 0.9700],
[0.3865, 0.3592, 0.4886, 0.7780, 0.8204],
[0.2083, 0.8894, 0.1897, 0.4086, 0.5711],
...,
[0.3927, 0.8274, 0.6683, 0.8600, 0.1232],
[0.2850, 0.7847, 0.9359, 0.1104, 0.4110],
[0.9090, 0.4440, 0.9736, 0.4690, 0.0241]], dtype=torch.float64),
metadata={},
), (<OnDiskFeatureDataDomain.NODE: 'node'>, 'item', 'feat_1'): TorchBasedFeature(
feature=tensor([[0.5202, 0.3933, 0.0078, 0.8738, 0.4788],
[0.5636, 0.8503, 0.5413, 0.5927, 0.1807],
[0.9967, 0.3166, 0.5953, 0.2362, 0.3988],
...,
[0.9278, 0.1584, 0.0219, 0.5824, 0.7768],
[0.5195, 0.9014, 0.9647, 0.0757, 0.4188],
[0.1438, 0.3991, 0.3973, 0.0935, 0.9690]]),
metadata={},
), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'user:like:item', 'feat_0'): TorchBasedFeature(
feature=tensor([[0.7880, 0.0727, 0.6775, 0.2930, 0.9787],
[0.5920, 0.0072, 0.8040, 0.4680, 0.1700],
[0.6294, 0.4662, 0.4292, 0.2310, 0.8172],
...,
[0.1233, 0.3032, 0.4696, 0.4758, 0.2389],
[0.5241, 0.3112, 0.5188, 0.5255, 0.7225],
[0.7823, 0.3524, 0.2434, 0.2316, 0.5300]], dtype=torch.float64),
metadata={},
), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'user:like:item', 'feat_1'): TorchBasedFeature(
feature=tensor([[0.0350, 0.9093, 0.5759, 0.9654, 0.3551],
[0.8878, 0.9265, 0.6586, 0.9531, 0.5601],
[0.3325, 0.7989, 0.0130, 0.7637, 0.0281],
...,
[0.6805, 0.1162, 0.1952, 0.1618, 0.2510],
[0.6160, 0.1648, 0.5871, 0.6500, 0.3206],
[0.9947, 0.7806, 0.0894, 0.6949, 0.2227]]),
metadata={},
), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'user:follow:user', 'feat_0'): TorchBasedFeature(
feature=tensor([[0.7725, 0.8467, 0.3253, 0.4268, 0.5097],
[0.5010, 0.9707, 0.0794, 0.6614, 0.0441],
[0.7171, 0.5069, 0.0686, 0.1658, 0.0418],
...,
[0.9113, 0.2013, 0.8296, 0.1000, 0.4814],
[0.8449, 0.2914, 0.3920, 0.7812, 0.3770],
[0.7599, 0.3118, 0.5076, 0.0383, 0.8640]], dtype=torch.float64),
metadata={},
), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'user:follow:user', 'feat_1'): TorchBasedFeature(
feature=tensor([[0.5491, 0.0466, 0.0218, 0.5377, 0.6500],
[0.7820, 0.9557, 0.3029, 0.9574, 0.0498],
[0.9156, 0.5180, 0.8775, 0.6983, 0.8005],
...,
[0.1228, 0.6004, 0.2861, 0.1396, 0.1906],
[0.2869, 0.6629, 0.2565, 0.7929, 0.8186],
[0.2030, 0.1343, 0.9927, 0.5910, 0.8593]]),
metadata={},
)}
)
Loaded node classification task: OnDiskTask(validation_set=ItemSetDict(
itemsets={'user': ItemSet(
items=(tensor([283, 592, 300, 58, 218, 235, 68, 663, 279, 90, 157, 319, 611, 445,
177, 557, 299, 503, 712, 93, 107, 91, 494, 194, 335, 871, 982, 652,
793, 487, 770, 803, 625, 852, 306, 374, 169, 153, 687, 407, 489, 791,
490, 541, 951, 816, 875, 326, 369, 243, 449, 909, 954, 292, 965, 458,
143, 824, 839, 696, 97, 193, 972, 629, 812, 800, 703, 499, 769, 773,
117, 81, 775, 753, 168, 799, 21, 561, 366, 430, 393, 764, 988, 23,
60, 846, 380, 970, 509, 325, 543, 277, 419, 854, 733, 573, 991, 452,
832, 801, 983, 894, 528, 849, 293, 238, 460, 932, 42, 409, 156, 961,
268, 240, 653, 521, 417, 317, 76, 618, 370, 24, 763, 574, 167, 620,
7, 253, 103, 789, 198, 551, 923, 92, 314, 984, 462, 70, 266, 122,
588, 858, 616, 786, 360, 928, 496, 135, 338, 959, 62, 83, 940, 129,
994, 206, 118, 810, 216, 975, 320, 394, 747, 964, 681, 322, 264, 72,
671, 126, 234, 361, 404, 658, 728, 774, 346, 350, 95, 651, 439, 899,
359, 996, 910, 608, 448, 830, 926, 759, 165, 731, 221, 978, 210, 968,
645, 853, 969, 309], dtype=torch.int32), tensor([2, 6, 7, 1, 7, 5, 9, 6, 8, 6, 5, 9, 9, 2, 6, 0, 8, 7, 1, 1, 5, 5, 1, 6,
5, 4, 3, 1, 2, 0, 8, 6, 7, 3, 7, 0, 7, 5, 2, 7, 1, 9, 3, 4, 3, 6, 0, 4,
9, 8, 2, 6, 7, 2, 3, 8, 9, 0, 7, 0, 6, 2, 3, 9, 3, 0, 2, 8, 0, 2, 2, 5,
7, 7, 1, 3, 2, 7, 6, 5, 2, 5, 1, 8, 7, 4, 8, 1, 1, 7, 9, 9, 0, 5, 9, 3,
4, 7, 7, 3, 8, 4, 5, 9, 7, 1, 9, 7, 9, 7, 9, 0, 7, 9, 0, 5, 9, 1, 9, 9,
3, 4, 7, 9, 2, 2, 9, 1, 9, 7, 2, 4, 2, 6, 7, 0, 6, 0, 6, 8, 3, 0, 1, 1,
5, 0, 8, 5, 4, 9, 8, 1, 6, 1, 2, 9, 1, 8, 8, 8, 2, 6, 5, 5, 0, 8, 0, 9,
6, 0, 3, 8, 8, 9, 8, 3, 7, 4, 3, 9, 5, 2, 0, 0, 4, 5, 9, 5, 0, 9, 3, 7,
6, 7, 6, 4, 5, 2, 4, 6])),
names=('seeds', 'labels'),
), 'item': ItemSet(
items=(tensor([731, 855, 179, 206, 670, 560, 671, 532, 192, 311, 868, 52, 296, 683,
820, 188, 341, 903, 679, 635, 571, 738, 992, 410, 907, 841, 769, 539,
570, 946, 182, 214, 53, 962, 751, 264, 916, 490, 19, 531, 244, 784,
408, 676, 497, 730, 622, 302, 26, 858, 481, 911, 709, 714, 319, 79,
96, 223, 213, 813, 500, 367, 495, 840, 844, 518, 870, 582, 431, 93,
898, 233, 37, 359, 448, 417, 18, 154, 462, 43, 287, 280, 300, 327,
878, 650, 384, 882, 846, 450, 309, 761, 276, 286, 360, 12, 551, 204,
365, 857, 245, 675, 612, 138, 328, 436, 929, 616, 512, 766, 692, 625,
153, 739, 779, 783, 685, 613, 646, 986, 350, 639, 304, 905, 494, 626,
109, 115, 973, 888, 581, 589, 6, 873, 770, 104, 229, 289, 833, 802,
125, 5, 447, 452, 572, 710, 563, 956, 578, 507, 374, 485, 480, 880,
234, 136, 928, 259, 953, 598, 606, 838, 3, 799, 316, 292, 972, 933,
224, 282, 574, 27, 642, 884, 877, 174, 75, 219, 644, 226, 466, 503,
515, 195, 274, 378, 132, 536, 818, 629, 798, 982, 313, 411, 723, 64,
260, 318, 875, 590], dtype=torch.int32), tensor([3, 0, 4, 9, 4, 2, 0, 6, 4, 9, 5, 4, 7, 6, 7, 7, 0, 4, 8, 6, 8, 4, 4, 8,
2, 6, 6, 3, 4, 6, 2, 7, 1, 2, 7, 4, 4, 1, 8, 4, 2, 6, 1, 5, 1, 6, 8, 6,
3, 2, 0, 2, 1, 9, 9, 5, 2, 2, 8, 5, 0, 3, 4, 5, 2, 9, 5, 9, 3, 2, 4, 9,
9, 6, 3, 1, 1, 3, 2, 3, 3, 6, 8, 8, 7, 1, 7, 9, 5, 4, 0, 3, 7, 0, 5, 9,
3, 6, 7, 7, 6, 6, 6, 1, 1, 3, 9, 3, 2, 0, 6, 2, 6, 8, 5, 2, 9, 5, 6, 7,
5, 1, 1, 7, 2, 0, 3, 9, 9, 8, 2, 0, 8, 0, 9, 9, 8, 8, 1, 8, 0, 1, 4, 6,
1, 9, 1, 2, 1, 8, 4, 8, 6, 5, 8, 1, 7, 9, 2, 6, 2, 3, 4, 6, 2, 4, 4, 3,
8, 8, 2, 7, 3, 4, 5, 1, 3, 1, 5, 2, 3, 1, 7, 3, 8, 0, 6, 6, 8, 6, 8, 2,
3, 5, 7, 9, 0, 4, 0, 4])),
names=('seeds', 'labels'),
)},
names=('seeds', 'labels'),
),
train_set=ItemSetDict(
itemsets={'user': ItemSet(
items=(tensor([ 75, 949, 385, 227, 874, 893, 809, 840, 454, 263, 990, 662, 682, 971,
656, 798, 673, 440, 707, 151, 664, 729, 639, 825, 596, 841, 36, 815,
628, 255, 784, 138, 109, 470, 467, 739, 564, 783, 736, 184, 203, 804,
414, 522, 797, 942, 161, 412, 125, 120, 512, 546, 310, 261, 938, 892,
205, 179, 743, 690, 593, 540, 886, 163, 455, 53, 585, 49, 328, 559,
977, 244, 443, 765, 141, 619, 451, 105, 752, 316, 343, 418, 813, 641,
672, 441, 465, 52, 582, 147, 848, 332, 271, 245, 999, 71, 768, 604,
925, 357, 230, 749, 939, 828, 806, 819, 863, 684, 933, 591, 834, 473,
870, 416, 758, 771, 336, 9, 889, 860, 29, 755, 327, 136, 847, 348,
257, 275, 223, 900, 195, 642, 947, 836, 381, 856, 776, 43, 748, 344,
630, 901, 907, 569, 704, 579, 252, 254, 438, 379, 833, 140, 284, 303,
331, 294, 146, 583, 74, 727, 547, 669, 410, 601, 115, 934, 552, 383,
600, 479, 497, 65, 790, 534, 484, 424, 475, 285, 720, 491, 674, 556,
158, 478, 471, 214, 202, 280, 549, 461, 730, 411, 879, 778, 584, 87,
459, 94, 868, 98, 32, 241, 802, 207, 144, 811, 862, 54, 262, 675,
124, 25, 171, 876, 236, 347, 911, 390, 304, 890, 877, 307, 287, 572,
745, 587, 686, 602, 888, 513, 356, 581, 432, 883, 869, 457, 415, 515,
586, 905, 130, 861, 199, 699, 246, 649, 256, 155, 953, 667, 708, 121,
378, 855, 84, 792, 212, 867, 232, 498, 472, 229, 239, 183, 529, 823,
353, 278, 835, 906, 553, 89, 635, 35, 112, 79, 701, 985, 391, 281,
413, 272, 537, 535, 962, 63, 738, 247, 182, 50, 41, 334, 197, 51,
935, 558, 249, 523, 976, 941, 788, 152, 967, 5, 805, 637, 814, 446,
222, 209, 921, 908, 850, 550, 196, 844, 865, 695, 922, 481, 382, 685,
873, 700, 735, 516, 437, 613, 19, 67, 450, 760, 714, 590, 86, 219,
989, 647, 345, 779, 578, 1, 603, 507, 693, 190, 145, 545, 530, 186,
927, 960, 291, 666, 594, 887, 571, 502, 820, 340, 10, 665, 762, 2,
751, 78, 339, 313, 808, 904, 872, 204, 101, 61, 123, 376, 80, 654,
403, 742, 488, 77, 781, 829, 421, 388, 997, 987, 469, 952, 485, 273,
831, 826, 215, 638, 715, 420, 780, 159, 358, 624, 164, 845, 149, 818,
903, 82, 160, 12, 162, 258, 518, 362, 668, 859, 944, 453, 560, 364,
501, 386, 718, 18, 843, 636, 548, 554, 392, 349, 822, 995, 4, 26,
170, 181, 756, 580, 565, 794, 302, 14, 466, 106, 676, 711, 796, 955,
957, 609, 655, 384, 878, 174, 697, 678, 368, 917, 851, 351, 400, 192,
28, 208, 617, 355, 772, 142, 289, 914, 632, 732, 722, 34, 943, 267,
217, 116, 104, 640, 387, 821, 614, 702, 434, 974, 290, 998, 323, 912,
185, 401, 757, 723, 659, 483, 706, 447, 827, 719, 311, 670, 226, 200,
173, 761, 6, 46, 916, 251, 371, 913, 524, 312, 857, 657, 176, 225,
527, 807, 754, 544, 679, 365, 533, 321, 427, 305, 626, 717, 508, 64,
172, 329, 866, 47, 795, 308, 132, 539, 40, 56, 633, 597, 100, 919,
442, 44, 694, 570, 746, 8, 606, 891, 137, 102, 661, 133, 298, 422,
979, 506, 622, 39, 930, 958, 88, 220, 431, 242, 233, 782, 276, 915,
610, 224, 918, 648, 520, 463, 884, 599, 69, 945, 180, 274, 3, 563,
595, 785, 363, 842, 33, 265, 295, 269, 405, 11, 519, 677],
dtype=torch.int32), tensor([9, 8, 7, 0, 6, 8, 0, 7, 9, 5, 1, 0, 2, 4, 8, 6, 4, 3, 9, 5, 1, 4, 5, 8,
0, 4, 1, 8, 2, 3, 9, 6, 7, 5, 3, 1, 8, 7, 3, 9, 8, 5, 2, 7, 1, 0, 8, 6,
7, 0, 6, 3, 3, 4, 6, 9, 9, 9, 2, 3, 9, 8, 8, 0, 1, 1, 1, 4, 2, 4, 4, 1,
0, 3, 7, 0, 6, 3, 4, 2, 5, 3, 0, 8, 7, 2, 3, 4, 0, 1, 7, 6, 8, 5, 2, 5,
1, 0, 6, 2, 3, 3, 2, 2, 7, 1, 8, 0, 8, 5, 6, 9, 6, 6, 3, 2, 5, 2, 7, 8,
5, 3, 0, 1, 2, 8, 1, 2, 4, 8, 4, 1, 4, 2, 1, 4, 4, 3, 2, 3, 6, 7, 1, 1,
3, 4, 8, 4, 0, 5, 2, 3, 1, 0, 0, 8, 6, 4, 5, 8, 9, 5, 0, 7, 5, 5, 9, 7,
0, 2, 5, 4, 2, 7, 7, 0, 3, 8, 4, 2, 6, 5, 6, 6, 2, 7, 4, 2, 8, 2, 2, 6,
6, 1, 2, 9, 7, 3, 0, 6, 5, 5, 4, 0, 8, 8, 9, 1, 6, 1, 8, 3, 1, 5, 8, 4,
7, 4, 4, 1, 1, 5, 1, 0, 6, 1, 4, 1, 5, 0, 0, 3, 3, 2, 6, 7, 3, 2, 2, 1,
3, 6, 4, 9, 3, 5, 9, 8, 4, 7, 9, 5, 4, 2, 2, 5, 6, 4, 5, 8, 2, 2, 7, 3,
9, 1, 3, 8, 6, 4, 8, 3, 9, 6, 1, 0, 3, 5, 4, 7, 1, 9, 4, 8, 5, 8, 4, 2,
0, 3, 6, 1, 2, 3, 8, 4, 8, 1, 3, 7, 9, 0, 6, 9, 1, 5, 3, 5, 0, 6, 7, 7,
2, 1, 5, 6, 7, 1, 6, 2, 0, 1, 6, 5, 4, 4, 6, 3, 5, 8, 1, 2, 9, 7, 8, 4,
8, 4, 3, 2, 5, 7, 7, 3, 3, 1, 0, 4, 1, 3, 9, 7, 8, 2, 7, 5, 7, 5, 1, 6,
9, 4, 6, 5, 3, 3, 6, 3, 4, 4, 5, 4, 9, 5, 2, 0, 8, 9, 6, 6, 0, 4, 9, 7,
7, 5, 6, 1, 3, 4, 2, 7, 0, 9, 5, 3, 1, 9, 6, 4, 6, 5, 8, 5, 0, 6, 2, 4,
0, 5, 3, 9, 1, 2, 0, 9, 1, 8, 4, 6, 1, 2, 8, 4, 7, 2, 5, 9, 3, 8, 7, 4,
6, 4, 6, 4, 2, 2, 8, 6, 7, 9, 9, 9, 3, 2, 2, 4, 1, 9, 2, 7, 9, 9, 7, 6,
4, 5, 7, 6, 3, 6, 1, 6, 6, 1, 3, 6, 4, 7, 6, 4, 0, 3, 3, 4, 8, 6, 1, 3,
4, 8, 8, 5, 9, 5, 6, 3, 8, 5, 1, 2, 1, 8, 9, 6, 8, 1, 7, 7, 1, 5, 9, 1,
6, 7, 6, 1, 3, 3, 0, 0, 0, 5, 6, 5, 5, 8, 6, 0, 1, 0, 9, 9, 8, 0, 7, 4,
7, 1, 5, 5, 5, 8, 7, 2, 7, 9, 5, 1, 5, 2, 0, 5, 8, 5, 5, 7, 6, 0, 9, 0,
3, 9, 9, 1, 5, 1, 8, 0, 5, 4, 7, 5, 6, 1, 9, 6, 7, 6, 3, 4, 9, 9, 1, 0,
8, 3, 2, 3, 2, 8, 7, 2, 8, 9, 2, 6, 8, 7, 1, 6, 1, 7, 7, 1, 7, 1, 6, 4])),
names=('seeds', 'labels'),
), 'item': ItemSet(
items=(tensor([724, 954, 984, 623, 148, 549, 712, 719, 346, 651, 576, 966, 16, 470,
395, 42, 155, 806, 918, 673, 538, 394, 199, 596, 253, 768, 67, 733,
342, 790, 446, 379, 335, 221, 657, 433, 658, 747, 275, 941, 991, 752,
420, 424, 157, 848, 835, 510, 945, 121, 409, 652, 530, 143, 13, 727,
418, 32, 593, 614, 414, 239, 854, 517, 895, 200, 940, 891, 236, 896,
475, 23, 269, 886, 323, 15, 483, 525, 334, 661, 290, 381, 76, 968,
647, 580, 931, 708, 50, 980, 648, 97, 331, 667, 165, 943, 887, 735,
457, 893, 31, 178, 173, 25, 949, 28, 800, 252, 249, 345, 908, 326,
198, 521, 983, 498, 152, 63, 718, 386, 422, 967, 879, 191, 725, 548,
609, 775, 144, 171, 599, 118, 979, 634, 142, 860, 299, 288, 265, 444,
479, 169, 398, 344, 849, 435, 87, 124, 668, 922, 608, 543, 994, 486,
909, 468, 373, 772, 10, 484, 4, 970, 781, 805, 238, 315, 22, 329,
542, 298, 834, 796, 957, 897, 889, 812, 322, 86, 944, 255, 177, 716,
354, 726, 60, 881, 741, 439, 375, 535, 737, 963, 110, 94, 170, 910,
965, 914, 352, 205, 389, 592, 797, 20, 126, 405, 380, 985, 987, 698,
759, 924, 847, 832, 123, 680, 151, 624, 317, 874, 402, 0, 866, 66,
376, 95, 308, 951, 755, 999, 107, 619, 722, 9, 271, 69, 706, 147,
700, 225, 193, 48, 937, 120, 62, 544, 808, 364, 38, 534, 35, 912,
88, 251, 677, 541, 362, 883, 859, 14, 958, 183, 845, 397, 197, 218,
976, 189, 519, 704, 865, 254, 146, 150, 729, 102, 113, 575, 404, 715,
442, 566, 740, 867, 463, 445, 429, 638, 210, 656, 932, 678, 765, 301,
850, 863, 237, 681, 108, 810, 145, 904, 297, 645, 247, 190, 196, 45,
119, 792, 460, 471, 774, 467, 732, 659, 412, 695, 663, 632, 114, 839,
925, 817, 34, 990, 923, 283, 399, 823, 215, 934, 438, 547, 351, 156,
391, 426, 383, 654, 2, 526, 492, 194, 669, 508, 464, 332, 369, 552,
696, 423, 371, 21, 176, 662, 807, 785, 172, 74, 330, 294, 314, 427,
921, 935, 637, 91, 687, 307, 828, 794, 55, 947, 842, 92, 235, 458,
131, 754, 579, 40, 697, 791, 277, 159, 969, 355, 496, 333, 185, 279,
100, 162, 804, 801, 227, 473, 977, 773, 453, 627, 382, 141, 915, 567,
852, 825, 853, 482, 586, 753, 8, 263, 129, 338, 266, 372, 504, 771,
557, 160, 106, 357, 443, 981, 358, 377, 577, 809, 29, 778, 633, 80,
33, 180, 819, 285, 58, 71, 610, 636, 655, 212, 693, 869, 388, 643,
407, 278, 256, 699, 955, 672, 363, 702, 749, 68, 281, 786, 713, 324,
432, 829, 406, 533, 222, 528, 920, 77, 831, 959, 529, 939, 337, 167,
998, 856, 756, 505, 83, 312, 305, 537, 272, 455, 440, 862, 434, 184,
787, 59, 469, 604, 135, 56, 890, 127, 917, 477, 997, 795, 163, 117,
583, 140, 461, 664, 827, 111, 343, 620, 913, 814, 988, 465, 892, 901,
491, 993, 602, 284, 133, 690, 49, 41, 689, 240, 974, 665, 703, 98,
36, 187, 684, 762, 270, 565, 514, 175, 166, 591, 691, 701, 705, 168,
569, 587, 428, 295, 82, 456, 564, 413, 89, 653, 607, 721, 85, 130,
501, 262, 506, 926, 186, 273, 46, 361, 556, 748, 303, 7, 257, 509,
728, 181, 561, 437, 321, 743, 948, 403, 231, 339, 57, 425, 241, 978,
116, 975, 666, 717, 370, 686, 573, 788, 230, 72, 520, 393],
dtype=torch.int32), tensor([0, 8, 2, 3, 9, 9, 9, 1, 6, 2, 2, 4, 5, 1, 0, 3, 2, 3, 5, 6, 4, 8, 2, 1,
3, 8, 7, 3, 3, 5, 9, 5, 1, 5, 7, 0, 0, 6, 5, 2, 4, 3, 8, 1, 4, 8, 9, 6,
6, 6, 2, 3, 2, 7, 2, 8, 1, 7, 7, 4, 5, 5, 1, 8, 0, 2, 4, 0, 1, 8, 5, 9,
1, 7, 8, 4, 0, 0, 7, 3, 2, 1, 0, 3, 0, 1, 3, 1, 6, 4, 3, 6, 6, 0, 3, 6,
4, 3, 0, 9, 9, 2, 1, 2, 4, 9, 4, 2, 4, 9, 7, 4, 3, 3, 7, 7, 4, 1, 5, 4,
7, 5, 2, 2, 1, 4, 0, 9, 0, 6, 2, 0, 6, 5, 1, 3, 7, 9, 7, 6, 4, 8, 4, 7,
6, 4, 9, 8, 7, 4, 7, 0, 0, 5, 0, 5, 4, 2, 2, 7, 4, 6, 8, 7, 5, 2, 2, 8,
0, 2, 3, 9, 0, 1, 8, 0, 6, 3, 7, 4, 1, 2, 2, 4, 1, 2, 1, 0, 5, 0, 0, 6,
7, 5, 8, 1, 9, 4, 9, 9, 9, 4, 1, 3, 4, 8, 4, 0, 8, 3, 0, 2, 8, 0, 6, 2,
0, 0, 5, 2, 8, 9, 2, 2, 7, 0, 3, 9, 2, 3, 8, 4, 4, 4, 7, 9, 5, 2, 4, 3,
4, 4, 5, 6, 1, 5, 0, 9, 7, 2, 1, 1, 9, 1, 2, 7, 1, 7, 0, 5, 7, 9, 9, 0,
8, 1, 6, 2, 2, 2, 9, 2, 8, 6, 5, 6, 9, 4, 5, 0, 1, 3, 7, 1, 9, 7, 9, 9,
8, 9, 1, 0, 8, 8, 1, 6, 8, 0, 3, 6, 7, 1, 7, 4, 2, 7, 8, 2, 9, 6, 2, 1,
0, 1, 4, 9, 9, 8, 5, 8, 2, 8, 0, 3, 9, 2, 2, 0, 0, 0, 5, 7, 5, 0, 3, 1,
6, 9, 3, 4, 4, 6, 2, 1, 5, 7, 2, 3, 7, 6, 5, 1, 5, 2, 8, 7, 1, 9, 7, 7,
4, 1, 4, 0, 4, 8, 8, 9, 3, 2, 1, 3, 0, 0, 2, 2, 9, 1, 9, 7, 2, 2, 3, 0,
4, 5, 6, 0, 2, 0, 6, 6, 2, 9, 0, 8, 4, 6, 9, 9, 1, 7, 2, 9, 1, 2, 8, 5,
0, 6, 3, 3, 4, 4, 1, 4, 0, 6, 4, 4, 9, 7, 5, 7, 8, 7, 3, 2, 9, 0, 6, 4,
5, 5, 7, 3, 7, 7, 7, 8, 5, 2, 5, 2, 8, 6, 6, 6, 9, 9, 7, 4, 5, 9, 1, 7,
7, 1, 4, 6, 6, 6, 8, 9, 8, 4, 6, 3, 8, 3, 4, 7, 9, 7, 8, 7, 1, 2, 6, 1,
0, 6, 4, 8, 5, 1, 3, 0, 5, 4, 4, 0, 3, 4, 2, 2, 0, 8, 5, 5, 1, 3, 1, 1,
4, 4, 7, 6, 8, 9, 1, 9, 6, 8, 1, 1, 3, 0, 9, 5, 5, 1, 8, 4, 6, 6, 8, 8,
4, 1, 3, 1, 2, 1, 5, 2, 3, 0, 4, 7, 2, 8, 3, 5, 4, 6, 3, 7, 8, 4, 4, 8,
0, 9, 0, 1, 3, 7, 7, 7, 4, 9, 1, 5, 1, 4, 4, 6, 1, 8, 3, 9, 4, 2, 7, 9,
8, 5, 9, 9, 8, 4, 9, 6, 3, 5, 1, 6, 2, 7, 5, 9, 8, 4, 1, 4, 5, 2, 9, 6])),
names=('seeds', 'labels'),
)},
names=('seeds', 'labels'),
),
test_set=ItemSetDict(
itemsets={'user': ItemSet(
items=(tensor([128, 17, 898, 148, 108, 741, 881, 429, 96, 114, 986, 175, 396, 213,
402, 566, 315, 324, 367, 766, 22, 211, 408, 740, 621, 607, 73, 956,
444, 500, 406, 342, 187, 154, 514, 435, 767, 231, 680, 38, 397, 395,
330, 372, 724, 777, 532, 895, 119, 531, 297, 495, 504, 27, 705, 950,
536, 30, 713, 589, 787, 37, 505, 650, 605, 634, 567, 631, 692, 456,
436, 577, 931, 99, 166, 542, 937, 526, 924, 897, 880, 486, 480, 660,
352, 134, 721, 538, 896, 615, 45, 425, 228, 337, 973, 110, 948, 57,
333, 576, 259, 433, 354, 85, 643, 188, 726, 482, 885, 191, 562, 710,
377, 750, 15, 282, 709, 248, 838, 0, 13, 468, 389, 150, 555, 178,
902, 48, 59, 423, 716, 477, 474, 296, 963, 920, 318, 286, 511, 517,
127, 250, 476, 980, 946, 688, 111, 993, 992, 492, 131, 698, 981, 373,
929, 737, 260, 270, 139, 644, 464, 20, 936, 428, 568, 598, 16, 744,
426, 683, 201, 510, 646, 966, 189, 399, 237, 398, 817, 623, 525, 627,
689, 341, 882, 375, 288, 734, 113, 575, 691, 66, 31, 837, 864, 612,
725, 301, 493, 55], dtype=torch.int32), tensor([8, 0, 5, 8, 3, 2, 9, 1, 2, 7, 5, 7, 3, 8, 9, 1, 8, 0, 0, 9, 0, 0, 6, 1,
9, 8, 9, 1, 4, 6, 3, 0, 0, 2, 3, 6, 6, 6, 8, 4, 8, 0, 2, 9, 0, 3, 9, 3,
7, 8, 4, 6, 2, 5, 5, 3, 4, 6, 9, 0, 0, 1, 9, 8, 9, 2, 8, 4, 3, 3, 0, 3,
4, 7, 3, 1, 4, 6, 0, 2, 1, 3, 9, 3, 3, 9, 5, 4, 3, 9, 5, 3, 4, 6, 2, 6,
5, 5, 5, 9, 7, 2, 8, 3, 1, 9, 1, 0, 3, 4, 8, 4, 5, 1, 1, 2, 9, 0, 5, 9,
9, 8, 7, 8, 6, 0, 5, 9, 2, 7, 2, 8, 0, 3, 9, 6, 4, 9, 7, 4, 1, 3, 5, 9,
3, 3, 2, 2, 0, 4, 2, 6, 6, 5, 3, 9, 8, 9, 0, 2, 8, 2, 0, 9, 1, 0, 4, 5,
9, 0, 1, 5, 7, 5, 0, 1, 0, 4, 5, 0, 8, 7, 5, 8, 0, 2, 5, 3, 9, 8, 6, 3,
2, 6, 4, 9, 0, 3, 2, 3])),
names=('seeds', 'labels'),
), 'item': ItemSet(
items=(tensor([487, 902, 630, 816, 11, 51, 640, 44, 103, 202, 821, 385, 744, 232,
996, 836, 546, 694, 99, 336, 760, 400, 401, 971, 811, 830, 815, 459,
885, 84, 617, 707, 603, 553, 995, 894, 758, 472, 502, 211, 777, 248,
24, 112, 899, 720, 73, 449, 149, 861, 488, 611, 824, 291, 803, 595,
554, 30, 927, 588, 930, 207, 258, 70, 392, 122, 209, 451, 757, 217,
961, 310, 601, 843, 105, 161, 872, 989, 348, 228, 750, 745, 101, 499,
415, 416, 54, 524, 1, 621, 441, 736, 478, 293, 540, 871, 597, 347,
268, 964, 513, 474, 826, 419, 711, 216, 134, 306, 864, 562, 906, 605,
789, 938, 81, 493, 527, 584, 594, 220, 243, 746, 47, 674, 396, 128,
368, 585, 942, 17, 764, 545, 734, 523, 682, 208, 90, 660, 366, 793,
201, 600, 628, 876, 516, 137, 822, 349, 158, 421, 615, 340, 559, 767,
164, 454, 203, 900, 261, 568, 39, 950, 618, 649, 387, 641, 320, 919,
558, 511, 356, 688, 250, 390, 776, 65, 952, 139, 780, 555, 242, 267,
851, 936, 960, 353, 742, 476, 550, 246, 837, 522, 631, 489, 763, 782,
61, 78, 430, 325], dtype=torch.int32), tensor([2, 1, 2, 5, 6, 9, 7, 9, 4, 5, 6, 7, 6, 6, 8, 9, 0, 5, 3, 0, 4, 9, 7, 1,
7, 6, 9, 1, 9, 0, 5, 0, 9, 9, 4, 0, 4, 1, 8, 7, 4, 2, 0, 3, 6, 9, 3, 1,
0, 9, 6, 3, 0, 1, 8, 3, 6, 0, 9, 1, 5, 1, 7, 7, 5, 2, 8, 9, 1, 3, 9, 2,
7, 3, 6, 6, 8, 1, 3, 5, 8, 5, 0, 4, 8, 3, 3, 1, 8, 0, 0, 3, 9, 5, 5, 6,
6, 2, 2, 3, 6, 9, 8, 7, 1, 6, 1, 1, 4, 8, 1, 8, 6, 6, 5, 2, 6, 5, 2, 1,
7, 4, 3, 0, 4, 1, 4, 7, 2, 9, 5, 3, 8, 3, 0, 8, 6, 4, 8, 9, 2, 9, 3, 6,
4, 3, 6, 6, 4, 5, 9, 7, 3, 4, 3, 7, 7, 8, 3, 1, 6, 0, 9, 1, 1, 7, 5, 8,
6, 0, 8, 5, 9, 7, 7, 1, 8, 9, 1, 0, 9, 5, 0, 9, 6, 0, 7, 0, 8, 2, 7, 0,
4, 6, 9, 9, 6, 3, 8, 0])),
names=('seeds', 'labels'),
)},
names=('seeds', 'labels'),
),
metadata={'name': 'node_classification', 'num_classes': 10},)
Loaded link prediction task: OnDiskTask(validation_set=ItemSetDict(
itemsets={'user:like:item': ItemSet(
items=(tensor([[915, 146],
[ 21, 366],
[988, 707],
...,
[592, 121],
[592, 661],
[592, 681]], dtype=torch.int32), tensor([1., 1., 1., ..., 0., 0., 0.], dtype=torch.float64), tensor([ 0, 1, 2, ..., 1999, 1999, 1999])),
names=('seeds', 'labels', 'indexes'),
), 'user:follow:user': ItemSet(
items=(tensor([[664, 92],
[940, 843],
[474, 154],
...,
[144, 753],
[144, 14],
[144, 315]], dtype=torch.int32), tensor([1., 1., 1., ..., 0., 0., 0.], dtype=torch.float64), tensor([ 0, 1, 2, ..., 1999, 1999, 1999])),
names=('seeds', 'labels', 'indexes'),
)},
names=('seeds', 'labels', 'indexes'),
),
train_set=ItemSetDict(
itemsets={'user:like:item': ItemSet(
items=(tensor([[853, 84],
[770, 310],
[584, 243],
...,
[ 88, 264],
[939, 143],
[364, 908]], dtype=torch.int32),),
names=('seeds',),
), 'user:follow:user': ItemSet(
items=(tensor([[328, 971],
[194, 594],
[823, 854],
...,
[320, 502],
[965, 451],
[228, 843]], dtype=torch.int32),),
names=('seeds',),
)},
names=('seeds',),
),
test_set=ItemSetDict(
itemsets={'user:like:item': ItemSet(
items=(tensor([[937, 782],
[541, 259],
[954, 432],
...,
[103, 85],
[103, 605],
[103, 744]], dtype=torch.int32), tensor([1., 1., 1., ..., 0., 0., 0.], dtype=torch.float64), tensor([ 0, 1, 2, ..., 1999, 1999, 1999])),
names=('seeds', 'labels', 'indexes'),
), 'user:follow:user': ItemSet(
items=(tensor([[431, 378],
[121, 986],
[308, 665],
...,
[250, 832],
[250, 147],
[250, 518]], dtype=torch.int32), tensor([1., 1., 1., ..., 0., 0., 0.], dtype=torch.float64), tensor([ 0, 1, 2, ..., 1999, 1999, 1999])),
names=('seeds', 'labels', 'indexes'),
)},
names=('seeds', 'labels', 'indexes'),
),
metadata={'name': 'link_prediction', 'num_classes': 10},)
/home/ubuntu/regression_test/dgl/python/dgl/graphbolt/internal/utils.py:16: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(path)
/home/ubuntu/regression_test/dgl/python/dgl/graphbolt/impl/ondisk_dataset.py:460: DGLWarning: Edge feature is stored, but edge IDs are not saved.
dgl_warning("Edge feature is stored, but edge IDs are not saved.")
/home/ubuntu/regression_test/dgl/python/dgl/graphbolt/impl/ondisk_dataset.py:852: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(graph_topology.path)