Skip to content

Commit

Permalink
done
Browse files Browse the repository at this point in the history
  • Loading branch information
Sagi Polaczek committed Jun 19, 2023
1 parent 35719cf commit bc34093
Show file tree
Hide file tree
Showing 164 changed files with 4,495 additions and 1,201 deletions.
8 changes: 7 additions & 1 deletion fuse/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
get_specific_sample_from_potentially_morphed,
)
from fuse.data.ops.op_base import OpBase # DataTypeForTesting,
from fuse.data.ops.ops_common import OpApplyPatterns, OpLambda, OpFunc, OpRepeat, OpKeepKeypaths
from fuse.data.ops.ops_common import (
OpApplyPatterns,
OpLambda,
OpFunc,
OpRepeat,
OpKeepKeypaths,
)
from fuse.data.ops.ops_aug_common import OpRandApply, OpSample, OpSampleAndRepeat
from fuse.data.ops.ops_read import OpReadDataframe
from fuse.data.ops.ops_cast import OpToTensor, OpToNumpy
Expand Down
106 changes: 82 additions & 24 deletions fuse/data/datasets/caching/samples_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,30 @@

from fuse.data.pipelines.pipeline_default import PipelineDefault
from collections import OrderedDict
from fuse.data.datasets.caching.object_caching_handlers import _object_requires_hdf5_recurse
from fuse.data.datasets.caching.object_caching_handlers import (
_object_requires_hdf5_recurse,
)
from fuse.utils.ndict import NDict
import os
import psutil
from fuse.utils.file_io.file_io import load_hdf5, save_hdf5_safe, load_pickle, save_pickle_safe
from fuse.data import get_sample_id, create_initial_sample, get_specific_sample_from_potentially_morphed
from fuse.utils.file_io.file_io import (
load_hdf5,
save_hdf5_safe,
load_pickle,
save_pickle_safe,
)
from fuse.data import (
get_sample_id,
create_initial_sample,
get_specific_sample_from_potentially_morphed,
)
import hashlib
from fuse.utils.file_io import delete_directory_tree
from glob import glob
from fuse.utils.multiprocessing.run_multiprocessed import run_multiprocessed, get_from_global_storage
from fuse.utils.multiprocessing.run_multiprocessed import (
run_multiprocessed,
get_from_global_storage,
)
from fuse.data.datasets.sample_caching_audit import SampleCachingAudit
from fuse.data.utils.sample import get_initial_sample_id, set_initial_sample_id
from warnings import warn
Expand Down Expand Up @@ -80,7 +94,9 @@ def __init__(
self._write_dir_logic = custom_write_dir_callable

if custom_read_dirs_callable is None:
self._read_dirs_logic = partial(default_read_dirs_logic, cache_dirs=self._cache_dirs)
self._read_dirs_logic = partial(
default_read_dirs_logic, cache_dirs=self._cache_dirs
)
else:
self._read_dirs_logic = custom_read_dirs_callable

Expand All @@ -89,14 +105,19 @@ def __init__(

self._pipeline_desc_text = str(pipeline)
if use_pipeline_hash:
self._pipeline_desc_hash = "hash_" + hashlib.md5(self._pipeline_desc_text.encode("utf-8")).hexdigest()
self._pipeline_desc_hash = (
"hash_"
+ hashlib.md5(self._pipeline_desc_text.encode("utf-8")).hexdigest()
)
else:
self._pipeline_desc_hash = "hash_fixed"

self._verbose = verbose

if self._verbose > 0:
print(f"pipeline description hash for [{unique_name}] is: {self._pipeline_desc_hash}")
print(
f"pipeline description hash for [{unique_name}] is: {self._pipeline_desc_hash}"
)

self._restart_cache = restart_cache
if self._restart_cache:
Expand All @@ -123,14 +144,21 @@ def _verify_no_other_pipelines_cache(self) -> None:
continue
if os.path.basename(found_dir) != self._pipeline_desc_hash:
new_desc = self._pipeline_desc_text
new_file = os.path.join(found_dir, f"pipeline_{self._pipeline_desc_hash}_desc.txt")
new_file = os.path.join(
found_dir, f"pipeline_{self._pipeline_desc_hash}_desc.txt"
)
with open(new_file, "wt") as f:
f.write(new_desc)

pipeline_desc_file = os.path.join(found_dir, f"pipeline_{os.path.basename(found_dir)}_desc.txt")
pipeline_desc_file = os.path.join(
found_dir, f"pipeline_{os.path.basename(found_dir)}_desc.txt"
)
if os.path.exists(pipeline_desc_file):
print("*** Old pipeline description:", pipeline_desc_file)
print("*** New pipeline description (does not match old pipeline):", new_file)
print(
"*** New pipeline description (does not match old pipeline):",
new_file,
)

raise Exception(
f"Found samples cache for pipeline hash {os.path.basename(found_dir)} which is different from the current loaded pipeline hash {self._pipeline_desc_hash} !!\n"
Expand Down Expand Up @@ -168,7 +196,9 @@ def _get_read_dirs(self) -> List[str]:
ans = [os.path.join(x, self._pipeline_desc_hash) for x in ans]
return ans

def cache_samples(self, orig_sample_ids: List[Any]) -> List[Tuple[str, Union[None, List[str]], str]]:
def cache_samples(
self, orig_sample_ids: List[Any]
) -> List[Tuple[str, Union[None, List[str]], str]]:
"""
Go over all of orig_sample_ids, and cache resulting samples
returns information that helps to map from original sample id to the resulting sample id
Expand All @@ -187,9 +217,13 @@ def cache_samples(self, orig_sample_ids: List[Any]) -> List[Tuple[str, Union[Non

read_dirs = self._get_read_dirs()
for curr_read_dir in read_dirs:
fullpath_filename = os.path.join(curr_read_dir, "full_sets_info", hash_filename)
fullpath_filename = os.path.join(
curr_read_dir, "full_sets_info", hash_filename
)
if os.path.isfile(fullpath_filename):
print(f"entire samples set {hash_filename} already cached. Found {os.path.abspath(fullpath_filename)}")
print(
f"entire samples set {hash_filename} already cached. Found {os.path.abspath(fullpath_filename)}"
)
return load_pickle(fullpath_filename)

orig_sid_to_final = OrderedDict()
Expand All @@ -211,7 +245,9 @@ def cache_samples(self, orig_sample_ids: List[Any]) -> List[Tuple[str, Union[Non
set_info_dir = os.path.join(write_dir, "full_sets_info")
os.makedirs(set_info_dir, exist_ok=True)

pipeline_desc_file = os.path.join(write_dir, f"pipeline_{self._pipeline_desc_hash}_desc.txt")
pipeline_desc_file = os.path.join(
write_dir, f"pipeline_{self._pipeline_desc_hash}_desc.txt"
)
if not os.path.exists(pipeline_desc_file):
with open(pipeline_desc_file, "wt") as f:
f.write(self._pipeline_desc_text)
Expand Down Expand Up @@ -240,7 +276,9 @@ def get_orig_sample_id_hash(orig_sample_id: Any) -> str:
orig_sample_id is the original sample_id that was provided, regardless if it turned out to become None, the same sample_id, or different sample_id(s)
"""
orig_sample_id_str = str(orig_sample_id)
if orig_sample_id_str.startswith("<") and orig_sample_id_str.endswith(">"): # and '0x' in orig_sample_id_str
if orig_sample_id_str.startswith("<") and orig_sample_id_str.endswith(
">"
): # and '0x' in orig_sample_id_str
# <__main__.SomeClass at 0x7fc3e6645e20>
raise Exception(
f"You must implement a proper __str__ for orig_sample_id. String representations like <__main__.SomeClass at 0x7fc3e6645e20> are not descriptibe enough and also not persistent between runs. Got: {orig_sample_id_str}"
Expand All @@ -249,7 +287,9 @@ def get_orig_sample_id_hash(orig_sample_id: Any) -> str:
ans = "out_info_for_orig_sample@" + ans
return ans

def load_sample(self, sample_id: Hashable, keys: Optional[Sequence[str]] = None) -> NDict:
def load_sample(
self, sample_id: Hashable, keys: Optional[Sequence[str]] = None
) -> NDict:
"""
:param sample_id: the sample_id of the sample to load
:param keys: optionally, provide a subset of the keys to load in this sample.
Expand All @@ -262,18 +302,24 @@ def load_sample(self, sample_id: Hashable, keys: Optional[Sequence[str]] = None)
if audit_required:
initial_sample_id = get_initial_sample_id(sample_from_cache)
fresh_sample = self._load_sample_using_pipeline(initial_sample_id, keys)
fresh_sample = get_specific_sample_from_potentially_morphed(fresh_sample, sample_id)
fresh_sample = get_specific_sample_from_potentially_morphed(
fresh_sample, sample_id
)

self._audit.audit(sample_from_cache, fresh_sample)

return sample_from_cache

def _load_sample_using_pipeline(self, sample_id: Hashable, keys: Optional[Sequence[str]] = None) -> NDict:
def _load_sample_using_pipeline(
self, sample_id: Hashable, keys: Optional[Sequence[str]] = None
) -> NDict:
sample_dict = create_initial_sample(sample_id)
result_sample = self._pipeline(sample_dict)
return result_sample

def _load_sample_from_cache(self, sample_id: Hashable, keys: Optional[Sequence[str]] = None) -> NDict:
def _load_sample_from_cache(
self, sample_id: Hashable, keys: Optional[Sequence[str]] = None
) -> NDict:
"""
TODO: add comments
"""
Expand All @@ -289,7 +335,9 @@ def _load_sample_from_cache(self, sample_id: Hashable, keys: Optional[Sequence[s
loaded_sample.merge(loaded_sample_hdf5_part)
return loaded_sample

raise Exception(f"Expected to find a cached sample for sample_id={sample_id} but could not find any!")
raise Exception(
f"Expected to find a cached sample for sample_id={sample_id} but could not find any!"
)

@staticmethod
def _cache_worker(orig_sample_id: Any) -> Any:
Expand Down Expand Up @@ -339,21 +387,29 @@ def _cache(self, orig_sample_id: Any) -> Any:
for curr_sample in result_sample:
curr_sample_id = get_sample_id(curr_sample)
output_info.append(curr_sample_id)
output_sample_hash = SamplesCacher.get_final_sample_id_hash(curr_sample_id)
output_sample_hash = SamplesCacher.get_final_sample_id_hash(
curr_sample_id
)

requiring_hdf5_keys = _object_requires_hdf5_recurse(curr_sample)
if len(requiring_hdf5_keys) > 0:
requiring_hdf5_dict = curr_sample.get_multi(requiring_hdf5_keys)
requiring_hdf5_dict = requiring_hdf5_dict.flatten()

hdf5_filename = os.path.join(write_dir, output_sample_hash + ".hdf5")
hdf5_filename = os.path.join(
write_dir, output_sample_hash + ".hdf5"
)
save_hdf5_safe(hdf5_filename, **requiring_hdf5_dict)

# remove all hdf5 entries from the sample_dict that will be pickled
for k in requiring_hdf5_dict:
_ = curr_sample.pop(k)

save_pickle_safe(curr_sample, os.path.join(write_dir, output_sample_hash + ".pkl.gz"), compress=True)
save_pickle_safe(
curr_sample,
os.path.join(write_dir, output_sample_hash + ".pkl.gz"),
compress=True,
)
else:
output_info = None
# requiring_hdf5_keys = None
Expand All @@ -362,7 +418,9 @@ def _cache(self, orig_sample_id: Any) -> Any:
return output_info


def _get_available_write_location(cache_dirs: List[str], max_allowed_used_space: Optional[float] = None) -> str:
def _get_available_write_location(
cache_dirs: List[str], max_allowed_used_space: Optional[float] = None
) -> str:
"""
:param cache_dirs: write directories. Directories are checked in order that they are provided.
:param max_allowed_used_space: set to a value between 0.0 to 1.0.
Expand Down
13 changes: 11 additions & 2 deletions fuse/data/datasets/caching/tests/test_sample_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class OpFakeLoad(OpBase):
def __init__(self) -> None:
super().__init__()

def __call__(self, sample_dict: NDict, **kwargs: dict) -> Union[None, dict, List[dict]]:
def __call__(
self, sample_dict: NDict, **kwargs: dict
) -> Union[None, dict, List[dict]]:
sid = get_sample_id(sample_dict)
if "case_1" == sid:
sample_dict.merge(_generate_sample_1())
Expand Down Expand Up @@ -111,7 +113,14 @@ def test_same_uniquely_named_cache_and_multiple_pipeline_hashes(self) -> None:
(OpFakeLoad(), {}), ###just doubled it to change the pipeline hash
]
pl = PipelineDefault("example_pipeline", pipeline_desc)
self.assertRaises(Exception, SamplesCacher, "unittests_cache", pl, cache_dirs, restart_cache=False)
self.assertRaises(
Exception,
SamplesCacher,
"unittests_cache",
pl,
cache_dirs,
restart_cache=False,
)

def tearDown(self) -> None:
pass
Expand Down
5 changes: 4 additions & 1 deletion fuse/data/datasets/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def summary(self) -> str:

@abstractmethod
def get_multi(
self, items: Optional[Sequence[Union[int, Hashable]]] = None, *args: list, **kwargs: dict
self,
items: Optional[Sequence[Union[int, Hashable]]] = None,
*args: list,
**kwargs: dict
) -> List[Dict]:
"""
Get multiple items, optionally just some of the keys
Expand Down
Loading

0 comments on commit bc34093

Please sign in to comment.