From 4729f15a622ae2a26359eee09538f1244ae567d5 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Mon, 1 May 2023 19:54:34 +0200 Subject: [PATCH] Fix base and add test --- neurom/core/types.py | 17 ++++++++++++----- tests/test_mixed.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/neurom/core/types.py b/neurom/core/types.py index c5c71338..3c7300f2 100644 --- a/neurom/core/types.py +++ b/neurom/core/types.py @@ -94,9 +94,17 @@ def _ids_to_index(ids, base): def _index_to_ids(index, base): """Convert a linear index into ids on a square grid with side 'base'.""" # find number of integers in linear index - n_digits = math.ceil(len(str(index)) / max(1, len(str(base)) - 1)) + if index > 1: + ratio = math.log(index) / math.log(base) + n_digits = math.ceil(ratio) + if int(ratio) == n_digits: + n_digits += 1 + elif index == 1: + n_digits = 1 + else: + return [0] - int_types = np.trim_zeros(np.unravel_index(index, shape=(base,) * n_digits)) + int_types = np.unravel_index(index, shape=(base,) * n_digits) if _ALL_SUBTYPE in int_types and len(int_types) > 1: raise NeuroMError( @@ -228,9 +236,8 @@ def _enum_accept_undefined(cls, value): return cls._member_map_[value] # SectionType or raw integer - elif isinstance(value, collections.abc.Hashable): - if value in cls._value2member_map_: - return cls._value2member_map_[value] + elif isinstance(value, collections.abc.Hashable) and value in cls._value2member_map_: + return cls._value2member_map_[value] # Composite type or unhashable type (e.g. list) else: diff --git a/tests/test_mixed.py b/tests/test_mixed.py index fd412f6f..146b3894 100644 --- a/tests/test_mixed.py +++ b/tests/test_mixed.py @@ -119,6 +119,31 @@ def test_integer_behavior(self): assert SubtypeCollection(2) * 2 == 4 assert SubtypeCollection(2) / 2 == 1 + @pytest.mark.parametrize("base", [2, 3, 11, 99, 100]) + def test_base(self, monkeypatch, base): + monkeypatch.setattr(SubtypeCollection, "_BASE", base) + + def int_to_base(value, base): + """Convert an integer into a list of numbers in the given base.""" + digits = [] + while value: + digits.append(int(value % base)) + value = value // base + if not digits: + digits = [0] + return digits[::-1] + + SubtypeCollection(0).subtypes == [0] + SubtypeCollection(1).subtypes == [1] + SubtypeCollection(5).subtypes == int_to_base(5, base) + SubtypeCollection(7).subtypes == int_to_base(7, base) + SubtypeCollection(11).subtypes == int_to_base(11, base) + SubtypeCollection(23).subtypes == int_to_base(23, base) + SubtypeCollection(99).subtypes == int_to_base(99, base) + SubtypeCollection(100).subtypes == int_to_base(100, base) + SubtypeCollection(101).subtypes == int_to_base(101, base) + SubtypeCollection(10101).subtypes == int_to_base(10101, base) + class TestNeuriteType: def test_repr(self): @@ -217,6 +242,16 @@ def test_eq(self): assert NeuriteType.axon_carrying_dendrite != SubtypeCollection(NeuriteType.apical_dendrite) def test_raise(self): + NeuriteType("all") + NeuriteType((3, 2)) + with pytest.raises(ValueError, match="None is not a valid registered NeuriteType"): + NeuriteType(None) + with pytest.raises( + ValueError, match="{'WRONG TYPE': 999} is not a valid registered NeuriteType" + ): + NeuriteType({"WRONG TYPE": 999}) + with pytest.raises(ValueError, match="UNKNOWN VALUE is not a valid registered NeuriteType"): + NeuriteType("UNKNOWN VALUE") with pytest.raises(ValueError, match=r"\[2, 3, 4\] is not a valid registered NeuriteType"): NeuriteType([2, 3, 4])