Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pyproject.toml #323

Merged
merged 2 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion fuse/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
get_specific_sample_from_potentially_morphed,
)
from fuse.data.ops.op_base import OpBase # DataTypeForTesting,
from fuse.data.ops.ops_common import OpApplyPatterns, OpLambda, OpFunc, OpRepeat, OpKeepKeypaths
from fuse.data.ops.ops_common import (
OpApplyPatterns,
OpLambda,
OpFunc,
OpRepeat,
OpKeepKeypaths,
)
from fuse.data.ops.ops_aug_common import OpRandApply, OpSample, OpSampleAndRepeat
from fuse.data.ops.ops_read import OpReadDataframe
from fuse.data.ops.ops_cast import OpToTensor, OpToNumpy
Expand Down
106 changes: 82 additions & 24 deletions fuse/data/datasets/caching/samples_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,30 @@

from fuse.data.pipelines.pipeline_default import PipelineDefault
from collections import OrderedDict
from fuse.data.datasets.caching.object_caching_handlers import _object_requires_hdf5_recurse
from fuse.data.datasets.caching.object_caching_handlers import (
_object_requires_hdf5_recurse,
)
from fuse.utils.ndict import NDict
import os
import psutil
from fuse.utils.file_io.file_io import load_hdf5, save_hdf5_safe, load_pickle, save_pickle_safe
from fuse.data import get_sample_id, create_initial_sample, get_specific_sample_from_potentially_morphed
from fuse.utils.file_io.file_io import (
load_hdf5,
save_hdf5_safe,
load_pickle,
save_pickle_safe,
)
from fuse.data import (
get_sample_id,
create_initial_sample,
get_specific_sample_from_potentially_morphed,
)
import hashlib
from fuse.utils.file_io import delete_directory_tree
from glob import glob
from fuse.utils.multiprocessing.run_multiprocessed import run_multiprocessed, get_from_global_storage
from fuse.utils.multiprocessing.run_multiprocessed import (
run_multiprocessed,
get_from_global_storage,
)
from fuse.data.datasets.sample_caching_audit import SampleCachingAudit
from fuse.data.utils.sample import get_initial_sample_id, set_initial_sample_id
from warnings import warn
Expand Down Expand Up @@ -80,7 +94,9 @@ def __init__(
self._write_dir_logic = custom_write_dir_callable

if custom_read_dirs_callable is None:
self._read_dirs_logic = partial(default_read_dirs_logic, cache_dirs=self._cache_dirs)
self._read_dirs_logic = partial(
default_read_dirs_logic, cache_dirs=self._cache_dirs
)
else:
self._read_dirs_logic = custom_read_dirs_callable

Expand All @@ -89,14 +105,19 @@ def __init__(

self._pipeline_desc_text = str(pipeline)
if use_pipeline_hash:
self._pipeline_desc_hash = "hash_" + hashlib.md5(self._pipeline_desc_text.encode("utf-8")).hexdigest()
self._pipeline_desc_hash = (
"hash_"
+ hashlib.md5(self._pipeline_desc_text.encode("utf-8")).hexdigest()
)
else:
self._pipeline_desc_hash = "hash_fixed"

self._verbose = verbose

if self._verbose > 0:
print(f"pipeline description hash for [{unique_name}] is: {self._pipeline_desc_hash}")
print(
f"pipeline description hash for [{unique_name}] is: {self._pipeline_desc_hash}"
)

self._restart_cache = restart_cache
if self._restart_cache:
Expand All @@ -123,14 +144,21 @@ def _verify_no_other_pipelines_cache(self) -> None:
continue
if os.path.basename(found_dir) != self._pipeline_desc_hash:
new_desc = self._pipeline_desc_text
new_file = os.path.join(found_dir, f"pipeline_{self._pipeline_desc_hash}_desc.txt")
new_file = os.path.join(
found_dir, f"pipeline_{self._pipeline_desc_hash}_desc.txt"
)
with open(new_file, "wt") as f:
f.write(new_desc)

pipeline_desc_file = os.path.join(found_dir, f"pipeline_{os.path.basename(found_dir)}_desc.txt")
pipeline_desc_file = os.path.join(
found_dir, f"pipeline_{os.path.basename(found_dir)}_desc.txt"
)
if os.path.exists(pipeline_desc_file):
print("*** Old pipeline description:", pipeline_desc_file)
print("*** New pipeline description (does not match old pipeline):", new_file)
print(
"*** New pipeline description (does not match old pipeline):",
new_file,
)

raise Exception(
f"Found samples cache for pipeline hash {os.path.basename(found_dir)} which is different from the current loaded pipeline hash {self._pipeline_desc_hash} !!\n"
Expand Down Expand Up @@ -168,7 +196,9 @@ def _get_read_dirs(self) -> List[str]:
ans = [os.path.join(x, self._pipeline_desc_hash) for x in ans]
return ans

def cache_samples(self, orig_sample_ids: List[Any]) -> List[Tuple[str, Union[None, List[str]], str]]:
def cache_samples(
self, orig_sample_ids: List[Any]
) -> List[Tuple[str, Union[None, List[str]], str]]:
"""
Go over all of orig_sample_ids, and cache resulting samples
returns information that helps to map from original sample id to the resulting sample id
Expand All @@ -187,9 +217,13 @@ def cache_samples(self, orig_sample_ids: List[Any]) -> List[Tuple[str, Union[Non

read_dirs = self._get_read_dirs()
for curr_read_dir in read_dirs:
fullpath_filename = os.path.join(curr_read_dir, "full_sets_info", hash_filename)
fullpath_filename = os.path.join(
curr_read_dir, "full_sets_info", hash_filename
)
if os.path.isfile(fullpath_filename):
print(f"entire samples set {hash_filename} already cached. Found {os.path.abspath(fullpath_filename)}")
print(
f"entire samples set {hash_filename} already cached. Found {os.path.abspath(fullpath_filename)}"
)
return load_pickle(fullpath_filename)

orig_sid_to_final = OrderedDict()
Expand All @@ -211,7 +245,9 @@ def cache_samples(self, orig_sample_ids: List[Any]) -> List[Tuple[str, Union[Non
set_info_dir = os.path.join(write_dir, "full_sets_info")
os.makedirs(set_info_dir, exist_ok=True)

pipeline_desc_file = os.path.join(write_dir, f"pipeline_{self._pipeline_desc_hash}_desc.txt")
pipeline_desc_file = os.path.join(
write_dir, f"pipeline_{self._pipeline_desc_hash}_desc.txt"
)
if not os.path.exists(pipeline_desc_file):
with open(pipeline_desc_file, "wt") as f:
f.write(self._pipeline_desc_text)
Expand Down Expand Up @@ -240,7 +276,9 @@ def get_orig_sample_id_hash(orig_sample_id: Any) -> str:
orig_sample_id is the original sample_id that was provided, regardless if it turned out to become None, the same sample_id, or different sample_id(s)
"""
orig_sample_id_str = str(orig_sample_id)
if orig_sample_id_str.startswith("<") and orig_sample_id_str.endswith(">"): # and '0x' in orig_sample_id_str
if orig_sample_id_str.startswith("<") and orig_sample_id_str.endswith(
">"
): # and '0x' in orig_sample_id_str
# <__main__.SomeClass at 0x7fc3e6645e20>
raise Exception(
f"You must implement a proper __str__ for orig_sample_id. String representations like <__main__.SomeClass at 0x7fc3e6645e20> are not descriptibe enough and also not persistent between runs. Got: {orig_sample_id_str}"
Expand All @@ -249,7 +287,9 @@ def get_orig_sample_id_hash(orig_sample_id: Any) -> str:
ans = "out_info_for_orig_sample@" + ans
return ans

def load_sample(self, sample_id: Hashable, keys: Optional[Sequence[str]] = None) -> NDict:
def load_sample(
self, sample_id: Hashable, keys: Optional[Sequence[str]] = None
) -> NDict:
"""
:param sample_id: the sample_id of the sample to load
:param keys: optionally, provide a subset of the keys to load in this sample.
Expand All @@ -262,18 +302,24 @@ def load_sample(self, sample_id: Hashable, keys: Optional[Sequence[str]] = None)
if audit_required:
initial_sample_id = get_initial_sample_id(sample_from_cache)
fresh_sample = self._load_sample_using_pipeline(initial_sample_id, keys)
fresh_sample = get_specific_sample_from_potentially_morphed(fresh_sample, sample_id)
fresh_sample = get_specific_sample_from_potentially_morphed(
fresh_sample, sample_id
)

self._audit.audit(sample_from_cache, fresh_sample)

return sample_from_cache

def _load_sample_using_pipeline(self, sample_id: Hashable, keys: Optional[Sequence[str]] = None) -> NDict:
def _load_sample_using_pipeline(
self, sample_id: Hashable, keys: Optional[Sequence[str]] = None
) -> NDict:
sample_dict = create_initial_sample(sample_id)
result_sample = self._pipeline(sample_dict)
return result_sample

def _load_sample_from_cache(self, sample_id: Hashable, keys: Optional[Sequence[str]] = None) -> NDict:
def _load_sample_from_cache(
self, sample_id: Hashable, keys: Optional[Sequence[str]] = None
) -> NDict:
"""
TODO: add comments
"""
Expand All @@ -289,7 +335,9 @@ def _load_sample_from_cache(self, sample_id: Hashable, keys: Optional[Sequence[s
loaded_sample.merge(loaded_sample_hdf5_part)
return loaded_sample

raise Exception(f"Expected to find a cached sample for sample_id={sample_id} but could not find any!")
raise Exception(
f"Expected to find a cached sample for sample_id={sample_id} but could not find any!"
)

@staticmethod
def _cache_worker(orig_sample_id: Any) -> Any:
Expand Down Expand Up @@ -339,21 +387,29 @@ def _cache(self, orig_sample_id: Any) -> Any:
for curr_sample in result_sample:
curr_sample_id = get_sample_id(curr_sample)
output_info.append(curr_sample_id)
output_sample_hash = SamplesCacher.get_final_sample_id_hash(curr_sample_id)
output_sample_hash = SamplesCacher.get_final_sample_id_hash(
curr_sample_id
)

requiring_hdf5_keys = _object_requires_hdf5_recurse(curr_sample)
if len(requiring_hdf5_keys) > 0:
requiring_hdf5_dict = curr_sample.get_multi(requiring_hdf5_keys)
requiring_hdf5_dict = requiring_hdf5_dict.flatten()

hdf5_filename = os.path.join(write_dir, output_sample_hash + ".hdf5")
hdf5_filename = os.path.join(
write_dir, output_sample_hash + ".hdf5"
)
save_hdf5_safe(hdf5_filename, **requiring_hdf5_dict)

# remove all hdf5 entries from the sample_dict that will be pickled
for k in requiring_hdf5_dict:
_ = curr_sample.pop(k)

save_pickle_safe(curr_sample, os.path.join(write_dir, output_sample_hash + ".pkl.gz"), compress=True)
save_pickle_safe(
curr_sample,
os.path.join(write_dir, output_sample_hash + ".pkl.gz"),
compress=True,
)
else:
output_info = None
# requiring_hdf5_keys = None
Expand All @@ -362,7 +418,9 @@ def _cache(self, orig_sample_id: Any) -> Any:
return output_info


def _get_available_write_location(cache_dirs: List[str], max_allowed_used_space: Optional[float] = None) -> str:
def _get_available_write_location(
cache_dirs: List[str], max_allowed_used_space: Optional[float] = None
) -> str:
"""
:param cache_dirs: write directories. Directories are checked in order that they are provided.
:param max_allowed_used_space: set to a value between 0.0 to 1.0.
Expand Down
13 changes: 11 additions & 2 deletions fuse/data/datasets/caching/tests/test_sample_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class OpFakeLoad(OpBase):
def __init__(self) -> None:
super().__init__()

def __call__(self, sample_dict: NDict, **kwargs: dict) -> Union[None, dict, List[dict]]:
def __call__(
self, sample_dict: NDict, **kwargs: dict
) -> Union[None, dict, List[dict]]:
sid = get_sample_id(sample_dict)
if "case_1" == sid:
sample_dict.merge(_generate_sample_1())
Expand Down Expand Up @@ -111,7 +113,14 @@ def test_same_uniquely_named_cache_and_multiple_pipeline_hashes(self) -> None:
(OpFakeLoad(), {}), ###just doubled it to change the pipeline hash
]
pl = PipelineDefault("example_pipeline", pipeline_desc)
self.assertRaises(Exception, SamplesCacher, "unittests_cache", pl, cache_dirs, restart_cache=False)
self.assertRaises(
Exception,
SamplesCacher,
"unittests_cache",
pl,
cache_dirs,
restart_cache=False,
)

def tearDown(self) -> None:
pass
Expand Down
5 changes: 4 additions & 1 deletion fuse/data/datasets/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def summary(self) -> str:

@abstractmethod
def get_multi(
self, items: Optional[Sequence[Union[int, Hashable]]] = None, *args: list, **kwargs: dict
self,
items: Optional[Sequence[Union[int, Hashable]]] = None,
*args: list,
**kwargs: dict
) -> List[Dict]:
"""
Get multiple items, optionally just some of the keys
Expand Down
Loading