Skip to content

Commit

Permalink
Merge pull request #1 from sgaure/callitbatched
Browse files Browse the repository at this point in the history
Change keyword in ParticleSwarm to "batched". Remove an unused argument
  • Loading branch information
sgaure authored Feb 7, 2025
2 parents b498881 + edd362d commit 4d310b5
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions src/multivariate/solvers/zeroth_order/particle_swarm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ struct ParticleSwarm{Tl, Tu} <: ZerothOrderOptimizer
lower::Tl
upper::Tu
n_particles::Int
parallel::Bool
batched::Bool
end

ParticleSwarm(a, b, c) = ParticleSwarm(a, b, c, false)
Expand All @@ -14,15 +14,15 @@ ParticleSwarm(a, b, c) = ParticleSwarm(a, b, c, false)
ParticleSwarm(; lower = [],
upper = [],
n_particles = 0,
parallel = false)
batched = false)
```
The constructor takes 4 keywords:
* `lower = []`, a vector of lower bounds, unbounded below if empty or `-Inf`'s
* `upper = []`, a vector of upper bounds, unbounded above if empty or `Inf`'s
* `n_particles = 0`, the number of particles in the swarm, defaults to least three
* `parallel = false`, if true, the objective function is evaluated on a matrix
of column vectors.
* `batched = false`, if true, the objective function is evaluated on a matrix
of column vectors.
## Description
The Particle Swarm implementation in Optim.jl is the so-called Adaptive Particle
Expand All @@ -33,7 +33,7 @@ particle and move it away from its (potentially and probably) local optimum, to
improve the ability to find a global optimum. Of course, this comes a the cost
of slower convergence, but hopefully converges to the global optimum as a result.
If `parallel = true` is specified, there should be a 2-argument method for the objective
If `batched = true` is specified, there should be a 2-argument method for the objective
function, `f(val, X)`. The input vectors are columns of `X`. The outputs are written
into `val`. This makes it possible to parallelize the function evaluations, e.g. with:
Expand All @@ -52,7 +52,7 @@ reaches the maximum number of iterations set in Optim.Options(iterations=x)`.
## References
- [1] Zhan, Zhang, and Chung. Adaptive particle swarm optimization, IEEE Transactions on Systems, Man, and Cybernetics, Part B: CyberneticsVolume 39, Issue 6 (2009): 1362-1381
"""
ParticleSwarm(; lower = [], upper = [], n_particles = 0, parallel=false) = ParticleSwarm(lower, upper, n_particles, parallel)
ParticleSwarm(; lower = [], upper = [], n_particles = 0, batched=false) = ParticleSwarm(lower, upper, n_particles, batched)

Base.summary(::ParticleSwarm) = "Particle Swarm"

Expand Down Expand Up @@ -113,7 +113,7 @@ function initial_state(method::ParticleSwarm, options, d, initial_x::AbstractArr

@assert length(lower) == length(initial_x) "limits must be of same length as x_initial."
@assert all(upper .>= lower) "upper must be greater than or equal to lower"

if method.n_particles > 0
if method.n_particles < 3
@warn("Number of particles is set to 3 (minimum required)")
Expand Down Expand Up @@ -176,7 +176,7 @@ function initial_state(method::ParticleSwarm, options, d, initial_x::AbstractArr
X_best[j, 1] = initial_x[j]
end

if method.parallel
if method.batched
# here we could make a view of X, but then the objective will
# be compiled for a view also. We avoid that.
value(d, @view(score[2:n_particles]), X[:, 2:n_particles])
Expand Down Expand Up @@ -271,12 +271,12 @@ function update_state!(f, state::ParticleSwarmState{T}, method::ParticleSwarm) w
if state.limit_search_space
limit_X!(state.X, state.lower, state.upper, state.n_particles, n)
end
if method.parallel
compute_cost_parallel!(f, state.n_particles, state.X, state.score)
if method.batched
compute_cost_batched!(f, state.X, state.score)
else
compute_cost!(f, state.n_particles, state.X, state.score)
end

state.iteration += 1
false
end
Expand Down Expand Up @@ -512,8 +512,7 @@ function compute_cost!(f,
nothing
end

function compute_cost_parallel!(f,
n_particles::Int,
function compute_cost_batched!(f,
X::Matrix,
score::Vector)
value(f, score, X)
Expand Down

0 comments on commit 4d310b5

Please sign in to comment.