diff --git a/include/sourmash.h b/include/sourmash.h index b5e111b662..56224dc90b 100644 --- a/include/sourmash.h +++ b/include/sourmash.h @@ -263,6 +263,8 @@ uintptr_t nodegraph_get_kmer(const SourmashNodegraph *ptr, const char *kmer); const uint64_t *nodegraph_hashsizes(const SourmashNodegraph *ptr, uintptr_t *size); +uintptr_t nodegraph_intersection_count(const SourmashNodegraph *ptr, const SourmashNodegraph *optr); + uintptr_t nodegraph_ksize(const SourmashNodegraph *ptr); uintptr_t nodegraph_matches(const SourmashNodegraph *ptr, const SourmashKmerMinHash *mh_ptr); diff --git a/sourmash/cli/index.py b/sourmash/cli/index.py index 57fcb8ef40..b71e295a1a 100644 --- a/sourmash/cli/index.py +++ b/sourmash/cli/index.py @@ -55,7 +55,9 @@ def subparser(subparsers): ) subparser.add_argument( '-x', '--bf-size', metavar='S', type=float, default=1e5, - help='Bloom filter size used for internal nodes' + help='Maximum Bloom filter size used for internal nodes. Set to 0 ' + 'to calculate an optimal size given the complexity of the datasets ' + '(using a HyperLogLog to estimate the number of unique k-mers' ) subparser.add_argument( '-f', '--force', action='store_true', diff --git a/sourmash/commands.py b/sourmash/commands.py index 73b39adef8..15fccfe9eb 100644 --- a/sourmash/commands.py +++ b/sourmash/commands.py @@ -327,10 +327,20 @@ def index(args): set_quiet(args.quiet) moltype = sourmash_args.calculate_moltype(args) + batch = False if args.append: tree = load_sbt_index(args.sbt_name) else: - tree = create_sbt_index(args.bf_size, n_children=args.n_children) + bf_size = args.bf_size + if bf_size == 0: + bf_size = None + + tree = create_sbt_index(bf_size, n_children=args.n_children) + batch = True + + # TODO: set up storage here + storage_info = tree._setup_storage(args.sbt_name) + tree.storage = storage_info.storage if args.sparseness < 0 or args.sparseness > 1.0: error('sparseness must be in range [0.0, 1.0].') @@ -375,7 +385,7 @@ def index(args): ss.minhash = ss.minhash.downsample(scaled=args.scaled) scaleds.add(ss.minhash.scaled) - tree.insert(ss) + tree.insert(ss, batch=batch) n += 1 if not ss: @@ -398,6 +408,8 @@ def index(args): error('nums = {}; scaleds = {}', repr(nums), repr(scaleds)) sys.exit(-1) + tree = tree.finish() + notify('') # did we load any!? @@ -406,6 +418,10 @@ def index(args): sys.exit(-1) notify('loaded {} sigs; saving SBT under "{}"', n, args.sbt_name) + # TODO: if all nodes are already saved (like in the scaffold/batch case) + # we can potentially set structure_only=True here. An alternative is to + # modify Node.save to verify if the data is already saved or still need to + # be saved (dirty flag?) tree.save(args.sbt_name, sparseness=args.sparseness) diff --git a/sourmash/logging.py b/sourmash/logging.py index 49c3dc26b3..097c0e194b 100644 --- a/sourmash/logging.py +++ b/sourmash/logging.py @@ -1,13 +1,27 @@ +import atexit +import os import sys from io import StringIO _quiet = False _debug = False def set_quiet(val, print_debug=False): - global _quiet, _debug + global _quiet, _debug, _trace _quiet = bool(val) _debug = bool(print_debug) +#_trace = True if "SOURMASH_TRACE" in os.environ else False +if "SOURMASH_TRACE" in os.environ: + _trace = open(os.environ["SOURMASH_TRACE"], "w") + + @atexit.register + def flush_and_close(): + global _trace + _trace.flush() + _trace.close() +else: + _trace = None + def print_results(s, *args, **kwargs): if _quiet: @@ -41,6 +55,17 @@ def debug(s, *args, **kwargs): sys.stderr.flush() +def trace(s, *args, **kwargs): + "Low level execution information (even more verbose than debug)" + if not _trace: + return + + print(s.format(*args, **kwargs), file=_trace, + end=kwargs.get('end', u'\n')) + if kwargs.get('flush'): + sys.stderr.flush() + + def error(s, *args, **kwargs): "A simple error logging function => stderr." print(u'\r\033[K', end=u'', file=sys.stderr) diff --git a/sourmash/nodegraph.py b/sourmash/nodegraph.py index 8faa2eb874..2b6b1d897f 100644 --- a/sourmash/nodegraph.py +++ b/sourmash/nodegraph.py @@ -1,5 +1,7 @@ # -*- coding: UTF-8 -*- +from collections import namedtuple +import math from struct import pack, unpack import sys from tempfile import NamedTemporaryFile @@ -86,6 +88,14 @@ def matches(self, mh): return self._methodcall(lib.nodegraph_matches, mh._objptr) + def intersection_count(self, other): + if isinstance(other, Nodegraph): + return self._methodcall(lib.nodegraph_intersection_count, other._objptr) + else: + # FIXME: we could take MinHash and sets here too (or anything that can be + # converted to a list of ints...) + raise TypeError("Must be a Nodegraph") + def to_khmer_nodegraph(self): import khmer try: @@ -159,3 +169,62 @@ def calc_expected_collisions(graph, force=False, max_false_pos=.2): raise SystemExit(1) return fp_all + + +def optimal_size(num_kmers, mem_cap=None, fp_rate=None): + """ + Utility function for estimating optimal nodegraph args. + + - num_kmers: number of unique kmers [required] + - mem_cap: the allotted amount of memory [optional, conflicts with f] + - fp_rate: the desired false positive rate [optional, conflicts with M] + """ + if all((num_kmers is not None, mem_cap is not None, fp_rate is None)): + return estimate_optimal_with_K_and_M(num_kmers, mem_cap) + elif all((num_kmers is not None, mem_cap is None, fp_rate is not None)): + return estimate_optimal_with_K_and_f(num_kmers, fp_rate) + else: + raise TypeError("num_kmers and either mem_cap or fp_rate" + " must be defined.") + + +def estimate_optimal_with_K_and_M(num_kmers, mem_cap): + """ + Estimate optimal nodegraph args. + + - num_kmers: number of unique kmer + - mem_cap: the allotted amount of memory + """ + n_tables = math.log(2) * (mem_cap / float(num_kmers)) + int_n_tables = int(n_tables) + if int_n_tables == 0: + int_n_tables = 1 + ht_size = int(mem_cap / int_n_tables) + mem_cap = ht_size * int_n_tables + fp_rate = (1 - math.exp(-num_kmers / float(ht_size))) ** int_n_tables + res = namedtuple("result", ["num_htables", "htable_size", "mem_use", + "fp_rate"]) + return res(int_n_tables, ht_size, mem_cap, fp_rate) + + +def estimate_optimal_with_K_and_f(num_kmers, des_fp_rate): + """ + Estimate optimal memory. + + - num_kmers: the number of unique kmers + - des_fp_rate: the desired false positive rate + """ + n_tables = math.log(des_fp_rate, 0.5) + int_n_tables = int(n_tables) + if int_n_tables == 0: + int_n_tables = 1 + + ht_size = int(-num_kmers / ( + math.log(1 - des_fp_rate ** (1 / float(int_n_tables))))) + ht_size = max(ht_size, 1) + mem_cap = ht_size * int_n_tables + fp_rate = (1 - math.exp(-num_kmers / float(ht_size))) ** int_n_tables + + res = namedtuple("result", ["num_htables", "htable_size", "mem_use", + "fp_rate"]) + return res(int_n_tables, ht_size, mem_cap, fp_rate) diff --git a/sourmash/sbt.py b/sourmash/sbt.py index 9cae16b693..ead8b6c66e 100644 --- a/sourmash/sbt.py +++ b/sourmash/sbt.py @@ -44,6 +44,8 @@ def search_transcript(node, seq, threshold): from collections import namedtuple, Counter from collections.abc import Mapping +import hashlib +import heapq from copy import copy import json @@ -56,9 +58,10 @@ def search_transcript(node, seq, threshold): from .exceptions import IndexNotSupported from .sbt_storage import FSStorage, IPFSStorage, RedisStorage, ZipStorage -from .logging import error, notify, debug +from .logging import error, notify, debug, trace from .index import Index -from .nodegraph import Nodegraph, extract_nodegraph_info, calc_expected_collisions +from .nodegraph import Nodegraph, extract_nodegraph_info, calc_expected_collisions, optimal_size +from .hll import HLL STORAGES = { 'FSStorage': FSStorage, @@ -69,6 +72,9 @@ def search_transcript(node, seq, threshold): NodePos = namedtuple("NodePos", ["pos", "node"]) +StorageInfo = namedtuple("StorageInfo", + ["kind", "storage", "backend", "name", "subdir", + "storage_args", "index_filename"]) class GraphFactory(object): @@ -177,6 +183,7 @@ def __init__(self, factory, *, d=2, storage=None, cache_size=None): self._nodes = {} self._missing_nodes = set() self._leaves = {} + self._batched = [] self.d = d self.next_node = 0 self.storage = storage @@ -229,12 +236,31 @@ def new_node_pos(self, node): return self.next_node - def insert(self, signature): + def insert(self, signature, batch=False): "Add a new SourmashSignature in to the SBT." from .sbtmh import SigLeaf - + leaf = SigLeaf(signature.md5sum(), signature) - self.add_node(leaf) + + if batch: + self._batched.append(leaf) + else: + self.add_node(leaf) + + def finish(self): + # If not in batch mode, nothing to do + if not self._batched: + return self + + # TODO: check if SBT is empty, if it is not add current leaves to be + # scaffolded too + tree = scaffold(self._batched, self.storage, self.factory) + self._batched.clear() + + # TODO: make this inplace? + # self.__dict__.update(tree.__dict__) + # return self + return tree def add_node(self, node): pos = self.new_node_pos(node) @@ -317,9 +343,9 @@ def find(self, search_fn, *args, **kwargs): if node_p not in visited: visited.add(node_p) + trace("(TRAVERSAL) {0}", node_p) # apply search fn. If return false, truncate search. if search_fn(node_g, *args): - # leaf node? it's a match! if isinstance(node_g, Leaf): matches.append(node_g) @@ -327,9 +353,10 @@ def find(self, search_fn, *args, **kwargs): elif isinstance(node_g, Node): if kwargs.get('dfs', True): # defaults search to dfs for c in self.children(node_p): - queue.insert(0, c.pos) + if c.node: + queue.insert(0, c.pos) else: # bfs - queue.extend(c.pos for c in self.children(node_p)) + queue.extend(c.pos for c in self.children(node_p) if c.node) if unload_data: node_g.unload() @@ -533,40 +560,14 @@ def child(self, parent, pos): node = self._nodes.get(cd, None) return NodePos(cd, node) - def save(self, path, storage=None, sparseness=0.0, structure_only=False): - """Saves an SBT description locally and node data to a storage. - - Parameters - ---------- - path : str - path to where the SBT description should be saved. - storage : Storage, optional - Storage to be used for saving node data. - Defaults to FSStorage (a hidden directory at the same level of path) - sparseness : float - How much of the internal nodes should be saved. - Defaults to 0.0 (save all internal nodes data), - can go up to 1.0 (don't save any internal nodes data) - structure_only: boolean - Write only the index schema and metadata, but not the data. - Defaults to False (save data too) - - Returns - ------- - str - full path to the new SBT description - """ - info = {} - info['d'] = self.d - info['version'] = 6 - info["index_type"] = self.__class__.__name__ # TODO: check - + def _setup_storage(self, path, storage=None): # choose between ZipStorage and FS (file system/directory) storage. if not path.endswith(".sbt.json"): kind = "Zip" if not path.endswith('.sbt.zip'): path += '.sbt.zip' - storage = ZipStorage(path) + if storage is None: + storage = ZipStorage(path) backend = "FSStorage" name = os.path.basename(path[:-8]) subdir = '.sbt.{}'.format(name) @@ -576,6 +577,7 @@ def save(self, path, storage=None, sparseness=0.0, structure_only=False): else: # path.endswith('.sbt.json') assert path.endswith('.sbt.json') kind = "FS" + subdir = None name = os.path.basename(path) name = name[:-9] index_filename = os.path.abspath(path) @@ -591,9 +593,43 @@ def save(self, path, storage=None, sparseness=0.0, structure_only=False): backend = [k for (k, v) in STORAGES.items() if v == type(storage)][0] storage_args = storage.init_args() + return StorageInfo(kind, storage, backend, name, subdir, storage_args, index_filename) + + def save(self, path, storage=None, sparseness=0.0, structure_only=False): + """Saves an SBT description locally and node data to a storage. + + Parameters + ---------- + path : str + path to where the SBT description should be saved. + storage : Storage, optional + Storage to be used for saving node data. + Defaults to FSStorage (a hidden directory at the same level of path) + sparseness : float + How much of the internal nodes should be saved. + Defaults to 0.0 (save all internal nodes data), + can go up to 1.0 (don't save any internal nodes data) + structure_only: boolean + Write only the index schema and metadata, but not the data. + Defaults to False (save data too) + + Returns + ------- + str + full path to the new SBT description + """ + info = {} + info['d'] = self.d + info['version'] = 7 + info["index_type"] = self.__class__.__name__ # TODO: check + + if storage is None and self.storage is not None: + storage = self.storage + sinfo = self._setup_storage(path, storage) + info['storage'] = { - 'backend': backend, - 'args': storage_args + 'backend': sinfo.backend, + 'args': sinfo.storage_args } info['factory'] = { 'class': GraphFactory.__name__, @@ -611,29 +647,27 @@ def save(self, path, storage=None, sparseness=0.0, structure_only=False): if random() - sparseness <= 0: continue - data = { - # TODO: start using md5sum instead? - 'filename': os.path.basename(node.name), - 'name': node.name - } - + data = {"name": node.name} try: node.metadata.pop('max_n_below') except (AttributeError, KeyError): pass - data['metadata'] = node.metadata if structure_only is False: # trigger data loading before saving to the new place node.data - node.storage = storage + node.storage = sinfo.storage - if kind == "Zip": - node.save(os.path.join(subdir, data['filename'])) - elif kind == "FS": - data['filename'] = node.save(data['filename']) + basepath = None + if sinfo.kind == "Zip": + basepath = sinfo.subdir + + data["filename"] = _save_node(node, basepath) + else: + # TODO: a dry-run mode for calculating the name? + data["filename"] = node._path if isinstance(node, Node): nodes[i] = data @@ -647,17 +681,17 @@ def save(self, path, storage=None, sparseness=0.0, structure_only=False): info['nodes'] = nodes info['signatures'] = leaves - if kind == "Zip": + if sinfo.kind == "Zip": tree_data = json.dumps(info).encode("utf-8") - save_path = "{}.sbt.json".format(name) - storage.save(save_path, tree_data) - storage.close() + save_path = "{}.sbt.json".format(sinfo.name) + sinfo.storage.save(save_path, tree_data) + sinfo.storage.close() - elif kind == "FS": - with open(index_filename, 'w') as fp: + elif sinfo.kind == "FS": + with open(sinfo.index_filename, 'w') as fp: json.dump(info, fp) - notify("Finished saving SBT index, available at {0}\n".format(index_filename)) + notify("Finished saving SBT index, available at {0}\n".format(sinfo.index_filename)) return path @@ -738,6 +772,7 @@ def load(cls, location, *, leaf_loader=None, storage=None, print_version_warning 4: cls._load_v4, 5: cls._load_v5, 6: cls._load_v6, + 7: cls._load_v6, } try: @@ -1076,7 +1111,7 @@ def print_dot(self): for i, node in self._nodes.items(): if isinstance(node, Node): print('"{}" [shape=box fillcolor=gray style=filled]'.format( - node.name)) + ode.name)) for j, child in self.children(i): if child is not None: print('"{}" -> "{}"'.format(node.name, child.name)) @@ -1087,13 +1122,15 @@ def print(self): while stack: node_p = stack.pop() node_g = self._nodes.get(node_p, None) + if node_g is None: + node_g = self._leaves.get(node_p, None) if node_p not in visited and node_g is not None: visited.add(node_p) depth = int(math.floor(math.log(node_p + 1, self.d))) - print(" " * 4 * depth, node_g) + print(" " * 4 * depth, node_p, node_g) if isinstance(node_g, Node): - stack.extend(c.pos for c in self.children(node_p) - if c.pos not in visited) + stack.extend(reversed([c.pos for c in self.children(node_p) + if c.pos not in visited])) def __iter__(self): for i, node in self._nodes.items(): @@ -1171,9 +1208,17 @@ def __str__(self): name=self.name, nb=self.data.n_occupied(), fpr=calc_expected_collisions(self.data, True, 1.1)) - def save(self, path): + def save(self, subdir=None): buf = self.data.to_bytes(compression=1) - return self.storage.save(path, buf) + hash_md5 = hashlib.md5() + hash_md5.update(buf) + path = "internal/" + hash_md5.hexdigest() + + if subdir is not None: + path = os.path.join(subdir, path) + + self._path = self.storage.save(path, buf) + return self._path @property def data(self): @@ -1250,9 +1295,17 @@ def unload(self): # TODO: Check that data is actually in the storage? self._data = None - def save(self, path): + def save(self, subdir=None): buf = self.data.to_bytes(compression=1) - return self.storage.save(path, buf) + hash_md5 = hashlib.md5() + hash_md5.update(buf) + path = hash_md5.hexdigest() + + if subdir is not None: + path = os.path.join(subdir, path) + + self._path = self.storage.save("signatures/" + path, buf) + return self._path def update(self, parent): parent.data.update(self.data) @@ -1265,6 +1318,340 @@ def load(cls, info, storage=None): storage=storage) +def scaffold(original_datasets, storage, factory=None): + """ Generate an SBT with nodes clustered by amount of shared hashes + + Parameters + ---------- + + datasets: List[Union[SourmashSignature, Leaf]] + List of signatures, or already existing leaves from another SBT, to index + storage: Storage + Where to save the nodes data during construction + factory: Factory, optional + A factory for internal nodes. If None, will use a HyperLogLog counter + to estimate unique k-mers and calculate an optimal size for Nodegraphs. + + Returns + ------- + SBT + A newly initialized SBT with nodes clustered by amount of shared hashes + """ + + from .sbtmh import SigLeaf + from .signature import SourmashSignature + + InternalNode = namedtuple("InternalNode", "element left right") + BinaryLeaf = namedtuple("BinaryLeaf", "element") + THRESHOLD = 0 + + hll = None + ksize = None + + subdir = None + if isinstance(storage, ZipStorage): + if storage.subdir is None: + name = storage.path.split('/')[-1][:-8] + storage.subdir = '.sbt.{}'.format(name) + subdir = storage.subdir + + datasets = [] + for d in original_datasets: + if isinstance(d, SourmashSignature): + d = SigLeaf(d.md5sum(), d) + + # save MH to storage + if isinstance(d, SigLeaf): + d.data + if storage is not None: + d.storage = storage + _save_node(d, subdir) + datasets.append(d) + else: + raise ValueError("unknown dataset type") + del original_datasets + + # TODO: we can build the heap in parallel, if the data was + # pickle-able for multiprocessing... + # on top of doing the count_common calculations in parallel, + # we can also avoid building the heap (just build a list first) + # and then call heapify on it after the list is ready + heap = [] + for i, data1 in enumerate(datasets): + if i % 100 == 0: + print(f"processed {i} out of {len(datasets)}", end='\r') + + d1 = data1.data.minhash + + if hll is None and factory is None: + ksize = d1.ksize + hll = HLL(0.01, ksize) + + if hll is not None: + hll.update(d1) + + for j, data2 in enumerate(datasets): + if i > j: + d2 = data2.data.minhash + common = d1.count_common(d2) + + # TODO: check for a low threshold to insert + # (need to change logic further down before setting threshold) + if common >= THRESHOLD: + # heapq defaults to a min heap, + # invert "common" here to avoid having to use the + # internal methods for a max heap + heap.append((-common, i, j)) + heapq.heapify(heap) + + if factory is None: + n_unique_hashes = len(hll) + num_htables, htable_size, mem_use, fp_rate = optimal_size(n_unique_hashes, fp_rate=0.9) + # TODO: check this, we prefer ntables = 1? + #htable_size *= num_htables + print(len(hll), num_htables, htable_size) + + # TODO: turns out len(hll) is too big in most cases. + # need a better heuristic for optimal size... + htable_size = 1e5 + num_htables = 1 + + factory = GraphFactory(ksize, htable_size, num_htables) + + print("Processing first round of internal") + processed = set() + total_datasets = len(datasets) + next_round = [] + while heap: + i = len(heap) + if i % 100 == 0: + print(f"processed {i} out of {total_datasets}", end='\r') + + (_, p1, p2) = heapq.heappop(heap) + if p1 not in processed and p2 not in processed: + data1 = datasets[p1] + data2 = datasets[p2] + datasets[p1] = None + datasets[p2] = None + processed.add(p1) + processed.add(p2) + + # Name will be updated later, when we have the right position + new_node = Node(factory) + data1.update(new_node) + data2.update(new_node) + + new_internal = InternalNode(new_node, BinaryLeaf(data1), BinaryLeaf(data2)) + next_round.append(new_internal) + + # unload d1 and d2 + data1.unload() + data2.unload() + + elif p1 in processed and p2 in processed: + # already processed both, nothing to do + continue + else: + # at least one (p1, p2) is still valid. + # only process if it is the last one + if total_datasets - len(processed) == 1: + d = datasets[p1] + if d is None: + d = datasets[p2] + assert d is not None + datasets[p2] = None + processed.add(p2) + else: + datasets[p1] = None + processed.add(p1) + + # Name will be updated later, when we have the right position + new_node = Node(factory) + d.update(new_node) + + new_internal = InternalNode(new_node, BinaryLeaf(d), None) + next_round.append(new_internal) + + # unload d + d.unload() + + # Finish processing leaves, start dealing with internal nodes in next_round + while len(next_round) > 1: + print("Processing next round of internal") + current_round = next_round + next_round = [] + total_nodes = len(current_round) + + # TODO: we can build the heap in parallel, if the data was + # pickle-able for multiprocessing... + heap = [] + for (i, d1) in enumerate(current_round): + for (j, d2) in enumerate(current_round): + if i > j: + common = d1.element.data.intersection_count(d2.element.data) + heap.append((-common, i, j)) + heapq.heapify(heap) + + processed = set() + while heap: + i = len(heap) + if i % 100 == 0: + print(f"processed {i} out of {total_nodes}", end='\r') + + (_, p1, p2) = heapq.heappop(heap) + if p1 not in processed and p2 not in processed: + d1 = current_round[p1] + d2 = current_round[p2] + current_round[p1] = None + current_round[p2] = None + processed.add(p1) + processed.add(p2) + + new_node = Node(factory) + d1.element.update(new_node) + d2.element.update(new_node) + + new_internal = InternalNode(new_node, d1, d2) + next_round.append(new_internal) + + # save d1 and d2 Nodes into storage and unload them, if a storage is available + if storage is not None: + d1.element.storage = storage + _save_node(d1.element) + d1.element.unload() + + d2.element.storage = storage + _save_node(d2.element) + d2.element.unload() + + elif p1 in processed and p2 in processed: + # already processed both, nothing to do + continue + else: + # at least one (p1, p2) is still valid. + # process it if it is the last element left + if total_nodes - len(processed) == 1: + d = current_round[p1] + if d is None: + d = current_round[p2] + current_round[p2] = None + assert d is not None + processed.add(p2) + else: + processed.add(p1) + current_round[p1] = None + + new_node = Node(factory) + d.element.update(new_node) + + new_internal = InternalNode(new_node, d, None) + next_round.append(new_internal) + + # save d into storage and unload it, if a storage is available + if storage is not None: + d.element.storage = storage + _save_node(d.element) + d.element.unload() + + # next_round only contains the root of the SBT + # Convert from binary tree to nodes/leaves + if total_datasets == 1: + d = datasets.pop() + common = factory() + common.update(d.data.minhash) + root = InternalNode(common, BinaryLeaf(d), None) + else: + root = next_round.pop() + # This approach puts the more complete tree on the right, + # we prefer to have it on the left, + # so let's rotate it for now. + root = InternalNode(root.element, root.right, root.left) + + assert not next_round + + nodes = {} + leaves = {} + + def check_lift(cnode): + # Base case + if isinstance(cnode, BinaryLeaf): + return cnode + + if cnode.right is None and cnode.left is None: + return cnode + elif cnode.right is None and cnode.left is not None: + return check_lift(cnode.left) + elif cnode.right is not None and cnode.left is None: + return check_lift(cnode.right) + + # search the tree in order and build an SBT + visited = set() + queue = [(0, root)] + while queue: + (pos, cnode) = queue.pop(0) + if pos not in visited: + visited.add(pos) + + if isinstance(cnode, BinaryLeaf): + leaves[pos] = cnode.element + elif isinstance(cnode, InternalNode): + #to_lift = check_lift(cnode) + to_lift = None + if cnode.right is None and cnode.left is not None: + to_lift = cnode.left + elif cnode.right is not None and cnode.left is None: + to_lift = cnode.right + + if to_lift: + # This is a lift: if the internal node only have one child, + # lift the child to be the node + cnode = to_lift + if isinstance(cnode, BinaryLeaf): + # if InternalNode only has one child and it is a Leaf, + # use it instead (avoid creating a Node with only one child) + + # Technically the parent here should have all the data, + # but... it didn't. I might be missing something + # somewhere else... + ppos = int(math.floor((pos - 1) / 2)) + if ppos != -1: + # ppos == -1 means this is the root node + cnode.element.update(nodes[ppos]) + + leaves[pos] = cnode.element + continue + + nodes[pos] = cnode.element + + # TODO: this is a one-level rotation of the internal nodes, + # since we want all "complete" subtrees to be to the left, + # and all spaces as much to the right as possible. + # Ideally do a pre-processing step before searching the tree in + # order and building the SBT. + left = cnode.left + right = cnode.right + if isinstance(left, InternalNode) and isinstance(right, InternalNode): + # make sure the left one is "more complete" than the right one + if any(c is None for c in (left.left, left.right)): + left, right = right, left + + if left: + queue.append((2 * pos + 1, left)) + if right: + queue.append((2 * pos + 2, right)) + + new_tree = SBT(factory, storage=storage) + new_tree._nodes = nodes + new_tree._leaves = leaves + + return new_tree + + +def _save_node(node, basepath=None): + path = basepath + return node.save(path) + + def filter_distance(filter_a, filter_b, n=1000): """ Compute a heuristic distance per bit between two Bloom filters. diff --git a/sourmash/sbt_storage.py b/sourmash/sbt_storage.py index 3d78f7eaee..3bdf3a9306 100644 --- a/sourmash/sbt_storage.py +++ b/sourmash/sbt_storage.py @@ -72,6 +72,10 @@ def save(self, path, content): newpath = "{}_{}".format(path, n) fullpath = os.path.join(self.location, self.subdir, newpath) + dirpath = os.path.dirname(fullpath) + if not os.path.exists(dirpath): + os.makedirs(dirpath) + with open(fullpath, 'wb') as f: f.write(content) diff --git a/sourmash/sbtmh.py b/sourmash/sbtmh.py index c8b879118c..725bb69c88 100644 --- a/sourmash/sbtmh.py +++ b/sourmash/sbtmh.py @@ -1,8 +1,10 @@ from io import BytesIO +import os import sys from .sbt import Leaf, SBT, GraphFactory from . import signature +from .logging import trace def load_sbt_index(filename, *, print_version_warning=True, cache_size=None): @@ -14,7 +16,10 @@ def load_sbt_index(filename, *, print_version_warning=True, cache_size=None): def create_sbt_index(bloom_filter_size=1e5, n_children=2): "Create an empty SBT index." - factory = GraphFactory(1, bloom_filter_size, 4) + if bloom_filter_size == 0: + factory = None + else: + factory = GraphFactory(1, bloom_filter_size, 4) tree = SBT(factory, d=n_children) return tree @@ -39,14 +44,21 @@ def __str__(self): return '**Leaf:{name} -> {metadata}'.format( name=self.name, metadata=self.metadata) - def save(self, path): + def save(self, subdir=None): # this is here only for triggering the property load # before we reopen the file (and overwrite the previous # content...) self.data + path = "signatures/" + self.data.md5sum() + + if subdir is not None: + path = os.path.join(subdir, path) + buf = signature.save_signatures([self.data], compression=1) - return self.storage.save(path, buf) + self._path = self.storage.save(path, buf) + + return self._path def update(self, parent): mh = self.data.minhash @@ -62,7 +74,7 @@ def update(self, parent): @property def data(self): if self._data is None: - buf = BytesIO(self.storage.load(self._path)) + buf = self.storage.load(self._path) self._data = signature.load_one_signature(buf) return self._data @@ -112,6 +124,8 @@ def search_minhashes(node, sig, threshold, results=None): else: # Node minhash comparison score = _max_jaccard_underneath_internal_node(node, sig_mh) + trace("(SCORE) {0}: {1}", node.name, score) + if results is not None: results[node.name] = score @@ -134,6 +148,8 @@ def search(self, node, sig, threshold, results=None): else: # internal object, not leaf. score = _max_jaccard_underneath_internal_node(node, sig_mh) + trace("(SCORE) {0}: {1}", node.name, score) + if results is not None: results[node.name] = score @@ -156,10 +172,15 @@ def search_minhashes_containment(node, sig, threshold, results=None, downsample= else: # Node or Leaf, Nodegraph by minhash comparison matches = node.data.matches(mh) + len_mh = max(len(mh), 1) + + score = float(matches) / len_mh + trace("(SCORE) {0}: {1}", node.name, score) + if results is not None: - results[node.name] = float(matches) / len(mh) + results[node.name] = score - if len(mh) and float(matches) / len(mh) >= threshold: + if len_mh and score >= threshold: return 1 return 0 @@ -170,7 +191,9 @@ def __init__(self): def search(self, node, query, threshold, results=None): mh = query.minhash + score = 0 if not len(mh): + trace("(SCORE) {0}: 0", node.name) return 0 if isinstance(node, SigLeaf): @@ -179,9 +202,11 @@ def search(self, node, query, threshold, results=None): matches = node.data.matches(mh) if not matches: + trace("(SCORE) {0}: 0", node.name) return 0 score = float(matches) / len(mh) + trace("(SCORE) {0}: {1}", node.name, score) if score < threshold: return 0 diff --git a/src/core/src/ffi/nodegraph.rs b/src/core/src/ffi/nodegraph.rs index 29a1c3e84c..cc2191a0d4 100644 --- a/src/core/src/ffi/nodegraph.rs +++ b/src/core/src/ffi/nodegraph.rs @@ -149,6 +149,18 @@ pub unsafe extern "C" fn nodegraph_update( ong.update(ng).unwrap(); } +#[no_mangle] +pub unsafe extern "C" fn nodegraph_intersection_count( + ptr: *const SourmashNodegraph, + optr: *const SourmashNodegraph, +) -> usize { + let ng = SourmashNodegraph::as_rust(ptr); + let ong = SourmashNodegraph::as_rust(optr); + + // FIXME raise an exception properly + ng.intersection_count(ong) +} + #[no_mangle] pub unsafe extern "C" fn nodegraph_update_mh( ptr: *mut SourmashNodegraph, diff --git a/src/core/src/sketch/hyperloglog/estimators.rs b/src/core/src/sketch/hyperloglog/estimators.rs index 4c2fbe02cc..f3a6cb8197 100644 --- a/src/core/src/sketch/hyperloglog/estimators.rs +++ b/src/core/src/sketch/hyperloglog/estimators.rs @@ -32,7 +32,7 @@ pub fn mle(counts: &[u16], p: usize, q: usize, relerr: f64) -> f64 { let mut z = 0.; for i in num_iter::range_step_inclusive(k_max_prime as i32, k_min_prime as i32, -1) { - z = 0.5 * z + counts[i as usize] as f64; + z = 0.5 * z + (counts[i as usize] as f64); } // ldexp(x, i) = x * (2 ** i) diff --git a/src/core/src/sketch/nodegraph.rs b/src/core/src/sketch/nodegraph.rs index 1e2fac1eb2..50e07855bd 100644 --- a/src/core/src/sketch/nodegraph.rs +++ b/src/core/src/sketch/nodegraph.rs @@ -289,31 +289,29 @@ impl Nodegraph { self.unique_kmers } - pub fn similarity(&self, other: &Nodegraph) -> f64 { - let result: usize = self - .bs + pub fn intersection_count(&self, other: &Nodegraph) -> usize { + self.bs .iter() .zip(&other.bs) .map(|(bs, bs_other)| bs.intersection(bs_other).count()) - .sum(); + .sum() + } + + pub fn similarity(&self, other: &Nodegraph) -> f64 { + let intersection = self.intersection_count(other); let size: usize = self .bs .iter() .zip(&other.bs) .map(|(bs, bs_other)| bs.union(bs_other).count()) .sum(); - result as f64 / size as f64 + intersection as f64 / size as f64 } pub fn containment(&self, other: &Nodegraph) -> f64 { - let result: usize = self - .bs - .iter() - .zip(&other.bs) - .map(|(bs, bs_other)| bs.intersection(bs_other).count()) - .sum(); + let intersection = self.intersection_count(other); let size: usize = self.bs.iter().map(|bs| bs.len()).sum(); - result as f64 / size as f64 + intersection as f64 / size as f64 } } diff --git a/tests/test_sbt.py b/tests/test_sbt.py index 4d7bb51db6..944d4f990b 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -7,7 +7,7 @@ import sourmash from sourmash import load_one_signature, SourmashSignature, load_signatures from sourmash.exceptions import IndexNotSupported -from sourmash.sbt import SBT, GraphFactory, Leaf, Node +from sourmash.sbt import SBT, GraphFactory, Leaf, Node, scaffold from sourmash.sbtmh import (SigLeaf, search_minhashes, search_minhashes_containment) from sourmash.sbt_storage import (FSStorage, RedisStorage, @@ -915,3 +915,44 @@ def test_sbt_node_cache(): assert tree._nodescache.currsize == 1 assert tree._nodescache.currsize == 1 + + +def test_sbt_scaffold(tmpdir): + factory = GraphFactory(31, 1e5, 4) + + tree = SBT(factory) + leaves = [] + + for f in utils.SIG_FILES: + sig = next(load_signatures(utils.get_test_data(f))) + leaf = SigLeaf(os.path.basename(f), sig) + tree.add_node(leaf) + leaves.append(leaf) + to_search = leaf + + print('*' * 60) + print("{}:".format(to_search.metadata)) + old_result = {str(s) for s in tree.find(search_minhashes, + to_search.data, 0.1)} + print(*old_result, sep='\n') + + with ZipStorage(str(tmpdir.join("tree.sbt.zip"))) as storage: + tree = scaffold(list(leaves), storage) + new_leaves = set(tree.leaves()) + assert len(new_leaves) == len(leaves) + assert new_leaves == set(leaves) + + tree.save(str(tmpdir.join("tree")), storage=storage) + + with ZipStorage(str(tmpdir.join("tree.sbt.zip"))) as storage: + tree = SBT.load(str(tmpdir.join("tree")), + leaf_loader=SigLeaf.load, + storage=storage) + + print('*' * 60) + print("{}:".format(to_search.metadata)) + new_result = {str(s) for s in tree.find(search_minhashes, + to_search.data, 0.1)} + print(*new_result, sep='\n') + + assert old_result == new_result