"""Module for message propagation."""from__future__importabsolute_importfrom.importbackendasF,traversalastrvfrom.heterographimportDGLGraph__all__=["prop_nodes","prop_nodes_bfs","prop_nodes_topo","prop_edges","prop_edges_dfs",]
[docs]defprop_nodes(graph,nodes_generator,message_func="default",reduce_func="default",apply_node_func="default",):"""Functional method for :func:`dgl.DGLGraph.prop_nodes`. Parameters ---------- node_generators : generator The generator of node frontiers. message_func : callable, optional The message function. reduce_func : callable, optional The reduce function. apply_node_func : callable, optional The update function. See Also -------- dgl.DGLGraph.prop_nodes """graph.prop_nodes(nodes_generator,message_func,reduce_func,apply_node_func)
[docs]defprop_edges(graph,edges_generator,message_func="default",reduce_func="default",apply_node_func="default",):"""Functional method for :func:`dgl.DGLGraph.prop_edges`. Parameters ---------- edges_generator : generator The generator of edge frontiers. message_func : callable, optional The message function. reduce_func : callable, optional The reduce function. apply_node_func : callable, optional The update function. See Also -------- dgl.DGLGraph.prop_edges """graph.prop_edges(edges_generator,message_func,reduce_func,apply_node_func)
[docs]defprop_nodes_bfs(graph,source,message_func,reduce_func,reverse=False,apply_node_func=None,):"""Message propagation using node frontiers generated by BFS. Parameters ---------- graph : DGLGraph The graph object. source : list, tensor of nodes Source nodes. message_func : callable The message function. reduce_func : callable The reduce function. reverse : bool, optional If true, traverse following the in-edge direction. apply_node_func : callable, optional The update function. See Also -------- dgl.traversal.bfs_nodes_generator """assertisinstance(graph,DGLGraph),"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph"assert(len(graph.canonical_etypes)==1),"prop_nodes_bfs only support homogeneous graph"# TODO(murphy): Graph traversal currently is only supported on# CPP graphs. Move graph to CPU as a workaround,# which should be fixed in the future.nodes_gen=trv.bfs_nodes_generator(graph.cpu(),source,reverse)nodes_gen=[F.copy_to(frontier,graph.device)forfrontierinnodes_gen]prop_nodes(graph,nodes_gen,message_func,reduce_func,apply_node_func)
[docs]defprop_nodes_topo(graph,message_func,reduce_func,reverse=False,apply_node_func=None):"""Message propagation using node frontiers generated by topological order. Parameters ---------- graph : DGLGraph The graph object. message_func : callable The message function. reduce_func : callable The reduce function. reverse : bool, optional If true, traverse following the in-edge direction. apply_node_func : callable, optional The update function. See Also -------- dgl.traversal.topological_nodes_generator """assertisinstance(graph,DGLGraph),"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph"assert(len(graph.canonical_etypes)==1),"prop_nodes_topo only support homogeneous graph"# TODO(murphy): Graph traversal currently is only supported on# CPP graphs. Move graph to CPU as a workaround,# which should be fixed in the future.nodes_gen=trv.topological_nodes_generator(graph.cpu(),reverse)nodes_gen=[F.copy_to(frontier,graph.device)forfrontierinnodes_gen]prop_nodes(graph,nodes_gen,message_func,reduce_func,apply_node_func)
[docs]defprop_edges_dfs(graph,source,message_func,reduce_func,reverse=False,has_reverse_edge=False,has_nontree_edge=False,apply_node_func=None,):"""Message propagation using edge frontiers generated by labeled DFS. Parameters ---------- graph : DGLGraph The graph object. source : list, tensor of nodes Source nodes. message_func : callable, optional The message function. reduce_func : callable, optional The reduce function. reverse : bool, optional If true, traverse following the in-edge direction. has_reverse_edge : bool, optional If true, REVERSE edges are included. has_nontree_edge : bool, optional If true, NONTREE edges are included. apply_node_func : callable, optional The update function. See Also -------- dgl.traversal.dfs_labeled_edges_generator """assertisinstance(graph,DGLGraph),"DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph"assert(len(graph.canonical_etypes)==1),"prop_edges_dfs only support homogeneous graph"# TODO(murphy): Graph traversal currently is only supported on# CPP graphs. Move graph to CPU as a workaround,# which should be fixed in the future.edges_gen=trv.dfs_labeled_edges_generator(graph.cpu(),source,reverse,has_reverse_edge,has_nontree_edge,return_labels=False,)edges_gen=[F.copy_to(frontier,graph.device)forfrontierinedges_gen]prop_edges(graph,edges_gen,message_func,reduce_func,apply_node_func)