Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exclude dims #91

Merged
merged 3 commits into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions pyttb/ktensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ def tovec(self, include_weights=True):
# offset += f.shape[0]
return x

def ttv(self, vector, dims=None):
def ttv(self, vector, dims=None, exclude_dims=None):
"""
`Tensor` times vector for `ktensors`.

Expand Down Expand Up @@ -1833,7 +1833,7 @@ def ttv(self, vector, dims=None):
>>> weights = 2 * np.ones(rank)
>>> weights_and_data = np.concatenate((weights, data), axis=0)
>>> K = ttb.ktensor.from_vector(weights_and_data[:], shape, True)
>>> K0 = K.ttv(np.array([1, 1, 1]), dims=1) # compute along a single dimension
>>> K0 = K.ttv(np.array([1, 1, 1]),dims=1) # compute along a single dimension
>>> print(K0)
ktensor of shape 2 x 4
weights=[36. 54.]
Expand All @@ -1857,7 +1857,7 @@ def ttv(self, vector, dims=None):

Compute the product of a `ktensor` and multiple vectors out of order (results in a `ktensor`):

>>> K2 = K.ttv([vec4, vec3], np.array([2, 1]))
>>> K2 = K.ttv([vec4, vec3],np.array([2, 1]))
>>> print(K2)
ktensor of shape 2
weights=[1800. 3564.]
Expand All @@ -1866,17 +1866,20 @@ def ttv(self, vector, dims=None):
[2. 4.]]
"""

if dims is None:
if dims is None and exclude_dims is None:
dims = np.array([])
elif isinstance(dims, (float, int)):
dims = np.array([dims])

if isinstance(exclude_dims, (float, int)):
exclude_dims = np.array([exclude_dims])

# Check that vector is a list of vectors, if not place single vector as element in list
if len(vector) > 0 and isinstance(vector[0], (int, float, np.int_, np.float_)):
return self.ttv([vector], dims)

# Get sorted dims and index for multiplicands
dims, vidx = ttb.tt_dimscheck(dims, self.ndims, len(vector))
dims, vidx = ttb.tt_dimscheck(self.ndims, len(vector), dims, exclude_dims)

# Check that each multiplicand is the right size.
for i in range(dims.size):
Expand Down
60 changes: 46 additions & 14 deletions pyttb/pyttb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,30 @@ def tt_union_rows(MatrixA, MatrixB):


@overload
def tt_dimscheck(dims: np.ndarray, N: int, M: None = None) -> Tuple[np.ndarray, None]:
def tt_dimscheck(
N: int,
M: None = None,
dims: Optional[np.ndarray] = None,
exclude_dims: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, None]:
... # pragma: no cover see coveragepy/issues/970


@overload
def tt_dimscheck(dims: np.ndarray, N: int, M: int) -> Tuple[np.ndarray, np.ndarray]:
def tt_dimscheck(
N: int,
M: int,
dims: Optional[np.ndarray] = None,
exclude_dims: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray]:
... # pragma: no cover see coveragepy/issues/970


def tt_dimscheck(
dims: np.ndarray, N: int, M: Optional[int] = None
N: int,
M: Optional[int] = None,
dims: Optional[np.ndarray] = None,
exclude_dims: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Used to preprocess dimensions for tensor dimensions
Expand All @@ -136,24 +149,43 @@ def tt_dimscheck(
-------

"""
# Fix empty case
if dims.size == 0:
dims = np.arange(0, N)
if dims is not None and exclude_dims is not None:
raise ValueError("Either specify dims to include or exclude, but not both")

dim_array: np.ndarray = np.empty((1,))

# Fix "minus" case
if np.max(dims) < 0:
# Explicit exclude to resolve ambiguous -0
if exclude_dims is not None:
# Check that all members in range
if not np.all(np.isin(-dims, np.arange(0, N + 1))):
assert False, "Invalid magnitude for negative dims selection"
dims = np.setdiff1d(np.arange(1, N + 1), -dims) - 1
valid_indices = np.isin(exclude_dims, np.arange(0, N))
if not np.all(valid_indices):
invalid_indices = np.logical_not(valid_indices)
raise ValueError(
f"Exclude dims provided: {exclude_dims} "
f"but, {exclude_dims[invalid_indices]} were out of valid range"
f"[0,{N}]"
)
dim_array = np.setdiff1d(np.arange(0, N), exclude_dims)

# Fix empty case
if (dims is None or dims.size == 0) and exclude_dims is None:
dim_array = np.arange(0, N)
elif isinstance(dims, np.ndarray):
dim_array = dims

# Catch minus case to avoid silent errors
if np.any(dim_array < 0):
raise ValueError(
"Negative dims aren't allowed in pyttb, see exclude_dims argument instead"
)

# Save dimensions of dims
P = len(dims)
P = len(dim_array)

# Reorder dims from smallest to largest (this matters in particular for the vector
# multiplicand case, where the order affects the result)
sidx = np.argsort(dims)
sdims = dims[sidx]
sidx = np.argsort(dim_array)
sdims = dim_array[sidx]
vidx = None

if M is not None:
Expand Down
32 changes: 23 additions & 9 deletions pyttb/sptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def collapse(
if dims is None:
dims = np.arange(0, self.ndims)

dims, _ = tt_dimscheck(dims, self.ndims)
dims, _ = tt_dimscheck(self.ndims, dims=dims)
remdims = np.setdiff1d(np.arange(0, self.ndims), dims)

# Check for the case where we accumulate over *all* dimensions
Expand Down Expand Up @@ -882,7 +882,7 @@ def mttkrp(self, U: Union[ttb.ktensor, List[np.ndarray]], n: int) -> np.ndarray:
else:
Z.append(np.array([]))
# Perform ttv multiplication
V[:, r] = self.ttv(Z, -(n + 1)).double()
V[:, r] = self.ttv(Z, exclude_dims=n).double()

return V

Expand Down Expand Up @@ -1044,7 +1044,7 @@ def scale(self, factor: np.ndarray, dims: Union[float, np.ndarray]) -> sptensor:
"""
if isinstance(dims, (float, int)):
dims = np.array([dims])
dims, _ = ttb.tt_dimscheck(dims, self.ndims)
dims, _ = ttb.tt_dimscheck(self.ndims, dims=dims)

if isinstance(factor, ttb.tensor):
shapeArray = np.array(self.shape)
Expand Down Expand Up @@ -1181,6 +1181,7 @@ def ttv(
self,
vector: Union[np.ndarray, List[np.ndarray]],
dims: Optional[Union[int, np.ndarray]] = None,
exclude_dims: Optional[Union[int, np.ndarray]] = None,
) -> Union[sptensor, ttb.tensor]:
"""
Sparse tensor times vector
Expand All @@ -1189,20 +1190,24 @@ def ttv(
----------
vector: Vector(s) to multiply against
dims: Dimensions to multiply with vector(s)
exclude_dims: Use all dimensions but these
"""

if dims is None:
if dims is None and exclude_dims is None:
dims = np.array([])
elif isinstance(dims, (float, int)):
dims = np.array([dims])

if isinstance(exclude_dims, (float, int)):
exclude_dims = np.array([exclude_dims])

# Check that vector is a list of vectors,
# if not place single vector as element in list
if len(vector) > 0 and isinstance(vector[0], (int, float, np.int_, np.float_)):
return self.ttv(np.array([vector]), dims)
return self.ttv(np.array([vector]), dims, exclude_dims)

# Get sorted dims and index for multiplicands
dims, vidx = ttb.tt_dimscheck(dims, self.ndims, len(vector))
dims, vidx = ttb.tt_dimscheck(self.ndims, len(vector), dims, exclude_dims)
remdims = np.setdiff1d(np.arange(0, self.ndims), dims).astype(int)

# Check that each multiplicand is the right size.
Expand Down Expand Up @@ -2495,6 +2500,7 @@ def ttm(
self,
matrices: Union[np.ndarray, List[np.ndarray]],
dims: Optional[Union[float, np.ndarray]] = None,
exclude_dims: Optional[Union[float, np.ndarray]] = None,
transpose: bool = False,
):
"""
Expand All @@ -2503,24 +2509,28 @@ def ttm(
Parameters
----------
matrices: A matrix or list of matrices
dims: :class:`Numpy.ndarray`, int
dims: Dimensions to multiply against
exclude_dims: Use all dimensions but these
transpose: Transpose matrices to be multiplied

Returns
-------

"""
if dims is None:
if dims is None and exclude_dims is None:
dims = np.arange(self.ndims)
elif isinstance(dims, list):
dims = np.array(dims)
elif isinstance(dims, (float, int, np.generic)):
dims = np.array([dims])

if isinstance(exclude_dims, (float, int)):
exclude_dims = np.array([exclude_dims])

# Handle list of matrices
if isinstance(matrices, list):
# Check dimensions are valid
[dims, vidx] = tt_dimscheck(dims, self.ndims, len(matrices))
[dims, vidx] = tt_dimscheck(self.ndims, len(matrices), dims, exclude_dims)
# Calculate individual products
Y = self.ttm(matrices[vidx[0]], dims[0], transpose=transpose)
for i in range(1, dims.size):
Expand All @@ -2535,6 +2545,10 @@ def ttm(
if transpose:
matrices = matrices.transpose()

# FIXME: This made typing happy but shouldn't be possible
if not isinstance(dims, np.ndarray): # pragma: no cover
raise ValueError("Dims should be an array here")

# Ensure this is the terminal single dimension case
if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
assert False, "dims must contain values in [0,self.dims)"
Expand Down
Loading