Source code for dgl.graphbolt.impl.cpu_cached_feature

"""CPU cached feature for GraphBolt."""
from typing import Dict, Optional, Union

import torch

from ..base import get_device_to_host_uva_stream, get_host_to_device_uva_stream
from ..feature_store import (
    bytes_to_number_of_items,
    Feature,
    FeatureKey,
    wrap_with_cached_feature,
)

from .cpu_feature_cache import CPUFeatureCache

__all__ = ["CPUCachedFeature", "cpu_cached_feature"]


[docs] class CPUCachedFeature(Feature): r"""CPU cached feature wrapping a fallback feature. Use `cpu_cached_feature` to construct an instance of this class. Parameters ---------- fallback_feature : Feature The fallback feature. cache : CPUFeatureCache A CPUFeatureCache instance to serve as the cache backend. offset : int, optional The offset value to add to the given ids before using the cache. This parameter is useful if multiple `CPUCachedFeature`s are sharing a single CPUFeatureCache object. """ _cache_type = CPUFeatureCache def __init__( self, fallback_feature: Feature, cache: CPUFeatureCache, offset: int = 0, ): super(CPUCachedFeature, self).__init__() assert isinstance(fallback_feature, Feature), ( f"The fallback_feature must be an instance of Feature, but got " f"{type(fallback_feature)}." ) self._fallback_feature = fallback_feature self._feature = cache self._offset = offset
[docs] def read(self, ids: torch.Tensor = None): """Read the feature by index. Parameters ---------- ids : torch.Tensor, optional The index of the feature. If specified, only the specified indices of the feature are read. If None, the entire feature is returned. Returns ------- torch.Tensor The read feature. """ if ids is None: return self._fallback_feature.read() return self._feature.query_and_replace( ids.cpu(), self._fallback_feature.read, self._offset ).to(ids.device)
[docs] def read_async(self, ids: torch.Tensor): r"""Read the feature by index asynchronously. Parameters ---------- ids : torch.Tensor The index of the feature. Only the specified indices of the feature are read. Returns ------- A generator object. The returned generator object returns a future on ``read_async_num_stages(ids.device)``\ th invocation. The return result can be accessed by calling ``.wait()``. on the returned future object. It is undefined behavior to call ``.wait()`` more than once. Examples -------- >>> import dgl.graphbolt as gb >>> feature = gb.Feature(...) >>> ids = torch.tensor([0, 2]) >>> for stage, future in enumerate(feature.read_async(ids)): ... pass >>> assert stage + 1 == feature.read_async_num_stages(ids.device) >>> result = future.wait() # result contains the read values. """ policy = self._feature._policy cache = self._feature._cache if ids.is_cuda and self.is_pinned(): ids_device = ids.device current_stream = torch.cuda.current_stream() device_to_host_stream = get_device_to_host_uva_stream() device_to_host_stream.wait_stream(current_stream) with torch.cuda.stream(device_to_host_stream): ids.record_stream(torch.cuda.current_stream()) ids = ids.to("cpu", non_blocking=True) ids_copy_event = torch.cuda.Event() ids_copy_event.record() yield # first stage is done. ids_copy_event.synchronize() policy_future = policy.query_and_replace_async(ids, self._offset) yield ( positions, index, pointers, missing_keys, found_offsets, missing_offsets, ) = policy_future.wait() self._feature.total_queries += ids.shape[0] self._feature.total_miss += missing_keys.shape[0] found_cnt = ids.size(0) - missing_keys.size(0) found_positions = positions[:found_cnt] missing_positions = positions[found_cnt:] found_pointers = pointers[:found_cnt] missing_pointers = pointers[found_cnt:] host_to_device_stream = get_host_to_device_uva_stream() with torch.cuda.stream(host_to_device_stream): found_positions = found_positions.to( ids_device, non_blocking=True ) values_from_cpu = cache.index_select(found_positions) values_from_cpu.record_stream(current_stream) values_from_cpu_copy_event = torch.cuda.Event() values_from_cpu_copy_event.record() fallback_reader = self._fallback_feature.read_async(missing_keys) for _ in range( self._fallback_feature.read_async_num_stages( missing_keys.device ) ): missing_values_future = next(fallback_reader, None) yield # fallback feature stages. values_from_cpu_copy_event.synchronize() reading_completed = policy.reading_completed_async( found_pointers, found_offsets ) missing_values = missing_values_future.wait() replace_future = cache.replace_async( missing_positions, missing_values ) host_to_device_stream = get_host_to_device_uva_stream() with torch.cuda.stream(host_to_device_stream): index = index.to(ids_device, non_blocking=True) missing_values = missing_values.to( ids_device, non_blocking=True ) index.record_stream(current_stream) missing_values.record_stream(current_stream) missing_values_copy_event = torch.cuda.Event() missing_values_copy_event.record() yield reading_completed.wait() replace_future.wait() writing_completed = policy.writing_completed_async( missing_pointers, missing_offsets ) class _Waiter: def __init__(self, events, existing, missing, index): self.events = events self.existing = existing self.missing = missing self.index = index def wait(self): """Returns the stored value when invoked.""" for event in self.events: event.wait() values = torch.empty( (self.index.shape[0],) + self.missing.shape[1:], dtype=self.missing.dtype, device=ids_device, ) num_found = self.existing.size(0) found_index = self.index[:num_found] missing_index = self.index[num_found:] values[found_index] = self.existing values[missing_index] = self.missing # Ensure there is no memory leak. self.events = self.existing = None self.missing = self.index = None return values yield _Waiter( [ writing_completed, values_from_cpu_copy_event, missing_values_copy_event, ], values_from_cpu, missing_values, index, ) elif ids.is_cuda: ids_device = ids.device current_stream = torch.cuda.current_stream() device_to_host_stream = get_device_to_host_uva_stream() device_to_host_stream.wait_stream(current_stream) with torch.cuda.stream(device_to_host_stream): ids.record_stream(torch.cuda.current_stream()) ids = ids.to("cpu", non_blocking=True) ids_copy_event = torch.cuda.Event() ids_copy_event.record() yield # first stage is done. ids_copy_event.synchronize() policy_future = policy.query_and_replace_async(ids, self._offset) yield ( positions, index, pointers, missing_keys, found_offsets, missing_offsets, ) = policy_future.wait() self._feature.total_queries += ids.shape[0] self._feature.total_miss += missing_keys.shape[0] found_cnt = ids.size(0) - missing_keys.size(0) found_positions = positions[:found_cnt] missing_positions = positions[found_cnt:] found_pointers = pointers[:found_cnt] missing_pointers = pointers[found_cnt:] values_future = cache.query_async( found_positions, index, ids.shape[0] ) fallback_reader = self._fallback_feature.read_async(missing_keys) for _ in range( self._fallback_feature.read_async_num_stages( missing_keys.device ) ): missing_values_future = next(fallback_reader, None) yield # fallback feature stages. values = values_future.wait() reading_completed = policy.reading_completed_async( found_pointers, found_offsets ) missing_index = index[found_cnt:] missing_values = missing_values_future.wait() replace_future = cache.replace_async( missing_positions, missing_values ) values = torch.ops.graphbolt.scatter_async( values, missing_index, missing_values ) yield host_to_device_stream = get_host_to_device_uva_stream() with torch.cuda.stream(host_to_device_stream): values = values.wait().to(ids_device, non_blocking=True) values.record_stream(current_stream) values_copy_event = torch.cuda.Event() values_copy_event.record() reading_completed.wait() replace_future.wait() writing_completed = policy.writing_completed_async( missing_pointers, missing_offsets ) class _Waiter: def __init__(self, events, values): self.events = events self.values = values def wait(self): """Returns the stored value when invoked.""" for event in self.events: event.wait() values = self.values # Ensure there is no memory leak. self.events = self.values = None return values yield _Waiter([values_copy_event, writing_completed], values) else: policy_future = policy.query_and_replace_async(ids, self._offset) yield ( positions, index, pointers, missing_keys, found_offsets, missing_offsets, ) = policy_future.wait() self._feature.total_queries += ids.shape[0] self._feature.total_miss += missing_keys.shape[0] found_cnt = ids.size(0) - missing_keys.size(0) found_positions = positions[:found_cnt] missing_positions = positions[found_cnt:] found_pointers = pointers[:found_cnt] missing_pointers = pointers[found_cnt:] values_future = cache.query_async( found_positions, index, ids.shape[0] ) fallback_reader = self._fallback_feature.read_async(missing_keys) for _ in range( self._fallback_feature.read_async_num_stages( missing_keys.device ) ): missing_values_future = next(fallback_reader, None) yield # fallback feature stages. values = values_future.wait() reading_completed = policy.reading_completed_async( found_pointers, found_offsets ) missing_index = index[found_cnt:] missing_values = missing_values_future.wait() replace_future = cache.replace_async( missing_positions, missing_values ) values = torch.ops.graphbolt.scatter_async( values, missing_index, missing_values ) yield reading_completed.wait() replace_future.wait() writing_completed = policy.writing_completed_async( missing_pointers, missing_offsets ) class _Waiter: def __init__(self, event, values): self.event = event self.values = values def wait(self): """Returns the stored value when invoked.""" self.event.wait() values = self.values.wait() # Ensure there is no memory leak. self.event = self.values = None return values yield _Waiter(writing_completed, values)
[docs] def read_async_num_stages(self, ids_device: torch.device): """The number of stages of the read_async operation. See read_async function for directions on its use. This function is required to return the number of yield operations when read_async is used with a tensor residing on ids_device. Parameters ---------- ids_device : torch.device The device of the ids parameter passed into read_async. Returns ------- int The number of stages of the read_async operation. """ if ids_device.type == "cuda": return 4 + self._fallback_feature.read_async_num_stages( torch.device("cpu") ) else: return 3 + self._fallback_feature.read_async_num_stages(ids_device)
[docs] def size(self): """Get the size of the feature. Returns ------- torch.Size The size of the feature. """ return self._fallback_feature.size()
[docs] def count(self): """Get the count of the feature. Returns ------- int The count of the feature. """ return self._fallback_feature.count()
[docs] def update(self, value: torch.Tensor, ids: torch.Tensor = None): """Update the feature. Parameters ---------- value : torch.Tensor The updated value of the feature. ids : torch.Tensor, optional The indices of the feature to update. If specified, only the specified indices of the feature will be updated. For the feature, the `ids[i]` row is updated to `value[i]`. So the indices and value must have the same length. If None, the entire feature will be updated. """ if ids is None: feat0 = value[:1] self._fallback_feature.update(value) cache_size = min( bytes_to_number_of_items(self.cache_size_in_bytes, feat0), value.shape[0], ) self._feature = None # Destroy the existing cache first. self._feature = self._cache_type( (cache_size,) + feat0.shape[1:], feat0.dtype ) else: self._fallback_feature.update(value, ids) self._feature.replace(ids, value, None, self._offset)
[docs] def is_pinned(self): """Returns True if the cache storage is pinned.""" return self._feature.is_pinned()
@property def cache_size_in_bytes(self): """Return the size taken by the cache in bytes.""" return self._feature.max_size_in_bytes @property def miss_rate(self): """Returns the cache miss rate since creation.""" return self._feature.miss_rate
[docs] def cpu_cached_feature( fallback_features: Union[Feature, Dict[FeatureKey, Feature]], max_cache_size_in_bytes: int, policy: Optional[str] = None, pin_memory: bool = False, ) -> Union[CPUCachedFeature, Dict[FeatureKey, CPUCachedFeature]]: r"""CPU cached feature wrapping a fallback feature. Parameters ---------- fallback_features : Union[Feature, Dict[FeatureKey, Feature]] The fallback feature(s). max_cache_size_in_bytes : int The capacity of the cache in bytes. The size should be a few factors larger than the size of each read request. Otherwise, the caching policy will hang due to all cache entries being read and/or write locked, resulting in a deadlock. policy : str, optional The cache eviction policy algorithm name. The available policies are ["s3-fifo", "sieve", "lru", "clock"]. Default is "sieve". pin_memory : bool, optional Whether the cache storage should be allocated on system pinned memory. Default is False. Returns ------- Union[CPUCachedFeature, Dict[FeatureKey, CPUCachedFeature]] New feature(s) wrapped with CPUCachedFeature. """ return wrap_with_cached_feature( CPUCachedFeature, fallback_features, max_cache_size_in_bytes, policy=policy, pin_memory=pin_memory, )