[docs]classSSTDataset(DGLBuiltinDataset):r"""Stanford Sentiment Treebank dataset. Each sample is the constituency tree of a sentence. The leaf nodes represent words. The word is a int value stored in the ``x`` feature field. The non-leaf node has a special value ``PAD_WORD`` in the ``x`` field. Each node also has a sentiment annotation: 5 classes (very negative, negative, neutral, positive and very positive). The sentiment label is a int value stored in the ``y`` feature field. Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_ Statistics: - Train examples: 8,544 - Dev examples: 1,101 - Test examples: 2,210 - Number of classes for each node: 5 Parameters ---------- mode : str, optional Should be one of ['train', 'dev', 'test', 'tiny'] Default: train glove_embed_file : str, optional The path to pretrained glove embedding file. Default: None vocab_file : str, optional Optional vocabulary file. If not given, the default vacabulary file is used. Default: None raw_dir : str Raw file directory to download/contains the input data directory. Default: ~/.dgl/ force_reload : bool Whether to reload the dataset. Default: False verbose : bool Whether to print out progress information. Default: True. transform : callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Attributes ---------- vocab : OrderedDict Vocabulary of the dataset num_classes : int Number of classes for each node pretrained_emb: Tensor Pretrained glove embedding with respect the vocabulary. vocab_size : int The size of the vocabulary Notes ----- All the samples will be loaded and preprocessed in the memory first. Examples -------- >>> # get dataset >>> train_data = SSTDataset() >>> dev_data = SSTDataset(mode='dev') >>> test_data = SSTDataset(mode='test') >>> tiny_data = SSTDataset(mode='tiny') >>> >>> len(train_data) 8544 >>> train_data.num_classes 5 >>> glove_embed = train_data.pretrained_emb >>> train_data.vocab_size 19536 >>> train_data[0] Graph(num_nodes=71, num_edges=70, ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)} edata_schemes={}) >>> for tree in train_data: ... input_ids = tree.ndata['x'] ... labels = tree.ndata['y'] ... mask = tree.ndata['mask'] ... # your code here """PAD_WORD=-1# special pad word idUNK_WORD=-1# out-of-vocabulary word iddef__init__(self,mode="train",glove_embed_file=None,vocab_file=None,raw_dir=None,force_reload=False,verbose=False,transform=None,):assertmodein["train","dev","test","tiny"]_url=_get_dgl_url("dataset/sst.zip")self._glove_embed_file=glove_embed_fileifmode=="train"elseNoneself.mode=modeself._vocab_file=vocab_filesuper(SSTDataset,self).__init__(name="sst",url=_url,raw_dir=raw_dir,force_reload=force_reload,verbose=verbose,transform=transform,)defprocess(self):fromnltk.corpus.readerimportBracketParseCorpusReader# load vocab fileself._vocab=OrderedDict()vocab_file=(self._vocab_fileifself._vocab_fileisnotNoneelseos.path.join(self.raw_path,"vocab.txt"))withopen(vocab_file,encoding="utf-8")asvf:forlineinvf.readlines():line=line.strip()self._vocab[line]=len(self._vocab)# filter gloveifself._glove_embed_fileisnotNoneandos.path.exists(self._glove_embed_file):glove_emb={}withopen(self._glove_embed_file,"r",encoding="utf-8")aspf:forlineinpf.readlines():sp=line.split(" ")ifsp[0].lower()inself._vocab:glove_emb[sp[0].lower()]=np.asarray([float(x)forxinsp[1:]])files=["{}.txt".format(self.mode)]corpus=BracketParseCorpusReader(self.raw_path,files)sents=corpus.parsed_sents(files[0])# initialize with glovepretrained_emb=[]fail_cnt=0forlineinself._vocab.keys():ifself._glove_embed_fileisnotNoneandos.path.exists(self._glove_embed_file):ifnotline.lower()inglove_emb:fail_cnt+=1pretrained_emb.append(glove_emb.get(line.lower(),np.random.uniform(-0.05,0.05,300)))self._pretrained_emb=Noneifself._glove_embed_fileisnotNoneandos.path.exists(self._glove_embed_file):self._pretrained_emb=F.tensor(np.stack(pretrained_emb,0))print("Miss word in GloVe {0:.4f}".format(1.0*fail_cnt/len(self._pretrained_emb)))# build treesself._trees=[]forsentinsents:self._trees.append(self._build_tree(sent))def_build_tree(self,root):g=nx.DiGraph()def_rec_build(nid,node):forchildinnode:cid=g.number_of_nodes()ifisinstance(child[0],str)orisinstance(child[0],bytes):# leaf nodeword=self.vocab.get(child[0].lower(),self.UNK_WORD)g.add_node(cid,x=word,y=int(child.label()),mask=1)else:g.add_node(cid,x=SSTDataset.PAD_WORD,y=int(child.label()),mask=0)_rec_build(cid,child)g.add_edge(cid,nid)# add rootg.add_node(0,x=SSTDataset.PAD_WORD,y=int(root.label()),mask=0)_rec_build(0,root)ret=from_networkx(g,node_attrs=["x","y","mask"])returnret@propertydefgraph_path(self):returnos.path.join(self.save_path,self.mode+"_dgl_graph.bin")@propertydefvocab_path(self):returnos.path.join(self.save_path,"vocab.pkl")defhas_cache(self):returnos.path.exists(self.graph_path)andos.path.exists(self.vocab_path)defsave(self):save_graphs(self.graph_path,self._trees)save_info(self.vocab_path,{"vocab":self.vocab})ifself.pretrained_emb:emb_path=os.path.join(self.save_path,"emb.pkl")save_info(emb_path,{"embed":self.pretrained_emb})defload(self):emb_path=os.path.join(self.save_path,"emb.pkl")self._trees=load_graphs(self.graph_path)[0]self._vocab=load_info(self.vocab_path)["vocab"]self._pretrained_emb=Noneifos.path.exists(emb_path):self._pretrained_emb=load_info(emb_path)["embed"]@propertydefvocab(self):r"""Vocabulary Returns ------- OrderedDict """returnself._vocab@propertydefpretrained_emb(self):r"""Pre-trained word embedding, if given."""returnself._pretrained_emb
[docs]def__getitem__(self,idx):r"""Get graph by index Parameters ---------- idx : int Returns ------- :class:`dgl.DGLGraph` graph structure, word id for each node, node labels and masks. - ``ndata['x']``: word id of the node - ``ndata['y']:`` label of the node - ``ndata['mask']``: 1 if the node is a leaf, otherwise 0 """ifself._transformisNone:returnself._trees[idx]else:returnself._transform(self._trees[idx])
[docs]def__len__(self):r"""Number of graphs in the dataset."""returnlen(self._trees)
@propertydefvocab_size(self):r"""Vocabulary size."""returnlen(self._vocab)@propertydefnum_classes(self):r"""Number of classes for each node."""return5