"""Utility functions for GraphBolt."""importhashlibimportjsonimportosimportshutilfromtypingimportList,Unionimportnumpyasnpimportpandasaspdimporttorchfromnumpy.lib.formatimportread_array_header_1_0,read_array_header_2_0
[docs]defnumpy_save_aligned(*args,**kwargs):"""A wrapper for numpy.save(), ensures the array is stored 4KiB aligned."""# https://github.com/numpy/numpy/blob/2093a6d5b933f812d15a3de0eafeeb23c61f948a/numpy/lib/format.py#L179has_array_align=hasattr(np.lib.format,"ARRAY_ALIGN")ifhas_array_align:default_alignment=np.lib.format.ARRAY_ALIGN# The maximum allowed alignment by the numpy code linked above is 4K.# Most filesystems work with block sizes of 4K so in practice, the file# size on the disk won't be larger.np.lib.format.ARRAY_ALIGN=4096np.save(*args,**kwargs)ifhas_array_align:np.lib.format.ARRAY_ALIGN=default_alignment
def_read_torch_data(path):returntorch.load(path)def_read_numpy_data(path,in_memory=True):ifin_memory:returntorch.from_numpy(np.load(path))returntorch.as_tensor(np.load(path,mmap_mode="r+"))defread_data(path,fmt,in_memory=True):"""Read data from disk."""iffmt=="torch":return_read_torch_data(path)eliffmt=="numpy":return_read_numpy_data(path,in_memory=in_memory)else:raiseRuntimeError(f"Unsupported format: {fmt}")defsave_data(data,path,fmt):"""Save data into disk."""# Make sure the directory exists.os.makedirs(os.path.dirname(path),exist_ok=True)iffmtnotin["numpy","torch"]:raiseRuntimeError(f"Unsupported format: {fmt}")# Perform necessary conversion.iffmt=="numpy"andisinstance(data,torch.Tensor):data=data.cpu().numpy()eliffmt=="torch"andisinstance(data,np.ndarray):data=torch.from_numpy(data).cpu()# Save the data.iffmt=="numpy":ifnotdata.flags["C_CONTIGUOUS"]:Warning("The ndarray saved to disk is not contiguous, ""so it will be copied to contiguous memory.")data=np.ascontiguousarray(data)numpy_save_aligned(path,data)eliffmt=="torch":ifnotdata.is_contiguous():Warning("The tensor saved to disk is not contiguous, ""so it will be copied to contiguous memory.")data=data.contiguous()torch.save(data,path)defget_npy_dim(npy_path):"""Get the dim of numpy file."""withopen(npy_path,"rb")asf:# For the read_array_header API provided by numpy will only read the# length of the header, it will cause parsing failure and error if# first 8 bytes which contains magin string and version are not read# ahead of time. So, we need to make sure we have skipped these 8# bytes.f.seek(8,0)try:shape,_,_=read_array_header_1_0(f)exceptValueError:try:shape,_,_=read_array_header_2_0(f)exceptValueError:raiseValueError("Invalid file format")returnlen(shape)def_to_int32(data):ifisinstance(data,torch.Tensor):returndata.to(torch.int32)elifisinstance(data,np.ndarray):returndata.astype(np.int32)else:raiseTypeError("Unsupported input type. Please provide a torch tensor or numpy array.")defcopy_or_convert_data(input_path,output_path,input_format,output_format="numpy",in_memory=True,is_feature=False,within_int32=False,):"""Copy or convert the data from input_path to output_path."""assert(output_format=="numpy"),"The output format of the data should be numpy."os.makedirs(os.path.dirname(output_path),exist_ok=True)# We read the data always in case we need to cast its type.data=read_data(input_path,input_format,in_memory)ifwithin_int32:data=_to_int32(data)ifinput_format=="numpy":# If dim of the data is 1, reshape it to n * 1 and save it to output_path.ifis_featureandget_npy_dim(input_path)==1:data=data.reshape(-1,1)# If the data does not need to be modified, just copy the file.elifnotwithin_int32anddata.numpy().flags["C_CONTIGUOUS"]:shutil.copyfile(input_path,output_path)returnelse:# If dim of the data is 1, reshape it to n * 1 and save it to output_path.ifis_featureanddata.dim()==1:data=data.reshape(-1,1)save_data(data,output_path,output_format)defread_edges(dataset_dir,edge_fmt,edge_path):"""Read egde data from numpy or csv."""assertedge_fmtin["numpy","csv",],f"`numpy` or `csv` is expected when reading edges but got `{edge_fmt}`."ifedge_fmt=="numpy":edge_data=read_data(os.path.join(dataset_dir,edge_path),edge_fmt,)assert(edge_data.shape[0]==2andlen(edge_data.shape)==2),f"The shape of edges should be (2, N), but got {edge_data.shape}."src,dst=edge_data.numpy()else:edge_data=pd.read_csv(os.path.join(dataset_dir,edge_path),names=["src","dst"],)src,dst=edge_data["src"].to_numpy(),edge_data["dst"].to_numpy()return(src,dst)defcalculate_file_hash(file_path,hash_algo="md5"):"""Calculate the hash value of a file."""hash_algos=["md5","sha1","sha224","sha256","sha384","sha512"]ifhash_algoinhash_algos:hash_obj=getattr(hashlib,hash_algo)()else:raiseValueError(f"Hash algorithm must be one of: {hash_algos}, but got `{hash_algo}`.")withopen(file_path,"rb")asfile:forchunkiniter(lambda:file.read(4096),b""):hash_obj.update(chunk)returnhash_obj.hexdigest()defcalculate_dir_hash(dir_path,hash_algo="md5",ignore:Union[str,List[str]]=None):"""Calculte the hash values of all files under the directory."""hashes={}fordirpath,_,filenamesinos.walk(dir_path):forfilenameinfilenames:ifignoreandfilenameinignore:continuefilepath=os.path.join(dirpath,filename)file_hash=calculate_file_hash(filepath,hash_algo=hash_algo)hashes[filepath]=file_hashreturnhashesdefcheck_dataset_change(dataset_dir,processed_dir):"""Check whether dataset has been changed by checking its hash value."""hash_value_file="dataset_hash_value.txt"hash_value_file_path=os.path.join(dataset_dir,processed_dir,hash_value_file)ifnotos.path.exists(hash_value_file_path):returnTruewithopen(hash_value_file_path,"r")asf:oringinal_hash_value=json.load(f)present_hash_value=calculate_dir_hash(dataset_dir,ignore=hash_value_file)iforinginal_hash_value==present_hash_value:force_preprocess=Falseelse:force_preprocess=Truereturnforce_preprocess