Skip to content
This repository has been archived by the owner on Feb 26, 2025. It is now read-only.

Commit

Permalink
Improve NeuriteType constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed May 1, 2023
1 parent 0f2d645 commit e178100
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
40 changes: 30 additions & 10 deletions neurom/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __new__(cls, *value):
SectionType(int_type) for int_type in cls._index_to_ids(int(obj), cls._BASE)
)

obj._value_ = int(obj)
obj._value_ = value
return obj

@staticmethod
Expand All @@ -94,9 +94,9 @@ def _ids_to_index(ids, base):
def _index_to_ids(index, base):
"""Convert a linear index into ids on a square grid with side 'base'."""
# find number of integers in linear index
n_digits = math.ceil(len(str(index)) / (len(str(base)) - 1))
n_digits = math.ceil(len(str(index)) / max(1, len(str(base)) - 1))

int_types = np.unravel_index(index, shape=(base,) * n_digits)
int_types = np.trim_zeros(np.unravel_index(index, shape=(base,) * n_digits))

if _ALL_SUBTYPE in int_types and len(int_types) > 1:
raise NeuroMError(
Expand Down Expand Up @@ -215,15 +215,35 @@ def unregister(cls, name):

def _enum_accept_undefined(cls, value):
# pylint: disable=protected-access
try:
obj = cls._member_map_[value]
except (KeyError, TypeError):

# Use NeuriteType name
if isinstance(value, NeuriteType):
value_str = value.name
if value_str in cls._member_map_:
return cls._member_map_[value_str]

# Name given as string
elif isinstance(value, str):
if value in cls._member_map_:
return cls._member_map_[value]

# SectionType or raw integer
elif isinstance(value, collections.abc.Hashable):
if value in cls._value2member_map_:
return cls._value2member_map_[value]

# Composite type or unhashable type (e.g. list)
else:
try:
subtype_value = SubtypeCollection(value)
obj = cls._value2member_map_[subtype_value]
except (KeyError, ValueError) as exc2:
raise ValueError(f"{value} is not a valid NeuriteType") from exc2
return obj
except Exception: # pylint: disable=broad-exception-caught
pass
else:
if subtype_value in cls._value2member_map_:
return cls._value2member_map_[subtype_value]

# Invalid value
raise ValueError(f"{value} is not a valid registered NeuriteType")


NeuriteType.__new__ = _enum_accept_undefined
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_tree_type_checker_broken():


def test_tree_type_checker_error():
with pytest.raises(ValueError, match='is not a valid NeuriteType'):
with pytest.raises(ValueError, match='is not a valid registered NeuriteType'):
tree_type_checker('NOT A VALID NeuriteType')


Expand Down
4 changes: 2 additions & 2 deletions tests/test_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def test_eq(self):
assert NeuriteType.axon_carrying_dendrite != SubtypeCollection(NeuriteType.apical_dendrite)

def test_raise(self):
with pytest.raises(ValueError, match=r"\[2, 3, 4\] is not a valid NeuriteType"):
with pytest.raises(ValueError, match=r"\[2, 3, 4\] is not a valid registered NeuriteType"):
NeuriteType([2, 3, 4])

with pytest.raises(ValueError, match="20304 is not a valid NeuriteType"):
with pytest.raises(ValueError, match="20304 is not a valid registered NeuriteType"):
NeuriteType(20304)

def test_integer_behavior(self):
Expand Down

0 comments on commit e178100

Please sign in to comment.