Skip to content

Commit

Permalink
Update logging_optimizer.gtsam_optimize to use NonlinearOptimizerPa…
Browse files Browse the repository at this point in the history
…rams::iterationHook
  • Loading branch information
gchenfc committed Apr 19, 2022
1 parent ddca736 commit 71aa20f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
6 changes: 5 additions & 1 deletion python/gtsam/tests/test_logging_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ def test_simple_printing(self):
def hook(_, error):
print(error)

# Only thing we require from optimizer is an iterate method
# Wrapper function sets the hook and calls optimizer.optimize() for us.
gtsam_optimize(self.optimizer, self.params, hook)

# Check that optimizing yields the identity.
actual = self.optimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
self.assertEqual(self.capturedOutput.getvalue(),
"0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n")

def test_lm_simple_printing(self):
"""Make sure we are properly terminating LM"""
Expand All @@ -79,6 +81,8 @@ def hook(_, error):

actual = self.lmoptimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
self.assertEqual(self.capturedOutput.getvalue(),
"0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n")

@unittest.skip("Not a test we want run every time, as needs comet.ml account")
def test_comet(self):
Expand Down
15 changes: 6 additions & 9 deletions python/gtsam/utils/logging_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def optimize(optimizer, check_convergence, hook):
current_error = optimizer.error()
hook(optimizer, current_error)

# Iterative loop
# Iterative loop. Cannot use `params.iterationHook` because we don't have access to params
# (backwards compatibility issue).
while True:
# Do next iteration
optimizer.iterate()
Expand All @@ -35,18 +36,14 @@ def optimize(optimizer, check_convergence, hook):
def gtsam_optimize(optimizer,
params,
hook):
""" Given an optimizer and params, iterate until convergence.
""" Given an optimizer and its params, iterate until convergence.
After each iteration, hook(optimizer) is called.
After the function, use values and errors to get the result.
Arguments:
optimizer {NonlinearOptimizer} -- Nonlinear optimizer
params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters
hook -- hook function to record the error
"""
def check_convergence(optimizer, current_error, new_error):
return (optimizer.iterations() >= params.getMaxIterations()) or (
gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(),
current_error, new_error)) or (
isinstance(optimizer, gtsam.LevenbergMarquardtOptimizer) and optimizer.lambda_() > params.getlambdaUpperBound())
optimize(optimizer, check_convergence, hook)
return optimizer.values()
hook(optimizer, optimizer.error()) # call once at start (backwards compatibility)
params.iterationHook = lambda iteration, error_before, error_after: hook(optimizer, error_after)
return optimizer.optimize()

0 comments on commit 71aa20f

Please sign in to comment.