LabelPropagation
- class dgl.nn.pytorch.utils.LabelPropagation(k, alpha, norm_type='sym', clamp=True, normalize=False, reset=False)[source]
Bases:
Module
Label Propagation from Learning from Labeled and Unlabeled Data with Label Propagation
where unlabeled data is initially set to zero and inferred from labeled data via propagation.
is a weight parameter for balancing between updated labels and initial labels. denotes the normalized adjacency matrix.- Parameters:
k (int) – The number of propagation steps.
alpha (float) – The
coefficient in range [0, 1].norm_type (str, optional) –
The type of normalization applied to the adjacency matrix, must be one of the following choices:
row
: row-normalized adjacency assym
: symmetrically normalized adjacency as
Default: ‘sym’.
clamp (bool, optional) – A bool flag to indicate whether to clamp the labels to [0, 1] after propagation. Default: True.
normalize (bool, optional) – A bool flag to indicate whether to apply row-normalization after propagation. Default: False.
reset (bool, optional) – A bool flag to indicate whether to reset the known labels after each propagation step. Default: False.
Examples
>>> import torch >>> import dgl >>> from dgl.nn import LabelPropagation
>>> label_propagation = LabelPropagation(k=5, alpha=0.5, clamp=False, normalize=True) >>> g = dgl.rand_graph(5, 10) >>> labels = torch.tensor([0, 2, 1, 3, 0]).long() >>> mask = torch.tensor([0, 1, 1, 1, 0]).bool() >>> new_labels = label_propagation(g, labels, mask)
- forward(g, labels, mask=None)[source]
Compute the label propagation process.
- Parameters:
g (DGLGraph) – The input graph.
labels (torch.Tensor) –
The input node labels. There are three cases supported.
A LongTensor of shape
or for node class labels in multiclass classification, where is the number of nodes.A LongTensor of shape
for one-hot encoding of node class labels in multiclass classification, where is the number of classes.A LongTensor of shape
for node labels in multilabel binary classification, where is the number of labels.
mask (torch.Tensor) – The bool indicators of shape
with True denoting labeled nodes. Default: None, indicating all nodes are labeled.
- Returns:
The propagated node labels of shape
with float type, where is the number of classes or labels.- Return type:
torch.Tensor