diff --git a/pyttb/sptensor.py b/pyttb/sptensor.py index f6d26fc4..9aeb83a5 100644 --- a/pyttb/sptensor.py +++ b/pyttb/sptensor.py @@ -6,7 +6,7 @@ import logging import warnings -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from typing import Any, Callable, List, Optional, Tuple, Union, cast, overload import numpy as np @@ -620,7 +620,7 @@ def innerprod( if self.shape != other.shape: assert False, "Sptensor and tensor must be same shape for innerproduct" [subsSelf, valsSelf] = self.find() - valsOther = other[subsSelf, "extract"] + valsOther = other[subsSelf.transpose(), "extract"] return valsOther.transpose().dot(valsSelf) if isinstance(other, (ttb.ktensor, ttb.ttensor)): # pragma: no cover @@ -685,7 +685,7 @@ def is_length_2(x): if isinstance(B, ttb.tensor): BB = sptensor.from_data( - self.subs, B[self.subs, "extract"][:, None], self.shape + self.subs, B[self.subs.transpose(), "extract"][:, None], self.shape ) C = self.logical_and(BB) return C @@ -1053,7 +1053,7 @@ def scale(self, factor: np.ndarray, dims: Union[float, np.ndarray]) -> sptensor: assert False, "Size mismatch in scale" return ttb.sptensor.from_data( self.subs, - self.vals * factor[self.subs[:, dims], "extract"][:, None], + self.vals * factor[self.subs[:, dims].transpose(), "extract"][:, None], self.shape, ) if isinstance(factor, ttb.sptensor): @@ -1368,9 +1368,9 @@ def __getitem__(self, item): if ( isinstance(item, np.ndarray) and len(item.shape) == 2 - and item.shape[1] == self.ndims + and item.shape[0] == self.ndims ): - srchsubs = np.array(item) + srchsubs = np.array(item.transpose()) # *** CASE 2b: Linear indexing *** else: @@ -1463,21 +1463,21 @@ def _set_subscripts(self, key, value): tt_subscheck(newsubs, nargout=False) # Error check on subscripts - if newsubs.shape[1] < self.ndims: + if newsubs.shape[0] < self.ndims: assert False, "Invalid subscripts" # Check for expanding the order - if newsubs.shape[1] > self.ndims: + if newsubs.shape[0] > self.ndims: newshape = list(self.shape) # TODO no need for loop, just add correct size - for _ in range(self.ndims, newsubs.shape[1]): + for _ in range(self.ndims, newsubs.shape[0]): newshape.append(1) if self.subs.size > 0: self.subs = np.concatenate( ( self.subs, np.ones( - (self.shape[0], newsubs.shape[1] - self.ndims), + (self.shape[0], newsubs.shape[0] - self.ndims), dtype=int, ), ), @@ -1497,7 +1497,7 @@ def _set_subscripts(self, key, value): # Determine number of nonzeros being inserted. # (This is determined by number of subscripts) - newnnz = newsubs.shape[0] + newnnz = newsubs.shape[1] # Error check on size of newvals if newvals.size == 1: @@ -1510,7 +1510,7 @@ def _set_subscripts(self, key, value): assert False, "Number of subscripts and number of values do not match!" # Remove duplicates and print warning if any duplicates were removed - newsubs, idx = np.unique(newsubs, axis=0, return_index=True) + newsubs, idx = np.unique(newsubs.transpose(), axis=0, return_index=True) if newsubs.shape[0] != newnnz: warnings.warn("Duplicate assignments discarded") @@ -1647,6 +1647,8 @@ def _set_subtensor(self, key, value): newsz.append(self.shape[n]) else: newsz.append(max([self.shape[n], key[n].stop])) + elif isinstance(key[n], Iterable): + newsz.append(max([self.shape[n], max(key[n]) + 1])) else: newsz.append(max([self.shape[n], key[n] + 1])) @@ -1660,7 +1662,7 @@ def _set_subtensor(self, key, value): ) else: newsz.append(key[n].stop) - elif isinstance(key[n], np.ndarray): + elif isinstance(key[n], (np.ndarray, Iterable)): newsz.append(max(key[n]) + 1) else: newsz.append(key[n] + 1) @@ -1671,7 +1673,8 @@ def _set_subtensor(self, key, value): self.subs = np.append( self.subs, np.zeros( - shape=(self.subs.shape[0], len(self.shape) - self.subs.shape[1]) + shape=(self.subs.shape[0], len(self.shape) - self.subs.shape[1]), + dtype=int, ), axis=1, ) @@ -1689,7 +1692,7 @@ def _set_subtensor(self, key, value): if isinstance(value, (int, float)): # Determine number of dimensions (may be larger than current number) N = len(key) - keyCopy = np.array(key) + keyCopy = [None] * N # Figure out how many indices are in each dimension nssubs = np.zeros((N, 1)) for n in range(0, N): @@ -1697,7 +1700,11 @@ def _set_subtensor(self, key, value): # Generate slice explicitly to determine its length keyCopy[n] = np.arange(0, self.shape[n])[key[n]] indicesInN = len(keyCopy[n]) + elif isinstance(key[n], Iterable): + keyCopy[n] = key[n] + indicesInN = len(key[n]) else: + keyCopy[n] = key[n] indicesInN = 1 nssubs[n] = indicesInN @@ -1806,7 +1813,7 @@ def __eq__(self, other): ] # Find where their nonzeros intersect - othervals = other[self.subs, "extract"] + othervals = other[self.subs.transpose(), "extract"] znzsubs = self.subs[(othervals[:, None] == self.vals).transpose()[0], :] return sptensor.from_data( @@ -1887,7 +1894,7 @@ def __ne__(self, other): subs1 = np.empty((0, self.subs.shape[1])) # find entries where x is nonzero but not equal to y subs2 = self.subs[ - self.vals.transpose()[0] != other[self.subs, "extract"], : + self.vals.transpose()[0] != other[self.subs.transpose(), "extract"], : ] if subs2.size == 0: subs2 = np.empty((0, self.subs.shape[1])) @@ -2002,7 +2009,7 @@ def __mul__(self, other): ) if isinstance(other, ttb.tensor): csubs = self.subs - cvals = self.vals * other[csubs, "extract"][:, None] + cvals = self.vals * other[csubs.transpose(), "extract"][:, None] return ttb.sptensor.from_data(csubs, cvals, self.shape) if isinstance(other, ttb.ktensor): csubs = self.subs @@ -2124,7 +2131,7 @@ def __le__(self, other): # self nonzero subs2 = self.subs[ - self.vals.transpose()[0] <= other[self.subs, "extract"], : + self.vals.transpose()[0] <= other[self.subs.transpose(), "extract"], : ] # assemble @@ -2212,7 +2219,9 @@ def __lt__(self, other): subs1 = subs1[ttb.tt_setdiff_rows(subs1, self.subs), :] # self nonzero - subs2 = self.subs[self.vals.transpose()[0] < other[self.subs, "extract"], :] + subs2 = self.subs[ + self.vals.transpose()[0] < other[self.subs.transpose(), "extract"], : + ] # assemble subs = np.vstack((subs1, subs2)) @@ -2267,7 +2276,10 @@ def __ge__(self, other): # self nonzero subs2 = self.subs[ - (self.vals >= other[self.subs, "extract"][:, None]).transpose()[0], : + ( + self.vals >= other[self.subs.transpose(), "extract"][:, None] + ).transpose()[0], + :, ] # assemble @@ -2325,7 +2337,10 @@ def __gt__(self, other): # self and other nonzero subs2 = self.subs[ - (self.vals > other[self.subs, "extract"][:, None]).transpose()[0], : + ( + self.vals > other[self.subs.transpose(), "extract"][:, None] + ).transpose()[0], + :, ] # assemble @@ -2428,7 +2443,7 @@ def __truediv__(self, other): if isinstance(other, ttb.tensor): csubs = self.subs - cvals = self.vals / other[csubs, "extract"][:, None] + cvals = self.vals / other[csubs.transpose(), "extract"][:, None] return ttb.sptensor.from_data(csubs, cvals, self.shape) if isinstance(other, ttb.ktensor): # TODO consider removing epsilon and generating nans consistent with above diff --git a/pyttb/tensor.py b/pyttb/tensor.py index 3bebc6ea..b415824d 100644 --- a/pyttb/tensor.py +++ b/pyttb/tensor.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from collections.abc import Iterable from itertools import permutations from math import factorial from typing import Any, Callable, List, Optional, Tuple, Union @@ -1276,11 +1277,14 @@ def __setitem__(self, key, value): # Figure out if we are doing a subtensor, a list of subscripts or a list of # linear indices access_type = "error" - if self.ndims <= 1: - if isinstance(key, np.ndarray): - access_type = "subscripts" - else: + # TODO pull out this big decision tree into a function + if isinstance(key, (float, int, np.generic, slice)): + access_type = "linear indices" + elif self.ndims <= 1: + if isinstance(key, tuple): access_type = "subtensor" + elif isinstance(key, np.ndarray): + access_type = "subscripts" else: if isinstance(key, np.ndarray): if len(key.shape) > 1 and key.shape[1] >= self.ndims: @@ -1289,10 +1293,14 @@ def __setitem__(self, key, value): access_type = "linear indices" elif isinstance(key, tuple): validSubtensor = [ - isinstance(keyElement, (int, slice)) for keyElement in key + isinstance(keyElement, (int, slice, Iterable)) for keyElement in key ] if np.all(validSubtensor): access_type = "subtensor" + elif isinstance(key, Iterable): + key = np.array(key) + if len(key.shape) == 1 or key.shape[1] == 1: + access_type = "linear indices" # Case 1: Rectangular Subtensor if access_type == "subtensor": @@ -1310,10 +1318,14 @@ def __setitem__(self, key, value): def _set_linear(self, key, value): idx = key - if (idx > np.prod(self.shape)).any(): + if not isinstance(idx, slice) and (idx > np.prod(self.shape)).any(): assert ( False ), "TTB:BadIndex In assignment X[I] = Y, a tensor X cannot be resized" + if isinstance(key, (int, float, np.generic)): + idx = np.array([key]) + elif isinstance(key, slice): + idx = np.array(range(np.prod(self.shape))[key]) idx = tt_ind2sub(self.shape, idx) if idx.shape[0] == 1: self.data[tuple(idx[0, :])] = value @@ -1333,6 +1345,14 @@ def _set_subtensor(self, key, value): sliceCheck.append(1) else: sliceCheck.append(element.stop) + elif isinstance(element, Iterable): + if any( + not isinstance(entry, (float, int, np.generic)) for entry in element + ): + raise ValueError( + f"Entries for setitem must be numeric but recieved, {element}" + ) + sliceCheck.append(max(element)) else: sliceCheck.append(element) bsiz = np.array(sliceCheck) @@ -1443,6 +1463,17 @@ def __getitem__(self, item): ------- :class:`pyttb.tensor` or :class:`numpy.ndarray` """ + # Case 0: Single Index Linear + if isinstance(item, (int, float, np.generic, slice)): + if isinstance(item, (int, float, np.generic)): + idx = np.array(item) + elif isinstance(item, slice): + idx = np.array(range(np.prod(self.shape))[item]) + a = np.squeeze( + self.data[tuple(ttb.tt_ind2sub(self.shape, idx).transpose())] + ) + # Todo if row make column? + return ttb.tt_subsubsref(a, idx) # Case 1: Rectangular Subtensor if ( isinstance(item, tuple) @@ -1484,17 +1515,28 @@ def __getitem__(self, item): return a # *** CASE 2a: Subscript indexing *** - if len(item) > 1 and isinstance(item[-1], str) and item[-1] == "extract": + if isinstance(item, np.ndarray) and len(item) > 1: # Extract array of subscripts + subs = np.array(item) + a = np.squeeze(self.data[tuple(subs)]) + # TODO if is row make column? + return ttb.tt_subsubsref(a, subs) + if ( + len(item) > 1 + and isinstance(item[0], np.ndarray) + and isinstance(item[-1], str) + and item[-1] == "extract" + ): + # TODO dry this up subs = np.array(item[0]) - a = np.squeeze(self.data[tuple(subs.transpose())]) + a = np.squeeze(self.data[tuple(subs)]) # TODO if is row make column? return ttb.tt_subsubsref(a, subs) # Case 2b: Linear Indexing - if len(item) >= 2 and not isinstance(item[-1], str): + if isinstance(item, tuple) and len(item) >= 2 and not isinstance(item[-1], str): assert False, "Linear indexing requires single input array" - idx = item[0] + idx = np.array(item) a = np.squeeze(self.data[tuple(ttb.tt_ind2sub(self.shape, idx).transpose())]) # Todo if row make column? return ttb.tt_subsubsref(a, idx) diff --git a/tests/test_sptensor.py b/tests/test_sptensor.py index 3a14ccdb..f9662d10 100644 --- a/tests/test_sptensor.py +++ b/tests/test_sptensor.py @@ -293,9 +293,9 @@ def test_sptensor__getitem__(sample_sptensor): # TODO need to understand what this intends to do ## Case 2 subscript indexing - assert sptensorInstance[np.array([[1, 2, 1]])] == np.array([[0]]) + assert sptensorInstance[np.array([[1], [2], [1]])] == np.array([[0]]) assert ( - sptensorInstance[np.array([[1, 2, 1], [1, 3, 1]])] == np.array([[0], [0]]) + sptensorInstance[np.array([[1, 1], [2, 3], [1, 1]])] == np.array([[0], [0]]) ).all() ## Case 2 Linear Indexing @@ -533,6 +533,18 @@ def test_sptensor_setitem_Case1(sample_sptensor): assert (sptensorInstance.vals == np.vstack((data["vals"], np.array([[7]])))).all() assert sptensorInstance.shape == data["shape"] + # Case I(b)ii: Set with scalar, iterable index, empty sptensor + someTensor = ttb.sptensor() + someTensor[[0, 1], 0] = 1 + assert someTensor[0, 0] == 1 + assert someTensor[1, 0] == 1 + assert np.all(someTensor[[0, 1], 0].vals == 1) + # Case I(b)ii: Set with scalar, iterable index, non-empty sptensor + someTensor[[0, 1], 1] = 2 + assert someTensor[0, 1] == 2 + assert someTensor[1, 1] == 2 + assert np.all(someTensor[[0, 1], 1].vals == 2) + # Case I: Assign with non-scalar or sptensor sptensorInstanceLarger = ttb.sptensor.from_tensor_type(sptensorInstance) with pytest.raises(AssertionError) as excinfo: @@ -551,12 +563,14 @@ def test_sptensor_setitem_Case2(sample_sptensor): # Case II: Too few keys in setitem for number of assignement values with pytest.raises(AssertionError) as excinfo: - sptensorInstance[np.array([1, 1, 1]).astype(int)] = np.array([[999.0], [888.0]]) + sptensorInstance[np.array([[1], [1], [1]]).astype(int)] = np.array( + [[999.0], [888.0]] + ) assert "Number of subscripts and number of values do not match!" in str(excinfo) # Case II: Warning For duplicates with pytest.warns(Warning) as record: - sptensorInstance[np.array([[1, 1, 1], [1, 1, 1]]).astype(int)] = np.array( + sptensorInstance[np.array([[1, 1], [1, 1], [1, 1]]).astype(int)] = np.array( [[999.0], [999.0]] ) assert "Duplicate assignments discarded" in str(record[0].message) @@ -567,54 +581,54 @@ def test_sptensor_setitem_Case2(sample_sptensor): assert np.all(empty_tensor[np.array([[0, 1], [2, 2]])] == 4) # Case II: Single entry, for single sub that exists - sptensorInstance[np.array([1, 1, 1]).astype(int)] = 999.0 - assert (sptensorInstance[np.array([[1, 1, 1]])] == np.array([[999]])).all() + sptensorInstance[np.array([[1], [1], [1]]).astype(int)] = 999.0 + assert (sptensorInstance[np.array([[1], [1], [1]])] == np.array([[999]])).all() assert (sptensorInstance.subs == data["subs"]).all() # Case II: Single entry, for multiple subs that exist (data, sptensorInstance) = sample_sptensor - sptensorInstance[np.array([[1, 1, 1], [1, 1, 3]]).astype(int)] = 999.0 + sptensorInstance[np.array([[1, 1], [1, 1], [1, 3]]).astype(int)] = 999.0 assert ( - sptensorInstance[np.array([[1, 1, 1], [1, 1, 3]])] == np.array([[999], [999]]) + sptensorInstance[np.array([[1, 1], [1, 1], [1, 3]])] == np.array([[999], [999]]) ).all() assert (sptensorInstance.subs == data["subs"]).all() # Case II: Multiple entries, for multiple subs that exist (data, sptensorInstance) = sample_sptensor - sptensorInstance[np.array([[1, 1, 1], [1, 1, 3]]).astype(int)] = np.array( + sptensorInstance[np.array([[1, 1], [1, 1], [1, 3]]).astype(int)] = np.array( [[888], [999]] ) assert ( - sptensorInstance[np.array([[1, 1, 3], [1, 1, 1]])] == np.array([[999], [888]]) + sptensorInstance[np.array([[1, 1], [1, 1], [3, 1]])] == np.array([[999], [888]]) ).all() assert (sptensorInstance.subs == data["subs"]).all() # Case II: Single entry, for single sub that doesn't exist (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 2]]).astype(int)] = 999.0 - assert (copy[np.array([[1, 1, 2]])] == np.array([999])).all() + copy[np.array([[1], [1], [2]]).astype(int)] = 999.0 + assert (copy[np.array([[1], [1], [2]])] == np.array([999])).all() assert (copy.subs == np.concatenate((data["subs"], np.array([[1, 1, 2]])))).all() # Case II: Single entry, for single sub that doesn't exist, expand dimensions (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 2, 1]]).astype(int)] = 999.0 - assert (copy[np.array([[1, 1, 2, 1]])] == np.array([999])).all() + copy[np.array([[1], [1], [2], [1]]).astype(int)] = 999.0 + assert (copy[np.array([[1], [1], [2], [1]])] == np.array([999])).all() # assert (copy.subs == np.concatenate((data['subs'], np.array([[1, 1, 2]])))).all() # Case II: Single entry, for multiple subs one that exists and the other doesn't (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 1], [2, 1, 3]]).astype(int)] = 999.0 - assert (copy[np.array([[2, 1, 3]])] == np.array([999])).all() + copy[np.array([[1, 2], [1, 1], [1, 3]]).astype(int)] = 999.0 + assert (copy[np.array([[2], [1], [3]])] == np.array([999])).all() assert (copy.subs == np.concatenate((data["subs"], np.array([[2, 1, 3]])))).all() # Case II: Multiple entries, for multiple subs that don't exist (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 2], [2, 1, 3]]).astype(int)] = np.array([[888], [999]]) - assert (copy[np.array([[1, 1, 2], [2, 1, 3]])] == np.array([[888], [999]])).all() + copy[np.array([[1, 2], [1, 1], [2, 3]]).astype(int)] = np.array([[888], [999]]) + assert (copy[np.array([[1, 2], [1, 1], [2, 3]])] == np.array([[888], [999]])).all() assert ( copy.subs == np.concatenate((data["subs"], np.array([[1, 1, 2], [2, 1, 3]]))) ).all() @@ -622,8 +636,8 @@ def test_sptensor_setitem_Case2(sample_sptensor): # Case II: Multiple entries, for multiple subs that exist and need to be removed (data, sptensorInstance) = sample_sptensor copy = ttb.sptensor.from_tensor_type(sptensorInstance) - copy[np.array([[1, 1, 1], [1, 1, 3]]).astype(int)] = np.array([[0], [0]]) - assert (copy[np.array([[1, 1, 2], [2, 1, 3]])] == np.array([[0], [0]])).all() + copy[np.array([[1, 1], [1, 1], [1, 3]]).astype(int)] = np.array([[0], [0]]) + assert (copy[np.array([[1, 2], [1, 1], [1, 3]])] == np.array([[0], [0]])).all() assert (copy.subs == np.array([[2, 2, 2], [3, 3, 3]])).all() diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 8894e0ea..d27463c1 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -251,6 +251,13 @@ def test_tensor__setitem__(sample_tensor_2way): # Subtensor add dimension empty_tensor[0, 0, 0] = 2 + # Subtensor with lists + some_tensor = ttb.tenones((3, 3)) + some_tensor[[0, 1], [0, 1]] = 11 + assert some_tensor[0, 0] == 11 + assert some_tensor[1, 1] == 11 + assert np.all(some_tensor[[0, 1], [0, 1]].data == 11) + # Subscripts with constant tensorInstance[np.array([[1, 1]])] = 13.0 dataGrowth[1, 1] = 13.0 @@ -280,11 +287,26 @@ def test_tensor__setitem__(sample_tensor_2way): dataGrowth[np.unravel_index([0], dataGrowth.shape, "F")] = 13.0 assert (tensorInstance.data == dataGrowth).all() + tensorInstance[0] = 14.0 + dataGrowth[np.unravel_index([0], dataGrowth.shape, "F")] = 14.0 + assert (tensorInstance.data == dataGrowth).all() + + tensorInstance[0:1] = 14.0 + dataGrowth[np.unravel_index([0], dataGrowth.shape, "F")] = 14.0 + assert (tensorInstance.data == dataGrowth).all() + # Linear Index with constant tensorInstance[np.array([0, 3, 4])] = 13.0 dataGrowth[np.unravel_index([0, 3, 4], dataGrowth.shape, "F")] = 13 assert (tensorInstance.data == dataGrowth).all() + # Linear index with multiple indicies + some_tensor = ttb.tenones((3, 3)) + some_tensor[[0, 1]] = 2 + assert some_tensor[0] == 2 + assert some_tensor[1] == 2 + assert np.array_equal(some_tensor[[0, 1]], [2, 2]) + # Test Empty Tensor Set Item, subtensor emptyTensor = ttb.tensor.from_data(np.array([])) emptyTensor[0, 0, 0] = 0 @@ -305,8 +327,16 @@ def test_tensor__setitem__(sample_tensor_2way): ) # Attempting to set some other way - with pytest.raises(AssertionError) as excinfo: + with pytest.raises(ValueError) as excinfo: tensorInstance[0, "a", 5] = 13.0 + assert "must be numeric" in str(excinfo) + + with pytest.raises(AssertionError) as excinfo: + + class BadKey: + pass + + tensorInstance[BadKey] = 13.0 assert "Invalid use of tensor setitem" in str(excinfo) @@ -335,11 +365,18 @@ def test_tensor__getitem__(sample_tensor_2way): assert tensorInstance[np.array([0, 0]), "extract"] == params["data"][0, 0] assert ( tensorInstance[np.array([[0, 0], [1, 1]]), "extract"] - == params["data"][([0, 1], [0, 1])] + == params["data"][([0, 0], [1, 1])] + ).all() + # Case 2a: Extract doesn't seem to be needed + assert tensorInstance[np.array([0, 0])] == params["data"][0, 0] + assert ( + tensorInstance[np.array([[0, 0], [1, 1]])] == params["data"][([0, 0], [1, 1])] ).all() # Case 2b: Linear Indexing assert tensorInstance[np.array([0])] == params["data"][0, 0] + assert tensorInstance[0] == params["data"][0, 0] + assert np.array_equal(tensorInstance[0:1], params["data"][0, 0]) with pytest.raises(AssertionError) as excinfo: tensorInstance[np.array([0]), np.array([0]), np.array([0])] assert "Linear indexing requires single input array" in str(excinfo)