"""Feature fetchers"""
from functools import partial
from typing import Dict
import torch
from torch.utils.data import functional_datapipe
from .base import etype_tuple_to_str
from .minibatch_transformer import MiniBatchTransformer
__all__ = [
"FeatureFetcher",
"FeatureFetcherStartMarker",
]
def get_feature_key_list(feature_keys, domain):
"""Processes node_feature_keys and extracts their feature keys to a list."""
if isinstance(feature_keys, Dict):
return [
(domain, type_name, feature_name)
for type_name, feature_names in feature_keys.items()
for feature_name in feature_names
]
elif feature_keys is not None:
return [(domain, None, feature_name) for feature_name in feature_keys]
else:
return []
@functional_datapipe("mark_feature_fetcher_start")
class FeatureFetcherStartMarker(MiniBatchTransformer):
"""Used to mark the start of a FeatureFetcher and is a no-op. All the
datapipes created during a FeatureFetcher instantiation are guarenteed to be
contained between FeatureFetcherStartMarker and FeatureFetcher instances in
the datapipe graph.
"""
def __init__(self, datapipe):
super().__init__(datapipe, self._identity)
[docs]
@functional_datapipe("fetch_feature")
class FeatureFetcher(MiniBatchTransformer):
"""A feature fetcher used to fetch features for node/edge in graphbolt.
Functional name: :obj:`fetch_feature`.
Parameters
----------
datapipe : DataPipe
The datapipe.
feature_store : FeatureStore
A storage for features, support read and update.
node_feature_keys : List[str] or Dict[str, List[str]]
Node features keys indicates the node features need to be read.
- If `node_features` is a list: It means the graph is homogeneous
graph, and the 'str' inside are feature names.
- If `node_features` is a dictionary: The keys should be node type
and the values are lists of feature names.
edge_feature_keys : List[str] or Dict[str, List[str]]
Edge features name indicates the edge features need to be read.
- If `edge_features` is a list: It means the graph is homogeneous
graph, and the 'str' inside are feature names.
- If `edge_features` is a dictionary: The keys are edge types,
following the format 'str:str:str', and the values are lists of
feature names.
overlap_fetch : bool, optional
If True, the feature fetcher will overlap the UVA feature fetcher
operations with the rest of operations by using an alternative CUDA
stream or utilizing asynchronous operations. Default is True.
"""
def __init__(
self,
datapipe,
feature_store,
node_feature_keys=None,
edge_feature_keys=None,
overlap_fetch=True,
):
datapipe = datapipe.mark_feature_fetcher_start()
self.feature_store = feature_store
self.node_feature_keys = node_feature_keys
self.edge_feature_keys = edge_feature_keys
max_val = 0
if overlap_fetch:
for feature_key_list in [
get_feature_key_list(node_feature_keys, "node"),
get_feature_key_list(edge_feature_keys, "edge"),
]:
for feature_key in feature_key_list:
if feature_key not in feature_store:
continue
for device_str in ["cpu", "cuda"]:
try:
max_val = max(
feature_store[
feature_key
].read_async_num_stages(
torch.device(device_str)
),
max_val,
)
except AssertionError:
pass
datapipe = datapipe.transform(self._read)
for i in range(max_val, 0, -1):
datapipe = datapipe.transform(
partial(self._execute_stage, i)
).buffer(1)
super().__init__(
datapipe, self._identity if max_val == 0 else self._final_stage
)
# A positive value indicates that the overlap optimization is enabled.
self.max_num_stages = max_val
@staticmethod
def _execute_stage(current_stage, data):
all_features = [data.node_features] + [
data.edge_features[i] for i in range(data.num_layers())
]
for features in all_features:
for key in features:
handle, stage = features[key]
assert current_stage >= stage
if current_stage == stage:
value = next(handle)
features[key] = (handle if stage > 1 else value, stage - 1)
return data
@staticmethod
def _final_stage(data):
all_features = [data.node_features] + [
data.edge_features[i] for i in range(data.num_layers())
]
for features in all_features:
for key in features:
value, stage = features[key]
assert stage == 0
features[key] = value.wait()
return data
def _read(self, data):
"""
Fill in the node/edge features field in data.
Parameters
----------
data : MiniBatch
An instance of :class:`MiniBatch`. Even if 'node_feature' or
'edge_feature' is already filled, it will be overwritten for
overlapping features.
Returns
-------
MiniBatch
An instance of :class:`MiniBatch` filled with required features.
"""
node_features = {}
num_layers = data.num_layers()
edge_features = [{} for _ in range(num_layers)]
is_heterogeneous = isinstance(
self.node_feature_keys, Dict
) or isinstance(self.edge_feature_keys, Dict)
# Read Node features.
input_nodes = data.node_ids()
def read_helper(feature_key, index):
if self.max_num_stages > 0:
feature = self.feature_store[feature_key]
num_stages = feature.read_async_num_stages(index.device)
if num_stages > 0:
return (feature.read_async(index), num_stages)
else: # Asynchronicity is not needed, compute in _final_stage.
class _Waiter:
def __init__(self, feature, index):
self.feature = feature
self.index = index
def wait(self):
"""Returns the stored value when invoked."""
result = self.feature.read(self.index)
# Ensure there is no memory leak.
self.feature = self.index = None
return result
return (_Waiter(feature, index), 0)
else:
domain, type_name, feature_name = feature_key
return self.feature_store.read(
domain, type_name, feature_name, index
)
if self.node_feature_keys and input_nodes is not None:
if is_heterogeneous:
for type_name, nodes in input_nodes.items():
if type_name not in self.node_feature_keys or nodes is None:
continue
for feature_name in self.node_feature_keys[type_name]:
node_features[(type_name, feature_name)] = read_helper(
("node", type_name, feature_name), nodes
)
else:
for feature_name in self.node_feature_keys:
node_features[feature_name] = read_helper(
("node", None, feature_name), input_nodes
)
# Read Edge features.
if self.edge_feature_keys and num_layers > 0:
for i in range(num_layers):
original_edge_ids = data.edge_ids(i)
if is_heterogeneous:
# Convert edge type to string.
original_edge_ids = {
(
etype_tuple_to_str(key)
if isinstance(key, tuple)
else key
): value
for key, value in original_edge_ids.items()
}
for type_name, edges in original_edge_ids.items():
if (
type_name not in self.edge_feature_keys
or edges is None
):
continue
for feature_name in self.edge_feature_keys[type_name]:
edge_features[i][
(type_name, feature_name)
] = read_helper(
("edge", type_name, feature_name), edges
)
else:
for feature_name in self.edge_feature_keys:
edge_features[i][feature_name] = read_helper(
("edge", None, feature_name), original_edge_ids
)
data.set_node_features(node_features)
data.set_edge_features(edge_features)
return data