LabelPropagation๏
- class dgl.nn.pytorch.utils.LabelPropagation(k, alpha, norm_type='sym', clamp=True, normalize=False, reset=False)[source]๏
- Bases: - Module- Label Propagation from Learning from Labeled and Unlabeled Data with Label Propagation \[\mathbf{Y}^{(t+1)} = \alpha \tilde{A} \mathbf{Y}^{(t)} + (1 - \alpha) \mathbf{Y}^{(0)}\]- where unlabeled data is initially set to zero and inferred from labeled data via propagation. \(\alpha\) is a weight parameter for balancing between updated labels and initial labels. \(\tilde{A}\) denotes the normalized adjacency matrix. - Parameters:
- k (int) โ The number of propagation steps. 
- alpha (float) โ The \(\alpha\) coefficient in range [0, 1]. 
- norm_type (str, optional) โ - The type of normalization applied to the adjacency matrix, must be one of the following choices: - row: row-normalized adjacency as \(D^{-1}A\)
- sym: symmetrically normalized adjacency as \(D^{-1/2}AD^{-1/2}\)
 - Default: โsymโ. 
- clamp (bool, optional) โ A bool flag to indicate whether to clamp the labels to [0, 1] after propagation. Default: True. 
- normalize (bool, optional) โ A bool flag to indicate whether to apply row-normalization after propagation. Default: False. 
- reset (bool, optional) โ A bool flag to indicate whether to reset the known labels after each propagation step. Default: False. 
 
 - Examples - >>> import torch >>> import dgl >>> from dgl.nn import LabelPropagation - >>> label_propagation = LabelPropagation(k=5, alpha=0.5, clamp=False, normalize=True) >>> g = dgl.rand_graph(5, 10) >>> labels = torch.tensor([0, 2, 1, 3, 0]).long() >>> mask = torch.tensor([0, 1, 1, 1, 0]).bool() >>> new_labels = label_propagation(g, labels, mask) - forward(g, labels, mask=None)[source]๏
- Compute the label propagation process. - Parameters:
- g (DGLGraph) โ The input graph. 
- labels (torch.Tensor) โ - The input node labels. There are three cases supported. - A LongTensor of shape \((N, 1)\) or \((N,)\) for node class labels in multiclass classification, where \(N\) is the number of nodes. 
- A LongTensor of shape \((N, C)\) for one-hot encoding of node class labels in multiclass classification, where \(C\) is the number of classes. 
- A LongTensor of shape \((N, L)\) for node labels in multilabel binary classification, where \(L\) is the number of labels. 
 
- mask (torch.Tensor) โ The bool indicators of shape \((N,)\) with True denoting labeled nodes. Default: None, indicating all nodes are labeled. 
 
- Returns:
- The propagated node labels of shape \((N, D)\) with float type, where \(D\) is the number of classes or labels. 
- Return type:
- torch.Tensor