"""Torch Module for SubgraphX"""
import math
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
from .... import to_heterogeneous, to_homogeneous
from ....base import NID
from ....convert import to_networkx
from ....subgraph import node_subgraph
from ....transforms.functional import remove_nodes
__all__ = ["SubgraphX", "HeteroSubgraphX"]
class MCTSNode:
    r"""Monte Carlo Tree Search Node
    Parameters
    ----------
    nodes : Tensor
        The node IDs of the graph that are associated with this tree node
    """
    def __init__(self, nodes):
        self.nodes = nodes
        self.num_visit = 0
        self.total_reward = 0.0
        self.immediate_reward = 0.0
        self.children = []
    def __repr__(self):
        r"""Get the string representation of the node.
        Returns
        -------
        str
            The string representation of the node
        """
        return str(self.nodes)
[docs]
class SubgraphX(nn.Module):
    r"""SubgraphX from `On Explainability of Graph Neural Networks via Subgraph
    Explorations <https://arxiv.org/abs/2102.05152>`
    It identifies the most important subgraph from the original graph that
    plays a critical role in GNN-based graph classification.
    It employs Monte Carlo tree search (MCTS) in efficiently exploring
    different subgraphs for explanation and uses Shapley values as the measure
    of subgraph importance.
    Parameters
    ----------
    model : nn.Module
        The GNN model to explain that tackles multiclass graph classification
        * Its forward function must have the form
          :attr:`forward(self, graph, nfeat)`.
        * The output of its forward function is the logits.
    num_hops : int
        Number of message passing layers in the model
    coef : float, optional
        This hyperparameter controls the trade-off between exploration and
        exploitation. A higher value encourages the algorithm to explore
        relatively unvisited nodes. Default: 10.0
    high2low : bool, optional
        If True, it will use the "High2low" strategy for pruning actions,
        expanding children nodes from high degree to low degree when extending
        the children nodes in the search tree. Otherwise, it will use the
        "Low2high" strategy. Default: True
    num_child : int, optional
        This is the number of children nodes to expand when extending the
        children nodes in the search tree. Default: 12
    num_rollouts : int, optional
        This is the number of rollouts for MCTS. Default: 20
    node_min : int, optional
        This is the threshold to define a leaf node based on the number of
        nodes in a subgraph. Default: 3
    shapley_steps : int, optional
        This is the number of steps for Monte Carlo sampling in estimating
        Shapley values. Default: 100
    log : bool, optional
        If True, it will log the progress. Default: False
    """
    def __init__(
        self,
        model,
        num_hops,
        coef=10.0,
        high2low=True,
        num_child=12,
        num_rollouts=20,
        node_min=3,
        shapley_steps=100,
        log=False,
    ):
        super().__init__()
        self.num_hops = num_hops
        self.coef = coef
        self.high2low = high2low
        self.num_child = num_child
        self.num_rollouts = num_rollouts
        self.node_min = node_min
        self.shapley_steps = shapley_steps
        self.log = log
        self.model = model
    def shapley(self, subgraph_nodes):
        r"""Compute Shapley value with Monte Carlo approximation.
        Parameters
        ----------
        subgraph_nodes : tensor
            The tensor node ids of the subgraph that are associated with this
            tree node
        Returns
        -------
        float
            Shapley value
        """
        num_nodes = self.graph.num_nodes()
        subgraph_nodes = subgraph_nodes.tolist()
        # Obtain neighboring nodes of the subgraph g_i, P'.
        local_region = subgraph_nodes
        for _ in range(self.num_hops - 1):
            in_neighbors, _ = self.graph.in_edges(local_region)
            _, out_neighbors = self.graph.out_edges(local_region)
            neighbors = torch.cat([in_neighbors, out_neighbors]).tolist()
            local_region = list(set(local_region + neighbors))
        split_point = num_nodes
        coalition_space = list(set(local_region) - set(subgraph_nodes)) + [
            split_point
        ]
        marginal_contributions = []
        device = self.feat.device
        for _ in range(self.shapley_steps):
            permuted_space = np.random.permutation(coalition_space)
            split_idx = int(np.where(permuted_space == split_point)[0])
            selected_nodes = permuted_space[:split_idx]
            # Mask for coalition set S_i
            exclude_mask = torch.ones(num_nodes)
            exclude_mask[local_region] = 0.0
            exclude_mask[selected_nodes] = 1.0
            # Mask for set S_i and g_i
            include_mask = exclude_mask.clone()
            include_mask[subgraph_nodes] = 1.0
            exclude_feat = self.feat * exclude_mask.unsqueeze(1).to(device)
            include_feat = self.feat * include_mask.unsqueeze(1).to(device)
            with torch.no_grad():
                exclude_probs = self.model(
                    self.graph, exclude_feat, **self.kwargs
                ).softmax(dim=-1)
                exclude_value = exclude_probs[:, self.target_class]
                include_probs = self.model(
                    self.graph, include_feat, **self.kwargs
                ).softmax(dim=-1)
                include_value = include_probs[:, self.target_class]
            marginal_contributions.append(include_value - exclude_value)
        return torch.cat(marginal_contributions).mean().item()
    def get_mcts_children(self, mcts_node):
        r"""Get the children of the MCTS node for the search.
        Parameters
        ----------
        mcts_node : MCTSNode
            Node in MCTS
        Returns
        -------
        list
            Children nodes after pruning
        """
        if len(mcts_node.children) > 0:
            return mcts_node.children
        subg = node_subgraph(self.graph, mcts_node.nodes)
        node_degrees = subg.out_degrees() + subg.in_degrees()
        k = min(subg.num_nodes(), self.num_child)
        chosen_nodes = torch.topk(
            node_degrees, k, largest=self.high2low
        ).indices
        mcts_children_maps = dict()
        for node in chosen_nodes:
            new_subg = remove_nodes(subg, node.to(subg.idtype), store_ids=True)
            # Get the largest weakly connected component in the subgraph.
            nx_graph = to_networkx(new_subg.cpu())
            largest_cc_nids = list(
                max(nx.weakly_connected_components(nx_graph), key=len)
            )
            # Map to the original node IDs.
            largest_cc_nids = new_subg.ndata[NID][largest_cc_nids].long()
            largest_cc_nids = subg.ndata[NID][largest_cc_nids].sort().values
            if str(largest_cc_nids) not in self.mcts_node_maps:
                child_mcts_node = MCTSNode(largest_cc_nids)
                self.mcts_node_maps[str(child_mcts_node)] = child_mcts_node
            else:
                child_mcts_node = self.mcts_node_maps[str(largest_cc_nids)]
            if str(child_mcts_node) not in mcts_children_maps:
                mcts_children_maps[str(child_mcts_node)] = child_mcts_node
        mcts_node.children = list(mcts_children_maps.values())
        for child_mcts_node in mcts_node.children:
            if child_mcts_node.immediate_reward == 0:
                child_mcts_node.immediate_reward = self.shapley(
                    child_mcts_node.nodes
                )
        return mcts_node.children
    def mcts_rollout(self, mcts_node):
        r"""Perform a MCTS rollout.
        Parameters
        ----------
        mcts_node : MCTSNode
            Starting node for MCTS
        Returns
        -------
        float
            Reward for visiting the node this time
        """
        if len(mcts_node.nodes) <= self.node_min:
            return mcts_node.immediate_reward
        children_nodes = self.get_mcts_children(mcts_node)
        children_visit_sum = sum([child.num_visit for child in children_nodes])
        children_visit_sum_sqrt = math.sqrt(children_visit_sum)
        chosen_child = max(
            children_nodes,
            key=lambda c: c.total_reward / max(c.num_visit, 1)
            + self.coef
            * c.immediate_reward
            * children_visit_sum_sqrt
            / (1 + c.num_visit),
        )
        reward = self.mcts_rollout(chosen_child)
        chosen_child.num_visit += 1
        chosen_child.total_reward += reward
        return reward
[docs]
    def explain_graph(self, graph, feat, target_class, **kwargs):
        r"""Find the most important subgraph from the original graph for the
        model to classify the graph into the target class.
        Parameters
        ----------
        graph : DGLGraph
            A homogeneous graph
        feat : Tensor
            The input node feature of shape :math:`(N, D)`, :math:`N` is the
            number of nodes, and :math:`D` is the feature size
        target_class : int
            The target class to explain
        kwargs : dict
            Additional arguments passed to the GNN model
        Returns
        -------
        Tensor
            Nodes that represent the most important subgraph
        Examples
        --------
        >>> import torch
        >>> import torch.nn as nn
        >>> import torch.nn.functional as F
        >>> from dgl.data import GINDataset
        >>> from dgl.dataloading import GraphDataLoader
        >>> from dgl.nn import GraphConv, AvgPooling, SubgraphX
        >>> # Define the model
        >>> class Model(nn.Module):
        ...     def __init__(self, in_dim, n_classes, hidden_dim=128):
        ...         super().__init__()
        ...         self.conv1 = GraphConv(in_dim, hidden_dim)
        ...         self.conv2 = GraphConv(hidden_dim, n_classes)
        ...         self.pool = AvgPooling()
        ...
        ...     def forward(self, g, h):
        ...         h = F.relu(self.conv1(g, h))
        ...         h = self.conv2(g, h)
        ...         return self.pool(g, h)
        >>> # Load dataset
        >>> data = GINDataset('MUTAG', self_loop=True)
        >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
        >>> # Train the model
        >>> feat_size = data[0][0].ndata['attr'].shape[1]
        >>> model = Model(feat_size, data.gclasses)
        >>> criterion = nn.CrossEntropyLoss()
        >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
        >>> for bg, labels in dataloader:
        ...     logits = model(bg, bg.ndata['attr'])
        ...     loss = criterion(logits, labels)
        ...     optimizer.zero_grad()
        ...     loss.backward()
        ...     optimizer.step()
        >>> # Initialize the explainer
        >>> explainer = SubgraphX(model, num_hops=2)
        >>> # Explain the prediction for graph 0
        >>> graph, l = data[0]
        >>> graph_feat = graph.ndata.pop("attr")
        >>> g_nodes_explain = explainer.explain_graph(graph, graph_feat,
        ...                                           target_class=l)
        """
        self.model.eval()
        assert (
            graph.num_nodes() > self.node_min
        ), f"The number of nodes in the\
            graph {graph.num_nodes()} should be bigger than {self.node_min}."
        self.graph = graph
        self.feat = feat
        self.target_class = target_class
        self.kwargs = kwargs
        # book all nodes in MCTS
        self.mcts_node_maps = dict()
        root = MCTSNode(graph.nodes())
        self.mcts_node_maps[str(root)] = root
        for i in range(self.num_rollouts):
            if self.log:
                print(
                    f"Rollout {i}/{self.num_rollouts}, \
                    {len(self.mcts_node_maps)} subgraphs have been explored."
                )
            self.mcts_rollout(root)
        best_leaf = None
        best_immediate_reward = float("-inf")
        for mcts_node in self.mcts_node_maps.values():
            if len(mcts_node.nodes) > self.node_min:
                continue
            if mcts_node.immediate_reward > best_immediate_reward:
                best_leaf = mcts_node
                best_immediate_reward = best_leaf.immediate_reward
        return best_leaf.nodes 
 
[docs]
class HeteroSubgraphX(nn.Module):
    r"""SubgraphX from `On Explainability of Graph Neural Networks via Subgraph
    Explorations <https://arxiv.org/abs/2102.05152>`__, adapted for heterogeneous graphs
    It identifies the most important subgraph from the original graph that
    plays a critical role in GNN-based graph classification.
    It employs Monte Carlo tree search (MCTS) in efficiently exploring
    different subgraphs for explanation and uses Shapley values as the measure
    of subgraph importance.
    Parameters
    ----------
    model : nn.Module
        The GNN model to explain that tackles multiclass graph classification
        * Its forward function must have the form
          :attr:`forward(self, graph, nfeat)`.
        * The output of its forward function is the logits.
    num_hops : int
        Number of message passing layers in the model
    coef : float, optional
        This hyperparameter controls the trade-off between exploration and
        exploitation. A higher value encourages the algorithm to explore
        relatively unvisited nodes. Default: 10.0
    high2low : bool, optional
        If True, it will use the "High2low" strategy for pruning actions,
        expanding children nodes from high degree to low degree when extending
        the children nodes in the search tree. Otherwise, it will use the
        "Low2high" strategy. Default: True
    num_child : int, optional
        This is the number of children nodes to expand when extending the
        children nodes in the search tree. Default: 12
    num_rollouts : int, optional
        This is the number of rollouts for MCTS. Default: 20
    node_min : int, optional
        This is the threshold to define a leaf node based on the number of
        nodes in a subgraph. Default: 3
    shapley_steps : int, optional
        This is the number of steps for Monte Carlo sampling in estimating
        Shapley values. Default: 100
    log : bool, optional
        If True, it will log the progress. Default: False
    """
    def __init__(
        self,
        model,
        num_hops,
        coef=10.0,
        high2low=True,
        num_child=12,
        num_rollouts=20,
        node_min=3,
        shapley_steps=100,
        log=False,
    ):
        super().__init__()
        self.num_hops = num_hops
        self.coef = coef
        self.high2low = high2low
        self.num_child = num_child
        self.num_rollouts = num_rollouts
        self.node_min = node_min
        self.shapley_steps = shapley_steps
        self.log = log
        self.model = model
    def shapley(self, subgraph_nodes):
        r"""Compute Shapley value with Monte Carlo approximation.
        Parameters
        ----------
        subgraph_nodes : dict[str, Tensor]
            subgraph_nodes[nty] gives the tensor node IDs of node type nty
            in the subgraph, which are associated with this tree node
        Returns
        -------
        float
            Shapley value
        """
        # Obtain neighboring nodes of the subgraph g_i, P'.
        local_regions = {
            ntype: nodes.tolist() for ntype, nodes in subgraph_nodes.items()
        }
        for _ in range(self.num_hops - 1):
            for c_etype in self.graph.canonical_etypes:
                src_ntype, _, dst_ntype = c_etype
                if (
                    src_ntype not in local_regions
                    or dst_ntype not in local_regions
                ):
                    continue
                in_neighbors, _ = self.graph.in_edges(
                    local_regions[dst_ntype], etype=c_etype
                )
                _, out_neighbors = self.graph.out_edges(
                    local_regions[src_ntype], etype=c_etype
                )
                local_regions[src_ntype] = list(
                    set(local_regions[src_ntype] + in_neighbors.tolist())
                )
                local_regions[dst_ntype] = list(
                    set(local_regions[dst_ntype] + out_neighbors.tolist())
                )
        split_point = self.graph.num_nodes()
        coalition_space = {
            ntype: list(
                set(local_regions[ntype]) - set(subgraph_nodes[ntype].tolist())
            )
            + [split_point]
            for ntype in subgraph_nodes.keys()
        }
        marginal_contributions = []
        for _ in range(self.shapley_steps):
            selected_node_map = dict()
            for ntype, nodes in coalition_space.items():
                permuted_space = np.random.permutation(nodes)
                split_idx = int(np.where(permuted_space == split_point)[0])
                selected_node_map[ntype] = permuted_space[:split_idx]
            # Mask for coalition set S_i
            exclude_mask = {
                ntype: torch.ones(self.graph.num_nodes(ntype))
                for ntype in self.graph.ntypes
            }
            for ntype, region in local_regions.items():
                exclude_mask[ntype][region] = 0.0
            for ntype, selected_nodes in selected_node_map.items():
                exclude_mask[ntype][selected_nodes] = 1.0
            # Mask for set S_i and g_i
            include_mask = {
                ntype: exclude_mask[ntype].clone()
                for ntype in self.graph.ntypes
            }
            for ntype, subgn in subgraph_nodes.items():
                exclude_mask[ntype][subgn] = 1.0
            exclude_feat = {
                ntype: self.feat[ntype]
                * exclude_mask[ntype].unsqueeze(1).to(self.feat[ntype].device)
                for ntype in self.graph.ntypes
            }
            include_feat = {
                ntype: self.feat[ntype]
                * include_mask[ntype].unsqueeze(1).to(self.feat[ntype].device)
                for ntype in self.graph.ntypes
            }
            with torch.no_grad():
                exclude_probs = self.model(
                    self.graph, exclude_feat, **self.kwargs
                ).softmax(dim=-1)
                exclude_value = exclude_probs[:, self.target_class]
                include_probs = self.model(
                    self.graph, include_feat, **self.kwargs
                ).softmax(dim=-1)
                include_value = include_probs[:, self.target_class]
            marginal_contributions.append(include_value - exclude_value)
        return torch.cat(marginal_contributions).mean().item()
    def get_mcts_children(self, mcts_node):
        r"""Get the children of the MCTS node for the search.
        Parameters
        ----------
        mcts_node : MCTSNode
            Node in MCTS
        Returns
        -------
        list
            Children nodes after pruning
        """
        if len(mcts_node.children) > 0:
            return mcts_node.children
        subg = node_subgraph(self.graph, mcts_node.nodes)
        # Choose k nodes based on the highest degree in the subgraph
        node_degrees_map = {
            ntype: torch.zeros(
                subg.num_nodes(ntype), device=subg.nodes(ntype).device
            )
            for ntype in subg.ntypes
        }
        for c_etype in subg.canonical_etypes:
            src_ntype, _, dst_ntype = c_etype
            node_degrees_map[src_ntype] += subg.out_degrees(etype=c_etype)
            node_degrees_map[dst_ntype] += subg.in_degrees(etype=c_etype)
        node_degrees_list = [
            ((ntype, i), degree)
            for ntype, node_degrees in node_degrees_map.items()
            for i, degree in enumerate(node_degrees)
        ]
        node_degrees = torch.stack([v for _, v in node_degrees_list])
        k = min(subg.num_nodes(), self.num_child)
        chosen_node_indicies = torch.topk(
            node_degrees, k, largest=self.high2low
        ).indices
        chosen_nodes = [node_degrees_list[i][0] for i in chosen_node_indicies]
        mcts_children_maps = dict()
        for ntype, node in chosen_nodes:
            new_subg = remove_nodes(subg, node, ntype, store_ids=True)
            if new_subg.num_edges() > 0:
                new_subg_homo = to_homogeneous(new_subg)
                # Get the largest weakly connected component in the subgraph.
                nx_graph = to_networkx(new_subg_homo.cpu())
                largest_cc_nids = list(
                    max(nx.weakly_connected_components(nx_graph), key=len)
                )
                largest_cc_homo = node_subgraph(new_subg_homo, largest_cc_nids)
                largest_cc_hetero = to_heterogeneous(
                    largest_cc_homo, new_subg.ntypes, new_subg.etypes
                )
                # Follow steps for backtracking to original graph node ids
                # 1. retrieve instanced homograph from connected-component homograph
                # 2. retrieve instanced heterograph from instanced homograph
                # 3. retrieve hetero-subgraph from instanced heterograph
                # 4. retrieve orignal graph ids from subgraph node ids
                cc_nodes = {
                    ntype: subg.ndata[NID][ntype][
                        new_subg.ndata[NID][ntype][
                            new_subg_homo.ndata[NID][
                                largest_cc_homo.ndata[NID][indicies]
                            ]
                        ]
                    ]
                    for ntype, indicies in largest_cc_hetero.ndata[NID].items()
                }
            else:
                available_ntypes = [
                    ntype
                    for ntype in new_subg.ntypes
                    if new_subg.num_nodes(ntype) > 0
                ]
                chosen_ntype = np.random.choice(available_ntypes)
                # backtrack from subgraph node ids to entire graph
                chosen_node = subg.ndata[NID][chosen_ntype][
                    np.random.choice(new_subg.nodes[chosen_ntype].data[NID])
                ]
                cc_nodes = {
                    chosen_ntype: torch.tensor(
                        [chosen_node],
                        device=subg.device,
                    )
                }
            if str(cc_nodes) not in self.mcts_node_maps:
                child_mcts_node = MCTSNode(cc_nodes)
                self.mcts_node_maps[str(child_mcts_node)] = child_mcts_node
            else:
                child_mcts_node = self.mcts_node_maps[str(cc_nodes)]
            if str(child_mcts_node) not in mcts_children_maps:
                mcts_children_maps[str(child_mcts_node)] = child_mcts_node
        mcts_node.children = list(mcts_children_maps.values())
        for child_mcts_node in mcts_node.children:
            if child_mcts_node.immediate_reward == 0:
                child_mcts_node.immediate_reward = self.shapley(
                    child_mcts_node.nodes
                )
        return mcts_node.children
    def mcts_rollout(self, mcts_node):
        r"""Perform a MCTS rollout.
        Parameters
        ----------
        mcts_node : MCTSNode
            Starting node for MCTS
        Returns
        -------
        float
            Reward for visiting the node this time
        """
        if (
            sum(len(nodes) for nodes in mcts_node.nodes.values())
            <= self.node_min
        ):
            return mcts_node.immediate_reward
        children_nodes = self.get_mcts_children(mcts_node)
        children_visit_sum = sum([child.num_visit for child in children_nodes])
        children_visit_sum_sqrt = math.sqrt(children_visit_sum)
        chosen_child = max(
            children_nodes,
            key=lambda c: c.total_reward / max(c.num_visit, 1)
            + self.coef
            * c.immediate_reward
            * children_visit_sum_sqrt
            / (1 + c.num_visit),
        )
        reward = self.mcts_rollout(chosen_child)
        chosen_child.num_visit += 1
        chosen_child.total_reward += reward
        return reward
[docs]
    def explain_graph(self, graph, feat, target_class, **kwargs):
        r"""Find the most important subgraph from the original graph for the
        model to classify the graph into the target class.
        Parameters
        ----------
        graph : DGLGraph
            A heterogeneous graph
        feat : dict[str, Tensor]
            The dictionary that associates input node features (values) with
            the respective node types (keys) present in the graph.
            The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is the
            number of nodes for node type :math:`t`, and :math:`D_t` is the feature size for
            node type :math:`t`
        target_class : int
            The target class to explain
        kwargs : dict
            Additional arguments passed to the GNN model
        Returns
        -------
        dict[str, Tensor]
            The dictionary associating tensor node ids (values) to
            node types (keys) that represents the most important subgraph
        Examples
        --------
        >>> import dgl
        >>> import dgl.function as fn
        >>> import torch as th
        >>> import torch.nn as nn
        >>> import torch.nn.functional as F
        >>> from dgl.nn import HeteroSubgraphX
        >>> class Model(nn.Module):
        ...     def __init__(self, in_dim, num_classes, canonical_etypes):
        ...         super(Model, self).__init__()
        ...         self.etype_weights = nn.ModuleDict(
        ...             {
        ...                 "_".join(c_etype): nn.Linear(in_dim, num_classes)
        ...                 for c_etype in canonical_etypes
        ...             }
        ...         )
        ...
        ...     def forward(self, graph, feat):
        ...         with graph.local_scope():
        ...             c_etype_func_dict = {}
        ...             for c_etype in graph.canonical_etypes:
        ...                 src_type, etype, dst_type = c_etype
        ...                 wh = self.etype_weights["_".join(c_etype)](feat[src_type])
        ...                 graph.nodes[src_type].data[f"h_{c_etype}"] = wh
        ...                 c_etype_func_dict[c_etype] = (
        ...                     fn.copy_u(f"h_{c_etype}", "m"),
        ...                     fn.mean("m", "h"),
        ...                 )
        ...             graph.multi_update_all(c_etype_func_dict, "sum")
        ...             hg = 0
        ...             for ntype in graph.ntypes:
        ...                 if graph.num_nodes(ntype):
        ...                     hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype)
        ...             return hg
        >>> input_dim = 5
        >>> num_classes = 2
        >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
        >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
        >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
        >>> transform = dgl.transforms.AddReverse()
        >>> g = transform(g)
        >>> # define and train the model
        >>> model = Model(input_dim, num_classes, g.canonical_etypes)
        >>> feat = g.ndata["h"]
        >>> optimizer = th.optim.Adam(model.parameters())
        >>> for epoch in range(10):
        ...     logits = model(g, feat)
        ...     loss = F.cross_entropy(logits, th.tensor([1]))
        ...     optimizer.zero_grad()
        ...     loss.backward()
        ...     optimizer.step()
        >>> # Explain for the graph
        >>> explainer = HeteroSubgraphX(model, num_hops=1)
        >>> explainer.explain_graph(g, feat, target_class=1)
        {'game': tensor([0, 1]), 'user': tensor([1, 2])}
        """
        self.model.eval()
        assert (
            graph.num_nodes() > self.node_min
        ), f"The number of nodes in the\
            graph {graph.num_nodes()} should be bigger than {self.node_min}."
        self.graph = graph
        self.feat = feat
        self.target_class = target_class
        self.kwargs = kwargs
        # book all nodes in MCTS
        self.mcts_node_maps = dict()
        root_dict = {ntype: graph.nodes(ntype) for ntype in graph.ntypes}
        root = MCTSNode(root_dict)
        self.mcts_node_maps[str(root)] = root
        for i in range(self.num_rollouts):
            if self.log:
                print(
                    f"Rollout {i}/{self.num_rollouts}, \
                    {len(self.mcts_node_maps)} subgraphs have been explored."
                )
            self.mcts_rollout(root)
        best_leaf = None
        best_immediate_reward = float("-inf")
        for mcts_node in self.mcts_node_maps.values():
            if len(mcts_node.nodes) > self.node_min:
                continue
            if mcts_node.immediate_reward > best_immediate_reward:
                best_leaf = mcts_node
                best_immediate_reward = best_leaf.immediate_reward
        return best_leaf.nodes