"""Torch Module for PGExplainer"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .... import batch, ETYPE, khop_in_subgraph, NID, to_homogeneous
__all__ = ["PGExplainer", "HeteroPGExplainer"]
[docs]
class PGExplainer(nn.Module):
    r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
    <https://arxiv.org/pdf/2011.04573>`
    PGExplainer adopts a deep neural network (explanation network) to
    parameterize the generation process of explanations, which enables it to
    explain multiple instances collectively. PGExplainer models the underlying
    structure as edge distributions, from which the explanatory graph is
    sampled.
    Parameters
    ----------
    model : nn.Module
        The GNN model to explain that tackles multiclass graph classification
        * Its forward function must have the form
          :attr:`forward(self, graph, nfeat, embed, edge_weight)`.
        * The output of its forward function is the logits if embed=False else
          the intermediate node embeddings.
    num_features : int
        Node embedding size used by :attr:`model`.
    num_hops : int, optional
        The number of hops for GNN information aggregation, which must match the
        number of message passing layers employed by the GNN to be explained.
    explain_graph : bool, optional
        Whether to initialize the model for graph-level or node-level predictions.
    coff_budget : float, optional
        Size regularization to constrain the explanation size. Default: 0.01.
    coff_connect : float, optional
        Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.
    sample_bias : float, optional
        Some members of a population are systematically more likely to be selected
        in a sample than others. Default: 0.0.
    """
    def __init__(
        self,
        model,
        num_features,
        num_hops=None,
        explain_graph=True,
        coff_budget=0.01,
        coff_connect=5e-4,
        sample_bias=0.0,
    ):
        super(PGExplainer, self).__init__()
        self.model = model
        self.graph_explanation = explain_graph
        # Node explanation requires additional self-embedding data.
        self.num_features = num_features * (2 if self.graph_explanation else 3)
        self.num_hops = num_hops
        # training hyperparameters for PGExplainer
        self.coff_budget = coff_budget
        self.coff_connect = coff_connect
        self.sample_bias = sample_bias
        self.init_bias = 0.0
        # Explanation network in PGExplainer
        self.elayers = nn.Sequential(
            nn.Linear(self.num_features, 64), nn.ReLU(), nn.Linear(64, 1)
        )
    def set_masks(self, graph, edge_mask=None):
        r"""Set the edge mask that plays a crucial role to explain the
        prediction made by the GNN for a graph. Initialize learnable edge
        mask if it is None.
        Parameters
        ----------
        graph : DGLGraph
            A homogeneous graph.
        edge_mask : Tensor, optional
            Learned importance mask of the edges in the graph, which is a tensor
            of shape :math:`(E)`, where :math:`E` is the number of edges in the
            graph. The values are within range :math:`(0, 1)`. The higher,
            the more important. Default: None.
        """
        if edge_mask is None:
            num_nodes = graph.num_nodes()
            num_edges = graph.num_edges()
            init_bias = self.init_bias
            std = nn.init.calculate_gain("relu") * math.sqrt(
                2.0 / (2 * num_nodes)
            )
            self.edge_mask = torch.randn(num_edges) * std + init_bias
        else:
            self.edge_mask = edge_mask
        self.edge_mask = self.edge_mask.to(graph.device)
    def clear_masks(self):
        r"""Clear the edge mask that play a crucial role to explain the
        prediction made by the GNN for a graph.
        """
        self.edge_mask = None
    def parameters(self):
        r"""
        Returns an iterator over the `Parameter` objects of the `nn.Linear`
        layers in the `self.elayers` sequential module. Each `Parameter`
        object contains the weight and bias parameters of an `nn.Linear`
        layer, as learned during training.
        Returns
        -------
        iterator
            An iterator over the `Parameter` objects of the `nn.Linear`
            layers in the `self.elayers` sequential module.
        """
        return self.elayers.parameters()
    def loss(self, prob, ori_pred):
        r"""The loss function that is used to learn the edge
        distribution.
        Parameters
        ----------
        prob: Tensor
            Tensor contains a set of probabilities for each possible
            class label of some model for all the batched graphs,
            which is of shape :math:`(B, L)`, where :math:`L` is the
            different types of label in the dataset and :math:`B` is
            the batch size.
        ori_pred: Tensor
            Tensor of shape :math:`(B, 1)`, representing the original prediction
            for the graph, where :math:`B` is the batch size.
        Returns
        -------
        float
            The function that returns the sum of the three loss components,
            which is a scalar tensor representing the total loss.
        """
        target_prob = prob.gather(-1, ori_pred.unsqueeze(-1))
        # 1e-6 added to prob to avoid taking the logarithm of zero
        target_prob += 1e-6
        # computing the log likelihood for a single prediction
        pred_loss = torch.mean(-torch.log(target_prob))
        # size
        edge_mask = self.sparse_mask_values
        if self.coff_budget <= 0:
            size_loss = self.coff_budget * torch.sum(edge_mask)
        else:
            size_loss = self.coff_budget * F.relu(
                torch.sum(edge_mask) - self.coff_budget
            )
        # entropy
        scale = 0.99
        edge_mask = self.edge_mask * (2 * scale - 1.0) + (1.0 - scale)
        mask_ent = -edge_mask * torch.log(edge_mask) - (
            1 - edge_mask
        ) * torch.log(1 - edge_mask)
        mask_ent_loss = self.coff_connect * torch.mean(mask_ent)
        loss = pred_loss + size_loss + mask_ent_loss
        return loss
    def concrete_sample(self, w, beta=1.0, training=True):
        r"""Sample from the instantiation of concrete distribution when training.
        Parameters
        ----------
        w : Tensor
            A tensor representing the log of the prior probability of choosing the edges.
        beta : float, optional
            Controls the degree of randomness in the output of the sigmoid function.
        training : bool, optional
            Randomness is injected during training.
        Returns
        -------
        Tensor
            If training is set to True, the output is a tensor of probabilities that
            represent the probability of activating the gate for each input element.
            If training is set to False, the output is also a tensor of probabilities,
            but they are determined solely by the log_alpha values, without adding any
            random noise.
        """
        if training:
            bias = self.sample_bias
            random_noise = torch.rand(w.size()).to(w.device)
            random_noise = bias + (1 - 2 * bias) * random_noise
            gate_inputs = torch.log(random_noise) - torch.log(
                1.0 - random_noise
            )
            gate_inputs = (gate_inputs + w) / beta
            gate_inputs = torch.sigmoid(gate_inputs)
        else:
            gate_inputs = torch.sigmoid(w)
        return gate_inputs
[docs]
    def train_step(self, graph, feat, temperature, **kwargs):
        r"""Compute the loss of the explanation network for graph classification
        Parameters
        ----------
        graph : DGLGraph
            Input batched homogeneous graph.
        feat : Tensor
            The input feature of shape :math:`(N, D)`. :math:`N` is the
            number of nodes, and :math:`D` is the feature size.
        temperature : float
            The temperature parameter fed to the sampling procedure.
        kwargs : dict
            Additional arguments passed to the GNN model.
        Returns
        -------
        Tensor
            A scalar tensor representing the loss.
        """
        assert (
            self.graph_explanation
        ), '"explain_graph" must be True when initializing the module.'
        self.model = self.model.to(graph.device)
        self.elayers = self.elayers.to(graph.device)
        pred = self.model(graph, feat, embed=False, **kwargs)
        pred = pred.argmax(-1).data
        prob, _ = self.explain_graph(
            graph, feat, temperature, training=True, **kwargs
        )
        loss = self.loss(prob, pred)
        return loss 
[docs]
    def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
        r"""Compute the loss of the explanation network for node classification
        Parameters
        ----------
        nodes : int, iterable[int], tensor
            The nodes from the graph used to train the explanation network,
            which cannot have any duplicate value.
        graph : DGLGraph
            Input homogeneous graph.
        feat : Tensor
            The input feature of shape :math:`(N, D)`. :math:`N` is the
            number of nodes, and :math:`D` is the feature size.
        temperature : float
            The temperature parameter fed to the sampling procedure.
        kwargs : dict
            Additional arguments passed to the GNN model.
        Returns
        -------
        Tensor
            A scalar tensor representing the loss.
        """
        assert (
            not self.graph_explanation
        ), '"explain_graph" must be False when initializing the module.'
        self.model = self.model.to(graph.device)
        self.elayers = self.elayers.to(graph.device)
        if isinstance(nodes, torch.Tensor):
            nodes = nodes.tolist()
        if isinstance(nodes, int):
            nodes = [nodes]
        prob, _, batched_graph, inverse_indices = self.explain_node(
            nodes, graph, feat, temperature, training=True, **kwargs
        )
        pred = self.model(
            batched_graph, self.batched_feats, embed=False, **kwargs
        )
        pred = pred.argmax(-1).data
        loss = self.loss(prob[inverse_indices], pred[inverse_indices])
        return loss 
[docs]
    def explain_graph(
        self, graph, feat, temperature=1.0, training=False, **kwargs
    ):
        r"""Learn and return an edge mask that plays a crucial role to
        explain the prediction made by the GNN for a graph. Also, return
        the prediction made with the edges chosen based on the edge mask.
        Parameters
        ----------
        graph : DGLGraph
            A homogeneous graph.
        feat : Tensor
            The input feature of shape :math:`(N, D)`. :math:`N` is the
            number of nodes, and :math:`D` is the feature size.
        temperature : float
            The temperature parameter fed to the sampling procedure.
        training : bool
            Training the explanation network.
        kwargs : dict
            Additional arguments passed to the GNN model.
        Returns
        -------
        Tensor
            Classification probabilities given the masked graph. It is a tensor
            of shape :math:`(B, L)`, where :math:`L` is the different types of
            label in the dataset, and :math:`B` is the batch size.
        Tensor
            Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
            is the number of edges in the graph. A higher weight suggests a
            larger contribution of the edge.
        Examples
        --------
        >>> import torch as th
        >>> import torch.nn as nn
        >>> import dgl
        >>> from dgl.data import GINDataset
        >>> from dgl.dataloading import GraphDataLoader
        >>> from dgl.nn import GraphConv, PGExplainer
        >>> import numpy as np
        >>> # Define the model
        >>> class Model(nn.Module):
        ...     def __init__(self, in_feats, out_feats):
        ...         super().__init__()
        ...         self.conv = GraphConv(in_feats, out_feats)
        ...         self.fc = nn.Linear(out_feats, out_feats)
        ...         nn.init.xavier_uniform_(self.fc.weight)
        ...
        ...     def forward(self, g, h, embed=False, edge_weight=None):
        ...         h = self.conv(g, h, edge_weight=edge_weight)
        ...
        ...         if embed:
        ...             return h
        ...
        ...         with g.local_scope():
        ...             g.ndata['h'] = h
        ...             hg = dgl.mean_nodes(g, 'h')
        ...             return self.fc(hg)
        >>> # Load dataset
        >>> data = GINDataset('MUTAG', self_loop=True)
        >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
        >>> # Train the model
        >>> feat_size = data[0][0].ndata['attr'].shape[1]
        >>> model = Model(feat_size, data.gclasses)
        >>> criterion = nn.CrossEntropyLoss()
        >>> optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
        >>> for bg, labels in dataloader:
        ...     preds = model(bg, bg.ndata['attr'])
        ...     loss = criterion(preds, labels)
        ...     optimizer.zero_grad()
        ...     loss.backward()
        ...     optimizer.step()
        >>> # Initialize the explainer
        >>> explainer = PGExplainer(model, data.gclasses)
        >>> # Train the explainer
        >>> # Define explainer temperature parameter
        >>> init_tmp, final_tmp = 5.0, 1.0
        >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
        >>> for epoch in range(20):
        ...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
        ...     for bg, labels in dataloader:
        ...          loss = explainer.train_step(bg, bg.ndata['attr'], tmp)
        ...          optimizer_exp.zero_grad()
        ...          loss.backward()
        ...          optimizer_exp.step()
        >>> # Explain the prediction for graph 0
        >>> graph, l = data[0]
        >>> graph_feat = graph.ndata.pop("attr")
        >>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
        """
        assert (
            self.graph_explanation
        ), '"explain_graph" must be True when initializing the module.'
        self.model = self.model.to(graph.device)
        self.elayers = self.elayers.to(graph.device)
        embed = self.model(graph, feat, embed=True, **kwargs)
        embed = embed.data
        col, row = graph.edges()
        col_emb = embed[col.long()]
        row_emb = embed[row.long()]
        emb = torch.cat([col_emb, row_emb], dim=-1)
        emb = self.elayers(emb)
        values = emb.reshape(-1)
        values = self.concrete_sample(
            values, beta=temperature, training=training
        )
        self.sparse_mask_values = values
        reverse_eids = graph.edge_ids(row, col).long()
        edge_mask = (values + values[reverse_eids]) / 2
        self.set_masks(graph, edge_mask)
        # the model prediction with the updated edge mask
        logits = self.model(graph, feat, edge_weight=self.edge_mask, **kwargs)
        probs = F.softmax(logits, dim=-1)
        if training:
            probs = probs.data
        else:
            self.clear_masks()
        return (probs, edge_mask) 
[docs]
    def explain_node(
        self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
    ):
        r"""Learn and return an edge mask that plays a crucial role to
        explain the prediction made by the GNN for provided set of node IDs.
        Also, return the prediction made with the graph and edge mask.
        Parameters
        ----------
        nodes : int, iterable[int], tensor
            The nodes from the graph, which cannot have any duplicate value.
        graph : DGLGraph
            A homogeneous graph.
        feat : Tensor
            The input feature of shape :math:`(N, D)`. :math:`N` is the
            number of nodes, and :math:`D` is the feature size.
        temperature : float
            The temperature parameter fed to the sampling procedure.
        training : bool
            Training the explanation network.
        kwargs : dict
            Additional arguments passed to the GNN model.
        Returns
        -------
        Tensor
            Classification probabilities given the masked graph. It is a tensor
            of shape :math:`(N, L)`, where :math:`L` is the different types
            of node labels in the dataset, and :math:`N` is the number of nodes
            in the graph.
        Tensor
            Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
            is the number of edges in the graph. A higher weight suggests a
            larger contribution of the edge.
        DGLGraph
            The batched set of subgraphs induced on the k-hop in-neighborhood
            of the input center nodes.
        Tensor
            The new IDs of the subgraph center nodes.
        Examples
        --------
        >>> import dgl
        >>> import numpy as np
        >>> import torch
        >>> # Define the model
        >>> class Model(torch.nn.Module):
        ...     def __init__(self, in_feats, out_feats):
        ...         super().__init__()
        ...         self.conv1 = dgl.nn.GraphConv(in_feats, out_feats)
        ...         self.conv2 = dgl.nn.GraphConv(out_feats, out_feats)
        ...
        ...     def forward(self, g, h, embed=False, edge_weight=None):
        ...         h = self.conv1(g, h, edge_weight=edge_weight)
        ...         if embed:
        ...             return h
        ...         return self.conv2(g, h)
        >>> # Load dataset
        >>> data = dgl.data.CoraGraphDataset(verbose=False)
        >>> g = data[0]
        >>> features = g.ndata["feat"]
        >>> labels = g.ndata["label"]
        >>> # Train the model
        >>> model = Model(features.shape[1], data.num_classes)
        >>> criterion = torch.nn.CrossEntropyLoss()
        >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
        >>> for epoch in range(20):
        ...     logits = model(g, features)
        ...     loss = criterion(logits, labels)
        ...     optimizer.zero_grad()
        ...     loss.backward()
        ...     optimizer.step()
        >>> # Initialize the explainer
        >>> explainer = dgl.nn.PGExplainer(
        ...     model, data.num_classes, num_hops=2, explain_graph=False
        ... )
        >>> # Train the explainer
        >>> # Define explainer temperature parameter
        >>> init_tmp, final_tmp = 5.0, 1.0
        >>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01)
        >>> epochs = 10
        >>> for epoch in range(epochs):
        ...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs))
        ...     loss = explainer.train_step_node(g.nodes(), g, features, tmp)
        ...     optimizer_exp.zero_grad()
        ...     loss.backward()
        ...     optimizer_exp.step()
        >>> # Explain the prediction for graph 0
        >>> probs, edge_weight, bg, inverse_indices = explainer.explain_node(
        ...     0, g, features
        ... )
        """
        assert (
            not self.graph_explanation
        ), '"explain_graph" must be False when initializing the module.'
        assert (
            self.num_hops is not None
        ), '"num_hops" must be provided when initializing the module.'
        if isinstance(nodes, torch.Tensor):
            nodes = nodes.tolist()
        if isinstance(nodes, int):
            nodes = [nodes]
        self.model = self.model.to(graph.device)
        self.elayers = self.elayers.to(graph.device)
        batched_graph = []
        batched_embed = []
        for node_id in nodes:
            sg, inverse_indices = khop_in_subgraph(
                graph, node_id, self.num_hops
            )
            sg.ndata["feat"] = feat[sg.ndata[NID].long()]
            sg.ndata["train"] = torch.tensor(
                [nid in inverse_indices for nid in sg.nodes()], device=sg.device
            )
            embed = self.model(sg, sg.ndata["feat"], embed=True, **kwargs)
            embed = embed.data
            col, row = sg.edges()
            col_emb = embed[col.long()]
            row_emb = embed[row.long()]
            self_emb = embed[inverse_indices[0]].repeat(sg.num_edges(), 1)
            emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
            batched_embed.append(emb)
            batched_graph.append(sg)
        batched_graph = batch(batched_graph)
        batched_embed = torch.cat(batched_embed)
        batched_embed = self.elayers(batched_embed)
        values = batched_embed.reshape(-1)
        values = self.concrete_sample(
            values, beta=temperature, training=training
        )
        self.sparse_mask_values = values
        col, row = batched_graph.edges()
        reverse_eids = batched_graph.edge_ids(row, col).long()
        edge_mask = (values + values[reverse_eids]) / 2
        self.set_masks(batched_graph, edge_mask)
        batched_feats = batched_graph.ndata["feat"]
        # the model prediction with the updated edge mask
        logits = self.model(
            batched_graph, batched_feats, edge_weight=self.edge_mask, **kwargs
        )
        probs = F.softmax(logits, dim=-1)
        batched_inverse_indices = (
            batched_graph.ndata["train"].nonzero().squeeze(1)
        )
        if training:
            self.batched_feats = batched_feats
            probs = probs.data
        else:
            self.clear_masks()
        return (
            probs,
            edge_mask,
            batched_graph,
            batched_inverse_indices,
        ) 
 
[docs]
class HeteroPGExplainer(PGExplainer):
    r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
    <https://arxiv.org/pdf/2011.04573>`__, adapted for heterogeneous graphs
    PGExplainer adopts a deep neural network (explanation network) to
    parameterize the generation process of explanations, which enables it to
    explain multiple instances collectively. PGExplainer models the underlying
    structure as edge distributions, from which the explanatory graph is
    sampled.
    Parameters
    ----------
    model : nn.Module
        The GNN model to explain that tackles multiclass graph classification
        * Its forward function must have the form
          :attr:`forward(self, graph, nfeat, embed, edge_weight)`.
        * The output of its forward function is the logits if embed=False else
          the intermediate node embeddings.
    num_features : int
        Node embedding size used by :attr:`model`.
    coff_budget : float, optional
        Size regularization to constrain the explanation size. Default: 0.01.
    coff_connect : float, optional
        Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.
    sample_bias : float, optional
        Some members of a population are systematically more likely to be selected
        in a sample than others. Default: 0.0.
    """
[docs]
    def train_step(self, graph, feat, temperature, **kwargs):
        # pylint: disable=useless-super-delegation
        r"""Compute the loss of the explanation network for graph classification
        Parameters
        ----------
        graph : DGLGraph
            Input batched heterogeneous graph.
        feat : dict[str, Tensor]
            A dict mapping node types (keys) to feature tensors (values).
            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
            the number of nodes for node type :math:`t`, and :math:`D_t` is the
            feature size for node type :math:`t`
        temperature : float
            The temperature parameter fed to the sampling procedure.
        kwargs : dict
            Additional arguments passed to the GNN model.
        Returns
        -------
        Tensor
            A scalar tensor representing the loss.
        """
        return super().train_step(graph, feat, temperature, **kwargs) 
[docs]
    def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
        r"""Compute the loss of the explanation network for node classification
        Parameters
        ----------
        nodes : dict[str, Iterable[int]]
            A dict mapping node types (keys) to an iterable set of node ids (values).
        graph : DGLGraph
            Input heterogeneous graph.
        feat : dict[str, Tensor]
            A dict mapping node types (keys) to feature tensors (values).
            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
            the number of nodes for node type :math:`t`, and :math:`D_t` is the
            feature size for node type :math:`t`
        temperature : float
            The temperature parameter fed to the sampling procedure.
        kwargs : dict
            Additional arguments passed to the GNN model.
        Returns
        -------
        Tensor
            A scalar tensor representing the loss.
        """
        assert (
            not self.graph_explanation
        ), '"explain_graph" must be False when initializing the module.'
        self.model = self.model.to(graph.device)
        self.elayers = self.elayers.to(graph.device)
        prob, _, batched_graph, inverse_indices = self.explain_node(
            nodes, graph, feat, temperature, training=True, **kwargs
        )
        pred = self.model(
            batched_graph, self.batched_feats, embed=False, **kwargs
        )
        pred = {ntype: pred[ntype].argmax(-1).data for ntype in pred.keys()}
        loss = self.loss(
            torch.cat(
                [prob[ntype][nid] for ntype, nid in inverse_indices.items()]
            ),
            torch.cat(
                [pred[ntype][nid] for ntype, nid in inverse_indices.items()]
            ),
        )
        return loss 
[docs]
    def explain_graph(
        self, graph, feat, temperature=1.0, training=False, **kwargs
    ):
        r"""Learn and return an edge mask that plays a crucial role to
        explain the prediction made by the GNN for a graph. Also, return
        the prediction made with the edges chosen based on the edge mask.
        Parameters
        ----------
        graph : DGLGraph
            A heterogeneous graph.
        feat : dict[str, Tensor]
            A dict mapping node types (keys) to feature tensors (values).
            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
            the number of nodes for node type :math:`t`, and :math:`D_t` is the
            feature size for node type :math:`t`
        temperature : float
            The temperature parameter fed to the sampling procedure.
        training : bool
            Training the explanation network.
        kwargs : dict
            Additional arguments passed to the GNN model.
        Returns
        -------
        Tensor
            Classification probabilities given the masked graph. It is a tensor
            of shape :math:`(B, L)`, where :math:`L` is the different types of
            label in the dataset, and :math:`B` is the batch size.
        dict[str, Tensor]
            A dict mapping edge types (keys) to edge tensors (values) of shape
            :math:`(E_t)`, where :math:`E_t` is the number of edges in the graph
            for edge type :math:`t`.  A higher weight suggests a larger
            contribution of the edge.
        Examples
        --------
        >>> import dgl
        >>> import torch as th
        >>> import torch.nn as nn
        >>> import numpy as np
        >>> # Define the model
        >>> class Model(nn.Module):
        ...     def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        ...         super().__init__()
        ...         self.conv = dgl.nn.HeteroGraphConv(
        ...             {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
        ...             aggregate="sum",
        ...         )
        ...         self.fc = nn.Linear(hid_feats, out_feats)
        ...         nn.init.xavier_uniform_(self.fc.weight)
        ...
        ...     def forward(self, g, h, embed=False, edge_weight=None):
        ...         if edge_weight:
        ...             mod_kwargs = {
        ...                 etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
        ...             }
        ...             h = self.conv(g, h, mod_kwargs=mod_kwargs)
        ...         else:
        ...             h = self.conv(g, h)
        ...
        ...         if embed:
        ...             return h
        ...
        ...         with g.local_scope():
        ...             g.ndata["h"] = h
        ...             hg = 0
        ...             for ntype in g.ntypes:
        ...                 hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
        ...             return self.fc(hg)
        >>> # Load dataset
        >>> input_dim = 5
        >>> hidden_dim = 5
        >>> num_classes = 2
        >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
        >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
        >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
        >>> transform = dgl.transforms.AddReverse()
        >>> g = transform(g)
        >>> # define and train the model
        >>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
        >>> optimizer = th.optim.Adam(model.parameters())
        >>> for epoch in range(10):
        ...     logits = model(g, g.ndata["h"])
        ...     loss = th.nn.functional.cross_entropy(logits, th.tensor([1]))
        ...     optimizer.zero_grad()
        ...     loss.backward()
        ...     optimizer.step()
        >>> # Initialize the explainer
        >>> explainer = dgl.nn.HeteroPGExplainer(model, hidden_dim)
        >>> # Train the explainer
        >>> # Define explainer temperature parameter
        >>> init_tmp, final_tmp = 5.0, 1.0
        >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
        >>> for epoch in range(20):
        ...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
        ...     loss = explainer.train_step(g, g.ndata["h"], tmp)
        ...     optimizer_exp.zero_grad()
        ...     loss.backward()
        ...     optimizer_exp.step()
        >>> # Explain the graph
        >>> feat = g.ndata.pop("h")
        >>> probs, edge_mask = explainer.explain_graph(g, feat)
        """
        assert (
            self.graph_explanation
        ), '"explain_graph" must be True when initializing the module.'
        self.model = self.model.to(graph.device)
        self.elayers = self.elayers.to(graph.device)
        embed = self.model(graph, feat, embed=True, **kwargs)
        for ntype, emb in embed.items():
            graph.nodes[ntype].data["emb"] = emb.data
        homo_graph = to_homogeneous(graph, ndata=["emb"])
        homo_embed = homo_graph.ndata["emb"]
        col, row = homo_graph.edges()
        col_emb = homo_embed[col.long()]
        row_emb = homo_embed[row.long()]
        emb = torch.cat([col_emb, row_emb], dim=-1)
        emb = self.elayers(emb)
        values = emb.reshape(-1)
        values = self.concrete_sample(
            values, beta=temperature, training=training
        )
        self.sparse_mask_values = values
        reverse_eids = homo_graph.edge_ids(row, col).long()
        edge_mask = (values + values[reverse_eids]) / 2
        self.set_masks(homo_graph, edge_mask)
        # convert the edge mask back into heterogeneous format
        hetero_edge_mask = self._edge_mask_to_heterogeneous(
            edge_mask=edge_mask,
            homograph=homo_graph,
            heterograph=graph,
        )
        # the model prediction with the updated edge mask
        logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs)
        probs = F.softmax(logits, dim=-1)
        if training:
            probs = probs.data
        else:
            self.clear_masks()
        return (probs, hetero_edge_mask) 
[docs]
    def explain_node(
        self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
    ):
        r"""Learn and return an edge mask that plays a crucial role to
        explain the prediction made by the GNN for provided set of node IDs.
        Also, return the prediction made with the batched graph and edge mask.
        Parameters
        ----------
        nodes : dict[str, Iterable[int]]
            A dict mapping node types (keys) to an iterable set of node ids (values).
        graph : DGLGraph
            A heterogeneous graph.
        feat : dict[str, Tensor]
            A dict mapping node types (keys) to feature tensors (values).
            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
            the number of nodes for node type :math:`t`, and :math:`D_t` is the
            feature size for node type :math:`t`
        temperature : float
            The temperature parameter fed to the sampling procedure.
        training : bool
            Training the explanation network.
        kwargs : dict
            Additional arguments passed to the GNN model.
        Returns
        -------
        dict[str, Tensor]
            A dict mapping node types (keys) to classification probabilities
            for node labels (values). The values are tensors of shape
            :math:`(N_t, L)`, where :math:`L` is the different types of node
            labels in the dataset, and :math:`N_t` is the number of nodes in
            the graph for node type :math:`t`.
        dict[str, Tensor]
            A dict mapping edge types (keys) to edge tensors (values) of shape
            :math:`(E_t)`, where :math:`E_t` is the number of edges in the graph
            for edge type :math:`t`.  A higher weight suggests a larger
            contribution of the edge.
        DGLGraph
            The batched set of subgraphs induced on the k-hop in-neighborhood
            of the input center nodes.
        dict[str, Tensor]
            A dict mapping node types (keys) to a tensor of node IDs (values)
            which correspond to the subgraph center nodes.
        Examples
        --------
        >>> import dgl
        >>> import torch as th
        >>> import torch.nn as nn
        >>> import numpy as np
        >>> # Define the model
        >>> class Model(nn.Module):
        ...     def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        ...         super().__init__()
        ...         self.conv = dgl.nn.HeteroGraphConv(
        ...             {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
        ...             aggregate="sum",
        ...         )
        ...         self.fc = nn.Linear(hid_feats, out_feats)
        ...         nn.init.xavier_uniform_(self.fc.weight)
        ...
        ...     def forward(self, g, h, embed=False, edge_weight=None):
        ...         if edge_weight:
        ...             mod_kwargs = {
        ...                 etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
        ...             }
        ...             h = self.conv(g, h, mod_kwargs=mod_kwargs)
        ...         else:
        ...             h = self.conv(g, h)
        ...
        ...         return h
        >>> # Load dataset
        >>> input_dim = 5
        >>> hidden_dim = 5
        >>> num_classes = 2
        >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
        >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
        >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
        >>> transform = dgl.transforms.AddReverse()
        >>> g = transform(g)
        >>> # define and train the model
        >>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
        >>> optimizer = th.optim.Adam(model.parameters())
        >>> for epoch in range(10):
        ...     logits = model(g, g.ndata["h"])['user']
        ...     loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1]))
        ...     optimizer.zero_grad()
        ...     loss.backward()
        ...     optimizer.step()
        >>> # Initialize the explainer
        >>> explainer = dgl.nn.HeteroPGExplainer(
        ...     model, hidden_dim, num_hops=2, explain_graph=False
        ... )
        >>> # Train the explainer
        >>> # Define explainer temperature parameter
        >>> init_tmp, final_tmp = 5.0, 1.0
        >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
        >>> for epoch in range(20):
        ...     tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
        ...     loss = explainer.train_step_node(
        ...         { ntype: g.nodes(ntype) for ntype in g.ntypes },
        ...         g, g.ndata["h"], tmp
        ...     )
        ...     optimizer_exp.zero_grad()
        ...     loss.backward()
        ...     optimizer_exp.step()
        >>> # Explain the graph
        >>> feat = g.ndata.pop("h")
        >>> probs, edge_mask, bg, inverse_indices = explainer.explain_node(
        ...     { "user": [0] }, g, feat
        ... )
        """
        assert (
            not self.graph_explanation
        ), '"explain_graph" must be False when initializing the module.'
        assert (
            self.num_hops is not None
        ), '"num_hops" must be provided when initializing the module.'
        self.model = self.model.to(graph.device)
        self.elayers = self.elayers.to(graph.device)
        batched_embed = []
        batched_homo_graph = []
        batched_hetero_graph = []
        for target_ntype, target_nids in nodes.items():
            if isinstance(target_nids, torch.Tensor):
                target_nids = target_nids.tolist()
            for target_nid in target_nids:
                sg, inverse_indices = khop_in_subgraph(
                    graph, {target_ntype: target_nid}, self.num_hops
                )
                for sg_ntype in sg.ntypes:
                    sg_feat = feat[sg_ntype][sg.ndata[NID][sg_ntype].long()]
                    train_mask = [
                        sg_ntype in inverse_indices
                        and node_id in inverse_indices[sg_ntype]
                        for node_id in sg.nodes(sg_ntype)
                    ]
                    sg.nodes[sg_ntype].data["feat"] = sg_feat
                    sg.nodes[sg_ntype].data["train"] = torch.tensor(
                        train_mask, device=sg.device
                    )
                embed = self.model(sg, sg.ndata["feat"], embed=True, **kwargs)
                for ntype in embed.keys():
                    sg.nodes[ntype].data["emb"] = embed[ntype].data
                homo_sg = to_homogeneous(sg, ndata=["emb"])
                homo_sg_embed = homo_sg.ndata["emb"]
                col, row = homo_sg.edges()
                col_emb = homo_sg_embed[col.long()]
                row_emb = homo_sg_embed[row.long()]
                self_emb = homo_sg_embed[
                    inverse_indices[target_ntype][0]
                ].repeat(sg.num_edges(), 1)
                emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
                batched_embed.append(emb)
                batched_homo_graph.append(homo_sg)
                batched_hetero_graph.append(sg)
        batched_homo_graph = batch(batched_homo_graph)
        batched_hetero_graph = batch(batched_hetero_graph)
        batched_embed = torch.cat(batched_embed)
        batched_embed = self.elayers(batched_embed)
        values = batched_embed.reshape(-1)
        values = self.concrete_sample(
            values, beta=temperature, training=training
        )
        self.sparse_mask_values = values
        col, row = batched_homo_graph.edges()
        reverse_eids = batched_homo_graph.edge_ids(row, col).long()
        edge_mask = (values + values[reverse_eids]) / 2
        self.set_masks(batched_homo_graph, edge_mask)
        # Convert the edge mask back into heterogeneous format.
        hetero_edge_mask = self._edge_mask_to_heterogeneous(
            edge_mask=edge_mask,
            homograph=batched_homo_graph,
            heterograph=batched_hetero_graph,
        )
        batched_feats = {
            ntype: batched_hetero_graph.nodes[ntype].data["feat"]
            for ntype in batched_hetero_graph.ntypes
        }
        # The model prediction with the updated edge mask.
        logits = self.model(
            batched_hetero_graph,
            batched_feats,
            edge_weight=hetero_edge_mask,
            **kwargs,
        )
        probs = {
            ntype: F.softmax(logits[ntype], dim=-1) for ntype in logits.keys()
        }
        batched_inverse_indices = {
            ntype: batched_hetero_graph.nodes[ntype]
            .data["train"]
            .nonzero()
            .squeeze(1)
            for ntype in batched_hetero_graph.ntypes
        }
        if training:
            self.batched_feats = batched_feats
            probs = {ntype: probs[ntype].data for ntype in probs.keys()}
        else:
            self.clear_masks()
        return (
            probs,
            hetero_edge_mask,
            batched_hetero_graph,
            batched_inverse_indices,
        ) 
    def _edge_mask_to_heterogeneous(self, edge_mask, homograph, heterograph):
        r"""Convert an edge mask from homogeneous mappings built through
        embeddings into heterogenous format by leveraging the context from
        the source DGLGraphs in homogenous and heterogeneous form.
        The `edge_mask` needs to have been built using the embedding of the
        homogenous graph format for the mappings to work correctly.
        Parameters
        ----------
        edge_mask : dict[str, Tensor]
            A dict mapping node types (keys) to a tensor of edge weights (values).
        homograph : DGLGraph
            The homogeneous form of the source graph.
        heterograph : DGLGraph
            The heterogeneous form of the source graph.
        Returns
        -------
        dict[str, Tensor]
            A dict mapping node types (keys) to tensors of node ids (values)
        """
        return {
            etype: edge_mask[
                (homograph.edata[ETYPE] == heterograph.get_etype_id(etype))
                .nonzero()
                .squeeze(1)
            ]
            for etype in heterograph.canonical_etypes
        }