HeteroGNNExplainer๏
- class dgl.nn.pytorch.explain.HeteroGNNExplainer(model, num_hops, lr=0.01, num_epochs=100, *, alpha1=0.005, alpha2=1.0, beta1=1.0, beta2=0.1, log=True)[source]๏
- Bases: - Module- GNNExplainer model from GNNExplainer: Generating Explanations for Graph Neural Networks, adapted for heterogeneous graphs - It identifies compact subgraph structures and small subsets of node features that play a critical role in GNN-based node classification and graph classification. - To generate an explanation, it learns an edge mask \(M\) and a feature mask \(F\) by optimizing the following objective function. \[l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)\]- where \(l\) is the loss function, \(y\) is the original model prediction, \(\hat{y}\) is the model prediction with the edge and feature mask applied, \(H\) is the entropy function. - Parameters:
- model (nn.Module) โ - The GNN model to explain. - The required arguments of its forward function are graph and feat. The latter one is for input node features. 
- It should also optionally take an eweight argument for edge weights and multiply the messages by it in message passing. 
- The output of its forward function is the logits for the predicted node/graph classes. 
 - See also the example in - explain_node()and- explain_graph().
- num_hops (int) โ The number of hops for GNN information aggregation. 
- lr (float, optional) โ The learning rate to use, default to 0.01. 
- num_epochs (int, optional) โ The number of epochs to train. 
- alpha1 (float, optional) โ A higher value will make the explanation edge masks more sparse by decreasing the sum of the edge mask. 
- alpha2 (float, optional) โ A higher value will make the explanation edge masks more sparse by decreasing the entropy of the edge mask. 
- beta1 (float, optional) โ A higher value will make the explanation node feature masks more sparse by decreasing the mean of the node feature mask. 
- beta2 (float, optional) โ A higher value will make the explanation node feature masks more sparse by decreasing the entropy of the node feature mask. 
- log (bool, optional) โ If True, it will log the computation process, default to True. 
 
 - explain_graph(graph, feat, **kwargs)[source]๏
- Learn and return node feature masks and edge masks that play a crucial role to explain the prediction made by the GNN for a graph. - Parameters:
- graph (DGLGraph) โ A heterogeneous graph that will be explained. 
- feat (dict[str, Tensor]) โ The dictionary that associates input node features (values) with the respective node types (keys) present in the graph. The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\) 
- kwargs (dict) โ Additional arguments passed to the GNN model. 
 
- Returns:
- feat_mask (dict[str, Tensor]) โ The dictionary that associates the learned node feature importance masks (values) with the respective node types (keys). The masks are of shape \((D_t)\), where \(D_t\) is the node feature size for node type - t. The values are within range \((0, 1)\). The higher, the more important.
- edge_mask (dict[Tuple[str], Tensor]) โ The dictionary that associates the learned edge importance masks (values) with the respective canonical edge types (keys). The masks are of shape \((E_t)\), where \(E_t\) is the number of edges for canonical edge type \(t\) in the graph. The values are within range \((0, 1)\). The higher, the more important. 
 
 - Examples - >>> import dgl >>> import dgl.function as fn >>> import torch as th >>> import torch.nn as nn >>> import torch.nn.functional as F >>> from dgl.nn import HeteroGNNExplainer - >>> class Model(nn.Module): ... def __init__(self, in_dim, num_classes, canonical_etypes): ... super(Model, self).__init__() ... self.etype_weights = nn.ModuleDict({ ... '_'.join(c_etype): nn.Linear(in_dim, num_classes) ... for c_etype in canonical_etypes ... }) ... ... def forward(self, graph, feat, eweight=None): ... with graph.local_scope(): ... c_etype_func_dict = {} ... for c_etype in graph.canonical_etypes: ... src_type, etype, dst_type = c_etype ... wh = self.etype_weights['_'.join(c_etype)](feat[src_type]) ... graph.nodes[src_type].data[f'h_{c_etype}'] = wh ... if eweight is None: ... c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'), ... fn.mean('m', 'h')) ... else: ... graph.edges[c_etype].data['w'] = eweight[c_etype] ... c_etype_func_dict[c_etype] = ( ... fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h')) ... graph.multi_update_all(c_etype_func_dict, 'sum') ... hg = 0 ... for ntype in graph.ntypes: ... if graph.num_nodes(ntype): ... hg = hg + dgl.mean_nodes(graph, 'h', ntype=ntype) ... return hg - >>> input_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim) >>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim) - >>> transform = dgl.transforms.AddReverse() >>> g = transform(g) - >>> # define and train the model >>> model = Model(input_dim, num_classes, g.canonical_etypes) >>> feat = g.ndata['h'] >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, feat) ... loss = F.cross_entropy(logits, th.tensor([1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step() - >>> # Explain for the graph >>> explainer = HeteroGNNExplainer(model, num_hops=1) >>> feat_mask, edge_mask = explainer.explain_graph(g, feat) >>> feat_mask {'game': tensor([0.2684, 0.2597, 0.3135, 0.2976, 0.2607]), 'user': tensor([0.2216, 0.2908, 0.2644, 0.2738, 0.2663])} >>> edge_mask {('game', 'rev_plays', 'user'): tensor([0.8922, 0.1966, 0.8371, 0.1330]), ('user', 'plays', 'game'): tensor([0.1785, 0.1696, 0.8065, 0.2167])} 
 - explain_node(ntype, node_id, graph, feat, **kwargs)[source]๏
- Learn and return node feature masks and a subgraph that play a crucial role to explain the prediction made by the GNN for node - node_idof type- ntype.- It requires - modelto return a dictionary mapping node types to type-specific predictions.- Parameters:
- ntype (str) โ The type of the node to explain. - modelmust be trained to make predictions for this particular node type.
- node_id (int) โ The ID of the node to explain. 
- graph (DGLGraph) โ A heterogeneous graph. 
- feat (dict[str, Tensor]) โ The dictionary that associates input node features (values) with the respective node types (keys) present in the graph. The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\) 
- kwargs (dict) โ Additional arguments passed to the GNN model. 
 
- Returns:
- new_node_id (Tensor) โ The new ID of the input center node. 
- sg (DGLGraph) โ The subgraph induced on the k-hop in-neighborhood of the input center node. 
- feat_mask (dict[str, Tensor]) โ The dictionary that associates the learned node feature importance masks (values) with the respective node types (keys). The masks are of shape \((D_t)\), where \(D_t\) is the node feature size for node type - t. The values are within range \((0, 1)\). The higher, the more important.
- edge_mask (dict[Tuple[str], Tensor]) โ The dictionary that associates the learned edge importance masks (values) with the respective canonical edge types (keys). The masks are of shape \((E_t)\), where \(E_t\) is the number of edges for canonical edge type \(t\) in the subgraph. The values are within range \((0, 1)\). The higher, the more important. 
 
 - Examples - >>> import dgl >>> import dgl.function as fn >>> import torch as th >>> import torch.nn as nn >>> import torch.nn.functional as F >>> from dgl.nn import HeteroGNNExplainer - >>> class Model(nn.Module): ... def __init__(self, in_dim, num_classes, canonical_etypes): ... super(Model, self).__init__() ... self.etype_weights = nn.ModuleDict({ ... '_'.join(c_etype): nn.Linear(in_dim, num_classes) ... for c_etype in canonical_etypes ... }) ... ... def forward(self, graph, feat, eweight=None): ... with graph.local_scope(): ... c_etype_func_dict = {} ... for c_etype in graph.canonical_etypes: ... src_type, etype, dst_type = c_etype ... wh = self.etype_weights['_'.join(c_etype)](feat[src_type]) ... graph.nodes[src_type].data[f'h_{c_etype}'] = wh ... if eweight is None: ... c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'), ... fn.mean('m', 'h')) ... else: ... graph.edges[c_etype].data['w'] = eweight[c_etype] ... c_etype_func_dict[c_etype] = ( ... fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h')) ... graph.multi_update_all(c_etype_func_dict, 'sum') ... return graph.ndata['h'] - >>> input_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim) >>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim) - >>> transform = dgl.transforms.AddReverse() >>> g = transform(g) - >>> # define and train the model >>> model = Model(input_dim, num_classes, g.canonical_etypes) >>> feat = g.ndata['h'] >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, feat)['user'] ... loss = F.cross_entropy(logits, th.tensor([1, 1, 1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step() - >>> # Explain the prediction for node 0 of type 'user' >>> explainer = HeteroGNNExplainer(model, num_hops=1) >>> new_center, sg, feat_mask, edge_mask = explainer.explain_node('user', 0, g, feat) >>> new_center tensor([0]) >>> sg Graph(num_nodes={'game': 1, 'user': 1}, num_edges={('game', 'rev_plays', 'user'): 1, ('user', 'plays', 'game'): 1, ('user', 'rev_rev_plays', 'game'): 1}, metagraph=[('game', 'user', 'rev_plays'), ('user', 'game', 'plays'), ('user', 'game', 'rev_rev_plays')]) >>> feat_mask {'game': tensor([0.2348, 0.2780, 0.2611, 0.2513, 0.2823]), 'user': tensor([0.2716, 0.2450, 0.2658, 0.2876, 0.2738])} >>> edge_mask {('game', 'rev_plays', 'user'): tensor([0.0630]), ('user', 'plays', 'game'): tensor([0.1939]), ('user', 'rev_rev_plays', 'game'): tensor([0.9166])} 
 - forward(*input: Any) None๏
- Define the computation performed at every call. - Should be overridden by all subclasses. - Note - Although the recipe for forward pass needs to be defined within this function, one should call the - Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.