"""A mini synthetic dataset for graph classification benchmark."""importmathimportosimportnetworkxasnximportnumpyasnpfrom..importbackendasFfrom..convertimportfrom_networkxfrom..transformsimportadd_self_loopfrom.dgl_datasetimportDGLDatasetfrom.utilsimportload_graphs,makedirs,save_graphs__all__=["MiniGCDataset"]
[docs]classMiniGCDataset(DGLDataset):"""The synthetic graph classification dataset class. The datset contains 8 different types of graphs. - class 0 : cycle graph - class 1 : star graph - class 2 : wheel graph - class 3 : lollipop graph - class 4 : hypercube graph - class 5 : grid graph - class 6 : clique graph - class 7 : circular ladder graph Parameters ---------- num_graphs: int Number of graphs in this dataset. min_num_v: int Minimum number of nodes for graphs max_num_v: int Maximum number of nodes for graphs seed: int, default is 0 Random seed for data generation transform: callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Attributes ---------- num_graphs : int Number of graphs min_num_v : int The minimum number of nodes max_num_v : int The maximum number of nodes num_classes : int The number of classes Examples -------- >>> data = MiniGCDataset(100, 16, 32, seed=0) The dataset instance is an iterable >>> len(data) 100 >>> g, label = data[64] >>> g Graph(num_nodes=20, num_edges=82, ndata_schemes={} edata_schemes={}) >>> label tensor(5) Batch the graphs and labels for mini-batch training >>> graphs, labels = zip(*[data[i] for i in range(16)]) >>> batched_graphs = dgl.batch(graphs) >>> batched_labels = torch.tensor(labels) >>> batched_graphs Graph(num_nodes=356, num_edges=1060, ndata_schemes={} edata_schemes={}) """def__init__(self,num_graphs,min_num_v,max_num_v,seed=0,save_graph=True,force_reload=False,verbose=False,transform=None,):self.num_graphs=num_graphsself.min_num_v=min_num_vself.max_num_v=max_num_vself.seed=seedself.save_graph=save_graphsuper(MiniGCDataset,self).__init__(name="minigc",hash_key=(num_graphs,min_num_v,max_num_v,seed),force_reload=force_reload,verbose=verbose,transform=transform,)defprocess(self):self.graphs=[]self.labels=[]self._generate(self.seed)
[docs]def__len__(self):"""Return the number of graphs in the dataset."""returnlen(self.graphs)
[docs]def__getitem__(self,idx):"""Get the idx-th sample. Parameters --------- idx : int The sample index. Returns ------- (:class:`dgl.Graph`, Tensor) The graph and its label. """ifself._transformisNone:g=self.graphs[idx]else:g=self._transform(self.graphs[idx])returng,self.labels[idx]
defhas_cache(self):graph_path=os.path.join(self.save_path,"dgl_graph_{}.bin".format(self.hash))ifos.path.exists(graph_path):returnTruereturnFalsedefsave(self):"""save the graph list and the labels"""ifself.save_graph:graph_path=os.path.join(self.save_path,"dgl_graph_{}.bin".format(self.hash))save_graphs(str(graph_path),self.graphs,{"labels":self.labels})defload(self):graphs,label_dict=load_graphs(os.path.join(self.save_path,"dgl_graph_{}.bin".format(self.hash)))self.graphs=graphsself.labels=label_dict["labels"]@propertydefnum_classes(self):"""Number of classes."""return8def_generate(self,seed):ifseedisnotNone:np.random.seed(seed)self._gen_cycle(self.num_graphs//8)self._gen_star(self.num_graphs//8)self._gen_wheel(self.num_graphs//8)self._gen_lollipop(self.num_graphs//8)self._gen_hypercube(self.num_graphs//8)self._gen_grid(self.num_graphs//8)self._gen_clique(self.num_graphs//8)self._gen_circular_ladder(self.num_graphs-len(self.graphs))# preprocessforiinrange(self.num_graphs):# convert to DGLGraph, and add self loopsself.graphs[i]=add_self_loop(from_networkx(self.graphs[i]))self.labels=F.tensor(np.array(self.labels).astype(np.int64))def_gen_cycle(self,n):for_inrange(n):num_v=np.random.randint(self.min_num_v,self.max_num_v)g=nx.cycle_graph(num_v)self.graphs.append(g)self.labels.append(0)def_gen_star(self,n):for_inrange(n):num_v=np.random.randint(self.min_num_v,self.max_num_v)# nx.star_graph(N) gives a star graph with N+1 nodesg=nx.star_graph(num_v-1)self.graphs.append(g)self.labels.append(1)def_gen_wheel(self,n):for_inrange(n):num_v=np.random.randint(self.min_num_v,self.max_num_v)g=nx.wheel_graph(num_v)self.graphs.append(g)self.labels.append(2)def_gen_lollipop(self,n):for_inrange(n):num_v=np.random.randint(self.min_num_v,self.max_num_v)path_len=np.random.randint(2,num_v//2)g=nx.lollipop_graph(m=num_v-path_len,n=path_len)self.graphs.append(g)self.labels.append(3)def_gen_hypercube(self,n):for_inrange(n):num_v=np.random.randint(self.min_num_v,self.max_num_v)g=nx.hypercube_graph(int(math.log(num_v,2)))g=nx.convert_node_labels_to_integers(g)self.graphs.append(g)self.labels.append(4)def_gen_grid(self,n):for_inrange(n):num_v=np.random.randint(self.min_num_v,self.max_num_v)assertnum_v>=4,("We require a grid graph to contain at least two ""rows and two columns, thus 4 nodes, got {:d} ""nodes".format(num_v))n_rows=np.random.randint(2,num_v//2)n_cols=num_v//n_rowsg=nx.grid_graph([n_rows,n_cols])g=nx.convert_node_labels_to_integers(g)self.graphs.append(g)self.labels.append(5)def_gen_clique(self,n):for_inrange(n):num_v=np.random.randint(self.min_num_v,self.max_num_v)g=nx.complete_graph(num_v)self.graphs.append(g)self.labels.append(6)def_gen_circular_ladder(self,n):for_inrange(n):num_v=np.random.randint(self.min_num_v,self.max_num_v)g=nx.circular_ladder_graph(num_v//2)self.graphs.append(g)self.labels.append(7)