"""Dataset for stochastic block model."""importmathimportosimportrandomimportnumpyasnpimportnumpy.randomasnprimportscipyasspfrom..importbatchfrom..convertimportfrom_scipyfrom.dgl_datasetimportDGLDatasetfrom.utilsimportload_graphs,load_info,save_graphs,save_infodefsbm(n_blocks,block_size,p,q,rng=None):"""(Symmetric) Stochastic Block Model Parameters ---------- n_blocks : int Number of blocks. block_size : int Block size. p : float Probability for intra-community edge. q : float Probability for inter-community edge. rng : numpy.random.RandomState, optional Random number generator. Returns ------- scipy sparse matrix The adjacency matrix of generated graph. """n=n_blocks*block_sizep/=nq/=nrng=np.random.RandomState()ifrngisNoneelserngrows=[]cols=[]foriinrange(n_blocks):forjinrange(i,n_blocks):density=pifi==jelseqblock=sp.sparse.random(block_size,block_size,density,random_state=rng,data_rvs=lambdan:np.ones(n),)rows.append(block.row+i*block_size)cols.append(block.col+j*block_size)rows=np.hstack(rows)cols=np.hstack(cols)a=sp.sparse.coo_matrix((np.ones(rows.shape[0]),(rows,cols)),shape=(n,n))adj=sp.sparse.triu(a)+sp.sparse.triu(a,1).transpose()returnadj
[docs]classSBMMixtureDataset(DGLDataset):r"""Symmetric Stochastic Block Model Mixture Reference: Appendix C of `Supervised Community Detection with Hierarchical Graph Neural Networks <https://arxiv.org/abs/1705.08415>`_ Parameters ---------- n_graphs : int Number of graphs. n_nodes : int Number of nodes. n_communities : int Number of communities. k : int, optional Multiplier. Default: 2 avg_deg : int, optional Average degree. Default: 3 pq : list of pair of nonnegative float or str, optional Random densities. This parameter is for future extension, for now it's always using the default value. Default: Appendix_C rng : numpy.random.RandomState, optional Random number generator. If not given, it's numpy.random.RandomState() with `seed=None`, which read data from /dev/urandom (or the Windows analogue) if available or seed from the clock otherwise. Default: None Raises ------ RuntimeError is raised if pq is not a list or string. Examples -------- >>> data = SBMMixtureDataset(n_graphs=16, n_nodes=10000, n_communities=2) >>> from torch.utils.data import DataLoader >>> dataloader = DataLoader(data, batch_size=1, collate_fn=data.collate_fn) >>> for graph, line_graph, graph_degrees, line_graph_degrees, pm_pd in dataloader: ... # your code here """def__init__(self,n_graphs,n_nodes,n_communities,k=2,avg_deg=3,pq="Appendix_C",rng=None,):self._n_graphs=n_graphsself._n_nodes=n_nodesself._n_communities=n_communitiesassertn_nodes%n_communities==0self._block_size=n_nodes//n_communitiesself._k=kself._avg_deg=avg_degself._pq=pqself._rng=rngsuper(SBMMixtureDataset,self).__init__(name="sbmmixture",hash_key=(n_graphs,n_nodes,n_communities,k,avg_deg,pq,rng),)defprocess(self):pq=self._pqiftype(pq)islist:assertlen(pq)==self._n_graphseliftype(pq)isstr:generator={"Appendix_C":self._appendix_c}[pq]pq=[generator()for_inrange(self._n_graphs)]else:raiseRuntimeError()self._graphs=[from_scipy(sbm(self._n_communities,self._block_size,*x))forxinpq]self._line_graphs=[g.line_graph(backtracking=False)forginself._graphs]in_degrees=lambdag:g.in_degrees().float()self._graph_degrees=[in_degrees(g)forginself._graphs]self._line_graph_degrees=[in_degrees(lg)forlginself._line_graphs]self._pm_pds=list(zip(*[g.edges()forginself._graphs]))[0]@propertydefgraph_path(self):returnos.path.join(self.save_path,"graphs_{}.bin".format(self.hash))@propertydefline_graph_path(self):returnos.path.join(self.save_path,"line_graphs_{}.bin".format(self.hash))@propertydefinfo_path(self):returnos.path.join(self.save_path,"info_{}.pkl".format(self.hash))defhas_cache(self):return(os.path.exists(self.graph_path)andos.path.exists(self.line_graph_path)andos.path.exists(self.info_path))defsave(self):save_graphs(self.graph_path,self._graphs)save_graphs(self.line_graph_path,self._line_graphs)save_info(self.info_path,{"graph_degree":self._graph_degrees,"line_graph_degree":self._line_graph_degrees,"pm_pds":self._pm_pds,},)defload(self):self._graphs,_=load_graphs(self.graph_path)self._line_graphs,_=load_graphs(self.line_graph_path)info=load_info(self.info_path)self._graph_degrees=info["graph_degree"]self._line_graph_degrees=info["line_graph_degree"]self._pm_pds=info["pm_pds"]
[docs]def__len__(self):r"""Number of graphs in the dataset."""returnlen(self._graphs)
[docs]def__getitem__(self,idx):r"""Get one example by index Parameters ---------- idx : int Item index Returns ------- graph: :class:`dgl.DGLGraph` The original graph line_graph: :class:`dgl.DGLGraph` The line graph of `graph` graph_degree: numpy.ndarray In degrees for each node in `graph` line_graph_degree: numpy.ndarray In degrees for each node in `line_graph` pm_pd: numpy.ndarray Edge indicator matrices Pm and Pd """return(self._graphs[idx],self._line_graphs[idx],self._graph_degrees[idx],self._line_graph_degrees[idx],self._pm_pds[idx],)
[docs]defcollate_fn(self,x):r"""The `collate` function for dataloader Parameters ---------- x : tuple a batch of data that contains: - graph: :class:`dgl.DGLGraph` The original graph - line_graph: :class:`dgl.DGLGraph` The line graph of `graph` - graph_degree: numpy.ndarray In degrees for each node in `graph` - line_graph_degree: numpy.ndarray In degrees for each node in `line_graph` - pm_pd: numpy.ndarray Edge indicator matrices Pm and Pd Returns ------- g_batch: :class:`dgl.DGLGraph` Batched graphs lg_batch: :class:`dgl.DGLGraph` Batched line graphs degg_batch: numpy.ndarray A batch of in degrees for each node in `g_batch` deglg_batch: numpy.ndarray A batch of in degrees for each node in `lg_batch` pm_pd_batch: numpy.ndarray A batch of edge indicator matrices Pm and Pd """g,lg,deg_g,deg_lg,pm_pd=zip(*x)g_batch=batch.batch(g)lg_batch=batch.batch(lg)degg_batch=np.concatenate(deg_g,axis=0)deglg_batch=np.concatenate(deg_lg,axis=0)pm_pd_batch=np.concatenate([x+i*self._n_nodesfori,xinenumerate(pm_pd)],axis=0)returng_batch,lg_batch,degg_batch,deglg_batch,pm_pd_batch