From 386f08e9654fb247a53e5cce0d7df43a8f54c15e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 8 Feb 2025 14:34:36 +0530 Subject: [PATCH 1/2] fix(HomotopyContinuation): handle `NaN` results from `unpolynomialize` --- .../Project.toml | 4 +++- .../src/solve.jl | 15 ++++++++++++-- .../test/allroots.jl | 20 +++++++++++++++++++ .../test/single_root.jl | 17 ++++++++++++++++ 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/lib/NonlinearSolveHomotopyContinuation/Project.toml b/lib/NonlinearSolveHomotopyContinuation/Project.toml index 06a21cbd0..0072625ad 100644 --- a/lib/NonlinearSolveHomotopyContinuation/Project.toml +++ b/lib/NonlinearSolveHomotopyContinuation/Project.toml @@ -26,6 +26,7 @@ DocStringExtensions = "0.9.3" Enzyme = "0.13" HomotopyContinuation = "2.12.0" LinearAlgebra = "1.10" +NaNMath = "1.1" NonlinearSolve = "4" NonlinearSolveBase = "1.3.3" SciMLBase = "2.72.2" @@ -37,8 +38,9 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "NonlinearSolve", "Enzyme"] +test = ["Aqua", "Test", "NonlinearSolve", "Enzyme", "NaNMath"] diff --git a/lib/NonlinearSolveHomotopyContinuation/src/solve.jl b/lib/NonlinearSolveHomotopyContinuation/src/solve.jl index b3e92b6b3..dee0bd921 100644 --- a/lib/NonlinearSolveHomotopyContinuation/src/solve.jl +++ b/lib/NonlinearSolveHomotopyContinuation/src/solve.jl @@ -84,7 +84,11 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{t end # unpack solutions and make them real u = isscalar ? only(result.solution) : result.solution - append!(validsols, f.unpolynomialize(real.(u), p)) + unpolysols = f.unpolynomialize(real.(u), p) + for sol in unpolysols + any(isnan, sol) && continue + push!(validsols, sol) + end end # if there are no valid solutions @@ -137,10 +141,17 @@ function CommonSolve.solve(prob::NonlinearProblem, alg::HomotopyContinuationJL{f T = eltype(u0) validsols = f.unpolynomialize(realsol, p) _, idx = findmin(validsols) do sol - norm(sol - u0_p) + any(isnan, sol) ? Inf : norm(sol - u0_p) end + u = map(real, validsols[idx]) + if any(isnan, u) + retcode = SciMLBase.ReturnCode.Infeasible + resid = NonlinearSolveBase.Utils.evaluate_f(prob, u0) + return SciMLBase.build_solution(prob, alg, u0, resid; retcode, original = orig_sol) + end + if u0 isa Number u = only(u) end diff --git a/lib/NonlinearSolveHomotopyContinuation/test/allroots.jl b/lib/NonlinearSolveHomotopyContinuation/test/allroots.jl index e74991afc..79a874a49 100644 --- a/lib/NonlinearSolveHomotopyContinuation/test/allroots.jl +++ b/lib/NonlinearSolveHomotopyContinuation/test/allroots.jl @@ -3,6 +3,7 @@ using NonlinearSolveHomotopyContinuation using SciMLBase: NonlinearSolution using ADTypes using Enzyme +import NaNMath alg = HomotopyContinuationJL{true}(; threading = false) @@ -170,3 +171,22 @@ end end end end + +@testset "`NaN` unpolynomialize" begin + polynomialize = function (u, p) + return sin(u^2) + end + unpolynomialize = function (u, p) + return (-NaNMath.sqrt(NaNMath.asin(u)), NaNMath.sqrt(NaNMath.asin(u))) + end + rhs = function (u, p) + return u^2 + u - 1 + end + prob = NonlinearProblem( + HomotopyNonlinearFunction(rhs; polynomialize, unpolynomialize), 1.0) + sol = solve(prob, alg) + @test sol isa EnsembleSolution + for nlsol in sol.u + @test !isnan(nlsol.u) + end +end diff --git a/lib/NonlinearSolveHomotopyContinuation/test/single_root.jl b/lib/NonlinearSolveHomotopyContinuation/test/single_root.jl index f56219586..615c433e6 100644 --- a/lib/NonlinearSolveHomotopyContinuation/test/single_root.jl +++ b/lib/NonlinearSolveHomotopyContinuation/test/single_root.jl @@ -1,6 +1,7 @@ using NonlinearSolve using NonlinearSolveHomotopyContinuation using SciMLBase: NonlinearSolution +import NaNMath alg = HomotopyContinuationJL{false}(; threading = false) @@ -146,3 +147,19 @@ end end end end + +@testset "`NaN` unpolynomialize" begin + polynomialize = function (u, p) + return sin(u^2) + end + unpolynomialize = function (u, p) + return NaN + end + rhs = function (u, p) + return u^2 + u - 1 + end + prob = NonlinearProblem( + HomotopyNonlinearFunction(rhs; polynomialize, unpolynomialize), 1.0) + sol = solve(prob, alg) + @test !SciMLBase.successful_retcode(sol) +end From 2c80a9c7fb4c457d66cf5ffb1ccb0737d0e52b64 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 10 Feb 2025 11:42:02 +0530 Subject: [PATCH 2/2] docs: fix docs block --- docs/src/api/homotopycontinuation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api/homotopycontinuation.md b/docs/src/api/homotopycontinuation.md index 1e46c6378..783eb6883 100644 --- a/docs/src/api/homotopycontinuation.md +++ b/docs/src/api/homotopycontinuation.md @@ -14,5 +14,5 @@ using NonlinearSolveHomotopyContinuation, NonlinearSolve ```@docs NonlinearSolveHomotopyContinuation.HomotopyContinuationJL -SciMLBase.HomotopyContinuationFunction +SciMLBase.HomotopyNonlinearFunction ```