diff --git a/Project.toml b/Project.toml index 3a5a3d3d..4474e5a6 100644 --- a/Project.toml +++ b/Project.toml @@ -7,18 +7,19 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [extensions] LazyArraysBandedMatricesExt = "BandedMatrices" LazyArraysBlockArraysExt = "BlockArrays" LazyArraysBlockBandedMatricesExt = "BlockBandedMatrices" +LazyArraysSparseArraysExt = "SparseArrays" LazyArraysStaticArraysExt = "StaticArrays" [compat] @@ -48,8 +49,9 @@ BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" Infinities = "e1ba4f0e-776d-440f-acd9-e1d2e9742647" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["Aqua", "BandedMatrices", "Base64", "BlockArrays", "BlockBandedMatrices", "StaticArrays", "Tracker", "Test", "Infinities", "Random"] +test = ["Aqua", "BandedMatrices", "Base64", "BlockArrays", "BlockBandedMatrices", "StaticArrays", "SparseArrays", "Tracker", "Test", "Infinities", "Random"] diff --git a/ext/LazyArraysSparseArraysExt.jl b/ext/LazyArraysSparseArraysExt.jl new file mode 100644 index 00000000..68d5ccd1 --- /dev/null +++ b/ext/LazyArraysSparseArraysExt.jl @@ -0,0 +1,13 @@ +module LazyArraysSparseArraysExt + +using SparseArrays: issparse, nnz, AbstractSparseArray +import LazyArrays: my_issparse, my_nnz + +my_nnz(A::AbstractSparseArray) = nnz(A) + +my_issparse(A::AbstractArray) = issparse(A) +my_issparse(A::DenseArray) = issparse(A) +my_issparse(S::AbstractSparseArray) = issparse(S) + + +end diff --git a/src/LazyArrays.jl b/src/LazyArrays.jl index 8aa52a37..6dafb0be 100644 --- a/src/LazyArrays.jl +++ b/src/LazyArrays.jl @@ -3,7 +3,7 @@ module LazyArrays # Use README as the docstring of the module: @doc read(joinpath(dirname(@__DIR__), "README.md"), String) LazyArrays -using Base.Broadcast, LinearAlgebra, FillArrays, ArrayLayouts, SparseArrays +using Base.Broadcast, LinearAlgebra, FillArrays, ArrayLayouts #, SparseArrays import LinearAlgebra.BLAS import Base: *, +, -, /, <, ==, >, \, ≤, ≥, (:), @_gc_preserve_begin, @_gc_preserve_end, @propagate_inbounds, diff --git a/src/lazyoperations.jl b/src/lazyoperations.jl index 347615f5..a3a79ada 100644 --- a/src/lazyoperations.jl +++ b/src/lazyoperations.jl @@ -161,6 +161,8 @@ function copy(M::Mul{ApplyLayout{typeof(kron)}}) return shuffle_algorithm(algo_type, M.A, M.B, eltype(M)) end +my_issparse(A) = false +my_nnz(A) = prod(size(A)) function shuffle_algorithm( ::ModifiedShuffle, K::Kron{T,2} where T, p::AbstractVecOrMat, OT::Type{<:Number} @@ -184,11 +186,11 @@ function shuffle_algorithm( R_H::Vector{Vector{Int}} = [] C_H::Vector{Vector{Int}} = [] - is_dense = !any(issparse, K.args) + is_dense = !any(my_issparse, K.args) # note: the following computation costs are for multiplication against # a single vector. - nnz_ = [issparse(X_h) ? nnz(X_h) : prod(size(X_h)) for X_h in K.args] + nnz_ = [my_issparse(X_h) ? my_nnz(X_h) : prod(size(X_h)) for X_h in K.args] trad_cost = 2*sum([ prod(size.(K.args[1:h-1], 2)) * nnz_[h] * prod(size.(K.args[h+1:end], 1)) for (h, X_h) in enumerate(K.args) @@ -226,7 +228,7 @@ function shuffle_algorithm( return shuffle_algorithm(Shuffle(), K, p, OT) end - nnz_m = [issparse(X_h) ? nnz(X_h) : prod(size(X_h)) for X_h in K_shrunk_factors] + nnz_m = [my_issparse(X_h) ? my_nnz(X_h) : prod(size(X_h)) for X_h in K_shrunk_factors] modified_cost = 2*sum([ prod(length.(C_H[1:h-1])) * nnz_m[h] * prod(length.(R_H[h+1:end])) for (h, X_h) in enumerate(K.args)