Skip to content

Commit

Permalink
add sequential option for SLS, fixes to import/export methods SLS<->SMT
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Nov 15, 2024
1 parent 6a9d591 commit 8e3b9f6
Show file tree
Hide file tree
Showing 16 changed files with 224 additions and 63 deletions.
3 changes: 1 addition & 2 deletions src/ast/sls/sat_ddfw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ namespace sat {
}

void ddfw::check_with_plugin() {
m_plugin->init_search();
unsigned steps = 0;
if (m_min_sz <= m_unsat.size())
save_best_values();
Expand All @@ -77,7 +76,6 @@ namespace sat {
IF_VERBOSE(0, verbose_stream() << "Exception: " << ex.what() << "\n");
throw;
}
m_plugin->finish_search();
}

void ddfw::log() {
Expand Down Expand Up @@ -246,6 +244,7 @@ namespace sat {

void ddfw::flip(bool_var v) {
++m_flips;
m_limit.inc();
literal lit = literal(v, !value(v));
literal nlit = ~lit;
SASSERT(is_true(lit));
Expand Down
4 changes: 2 additions & 2 deletions src/ast/sls/sat_ddfw.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ namespace sat {
class local_search_plugin {
public:
virtual ~local_search_plugin() {}
virtual void init_search() = 0;
virtual void finish_search() = 0;
//virtual void init_search() = 0;
//virtual void finish_search() = 0;
virtual void on_rescale() = 0;
virtual void on_save_model() = 0;
virtual void on_restart() = 0;
Expand Down
8 changes: 6 additions & 2 deletions src/ast/sls/sls_arith_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2016,9 +2016,13 @@ namespace sls {
w = mk_term(e);

num_t n;
if (!is_num(v, n))
try {
if (!is_num(v, n))
return false;
}
catch (overflow_exception const&) {
return false;
// verbose_stream() << "set value " << w << " " << mk_bounded_pp(e, m) << " " << n << " " << value(w) << "\n";
}
if (n == value(w))
return true;
return update(w, n);
Expand Down
110 changes: 80 additions & 30 deletions src/ast/sls/sls_smt_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Module Name:
#include "ast/sls/sls_smt_plugin.h"
#include "ast/for_each_expr.h"
#include "ast/bv_decl_plugin.h"
#include "ast/ast_pp.h"
#include "smt/params/smt_params_helper.hpp"

namespace sls {

Expand All @@ -29,6 +31,8 @@ namespace sls {
m_sync(),
m_smt2sync_tr(m, m_sync),
m_smt2sls_tr(m, m_sls),
m_sls2sync_tr(m_sls, m_sync),
m_sls2smt_tr(m_sls, m),
m_sync_uninterp(m_sync),
m_sls_uninterp(m_sls),
m_sync_values(m_sync),
Expand Down Expand Up @@ -85,7 +89,10 @@ namespace sls {
add_shared_term(t);
}

m_thread = std::thread([this]() { run(); });
if (ctx.parallel_mode())
m_thread = std::thread([this]() { run(); });
else
m_completed = true;
}

void smt_plugin::run() {
Expand All @@ -94,8 +101,19 @@ namespace sls {
m_result = m_ddfw->check(0, nullptr);
m_ddfw->collect_statistics(m_st);
IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n");
for (auto v : m_shared_bool_vars) {
auto w = m_smt_bool_var2sls_bool_var[v];
m_rewards[v] = m_ddfw->get_reward_avg(w);
}
m_completed = true;
}

void smt_plugin::bounded_run(unsigned max_iterations) {
m_ddfw->rlimit().reset_count();
m_ddfw->rlimit().push(max_iterations);
run();
m_ddfw->rlimit().pop();
}

void smt_plugin::finalize(model_ref& mdl, ::statistics& st) {
auto* d = m_ddfw;
Expand Down Expand Up @@ -191,13 +209,47 @@ namespace sls {
for (auto v : m_shared_bool_vars) {
auto w = m_smt_bool_var2sls_bool_var[v];
if (m_sat_phase[v] != is_true(sat::literal(w, false)))
flip(w);
flip(w);
m_ddfw->bias(w) = m_sat_phase[v] ? 1 : -1;
}
smt_phase_to_sls();
m_has_new_sat_phase = false;
return true;
}

void smt_plugin::smt_phase_to_sls() {
for (auto v : m_shared_bool_vars) {
auto w = m_smt_bool_var2sls_bool_var[v];
auto phase = ctx.get_best_phase(v);
if (phase != is_true(sat::literal(w, false)))
flip(w);
m_ddfw->bias(w) = phase ? 1 : -1;
}
}

void smt_plugin::smt_values_to_sls() {
for (auto const& [t, t_sync] : m_smt2sync_uninterp) {
expr_ref val_t(m);
if (!ctx.get_value(t, val_t))
continue;
expr* t_sls = m_smt2sls_tr(t);
auto val_sls = expr_ref(m_smt2sls_tr(val_t.get()), m_sls);
m_context.set_value(t_sls, val_sls);
}
}

void smt_plugin::sls_phase_to_smt() {
IF_VERBOSE(2, verbose_stream() << "SLS -> SMT phase\n");
for (auto v : m_shared_bool_vars)
ctx.force_phase(sat::literal(v, !m_sls_phase[v]));
}

void smt_plugin::sls_activity_to_smt() {
IF_VERBOSE(2, verbose_stream() << "SLS -> SMT activity\n");
for (auto v : m_shared_bool_vars)
ctx.inc_activity(v, 200 * m_rewards[v]);
}

bool smt_plugin::export_units_to_sls() {
if (!m_has_units)
return false;
Expand Down Expand Up @@ -225,65 +277,63 @@ namespace sls {
if (unsat().size() > m_min_unsat_size)
return;
m_min_unsat_size = unsat().size();
std::lock_guard<std::mutex> lock(m_mutex);
export_phase_from_sls();
export_values_from_sls();
}

void smt_plugin::export_phase_from_sls() {
std::lock_guard<std::mutex> lock(m_mutex);
for (auto v : m_shared_bool_vars) {
auto w = m_smt_bool_var2sls_bool_var[v];
m_rewards[v] = m_ddfw->get_reward_avg(w);
//verbose_stream() << v << " " << w << "\n";
VERIFY(m_ddfw->get_model().size() > w);
VERIFY(m_sls_phase.size() > v);
m_sls_phase[v] = l_true == m_ddfw->get_model()[w];
m_has_new_sls_phase = true;
m_sls_phase[v] = l_true == m_ddfw->get_model()[w];
}
// export_values_from_sls();
m_has_new_sls_phase = true;
}

void smt_plugin::export_values_from_sls() {
IF_VERBOSE(3, verbose_stream() << "import values from sls\n");
IF_VERBOSE(3, verbose_stream() << "export values from sls\n");
std::lock_guard<std::mutex> lock(m_mutex);
for (auto const& [t, t_sync] : m_sls2sync_uninterp) {
expr_ref val_t = m_context.get_value(t_sync);
m_sync_values.set(t_sync->get_id(), val_t.get());
expr_ref val_t = m_context.get_value(t);
auto sync_val = m_sls2sync_tr(val_t.get());
m_sync_values.setx(t_sync->get_id(), sync_val);
}
m_has_new_sls_values = true;
}

void smt_plugin::import_from_sls() {
export_activity_to_smt();
export_values_to_smt();
export_phase_to_smt();
if (m_has_new_sls_values) {
std::lock_guard<std::mutex> lock(m_mutex);
sls_values_to_smt();
m_has_new_sls_values = false;
}
if (m_has_new_sls_phase) {
std::lock_guard<std::mutex> lock(m_mutex);
sls_phase_to_smt();
m_has_new_sls_phase = false;
}
}

void smt_plugin::export_activity_to_smt() {

}

void smt_plugin::export_values_to_smt() {
void smt_plugin::sls_values_to_smt() {
if (!m_has_new_sls_values)
return;
IF_VERBOSE(3, verbose_stream() << "SLS -> SMT values\n");
std::lock_guard<std::mutex> lock(m_mutex);
IF_VERBOSE(2, verbose_stream() << "SLS -> SMT values\n");
ast_translation tr(m_sync, m);
for (auto const& [t, t_sync] : m_smt2sync_uninterp) {
expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr);
if (!sync_val)
continue;
expr_ref val(tr(sync_val), m);
ctx.initialize_value(t, val);
}
m_has_new_sls_values = false;
}

void smt_plugin::export_phase_to_smt() {
if (!m_has_new_sls_phase)
return;
std::lock_guard<std::mutex> lock(m_mutex);
IF_VERBOSE(3, verbose_stream() << "SLS -> SMT phase\n");
for (auto v : m_shared_bool_vars) {
auto w = m_smt_bool_var2sls_bool_var[v];
ctx.force_phase(sat::literal(w, m_sls_phase[v]));
ctx.set_value(t, val);
}
m_has_new_sls_phase = false;
}

void smt_plugin::add_shared_term(expr* t) {
Expand All @@ -310,6 +360,6 @@ namespace sls {
m_ddfw->reinit();
m_new_clause_added = false;
}
// export_from_sls();
export_from_sls();
}
}
23 changes: 16 additions & 7 deletions src/ast/sls/sls_smt_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ namespace sls {
virtual ~smt_context() {}
virtual ast_manager& get_manager() = 0;
virtual params_ref get_params() = 0;
virtual void initialize_value(expr* t, expr* v) = 0;
virtual void set_value(expr* t, expr* v) = 0;
virtual void force_phase(sat::literal lit) = 0;
virtual void set_has_new_best_phase(bool b) = 0;
virtual bool get_value(expr* v, expr_ref& val) = 0;
virtual bool get_best_phase(sat::bool_var v) = 0;
virtual expr* bool_var2expr(sat::bool_var v) = 0;
virtual void inc_activity(sat::bool_var v, double inc) = 0;
virtual void set_finished() = 0;
virtual unsigned get_num_bool_vars() const = 0;
virtual bool parallel_mode() const = 0;
};


Expand All @@ -50,7 +53,7 @@ namespace sls {
ast_manager& m;
ast_manager m_sls;
ast_manager m_sync;
ast_translation m_smt2sync_tr, m_smt2sls_tr;
ast_translation m_smt2sync_tr, m_smt2sls_tr, m_sls2sync_tr, m_sls2smt_tr;
expr_ref_vector m_sync_uninterp, m_sls_uninterp;
expr_ref_vector m_sync_values;
sat::ddfw* m_ddfw = nullptr;
Expand All @@ -63,6 +66,7 @@ namespace sls {
sat::literal_vector m_units;
model_ref m_sls_model;
::statistics m_st;

bool m_new_clause_added = false;
unsigned m_min_unsat_size = UINT_MAX;
obj_map<expr, expr*> m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp
Expand All @@ -78,19 +82,20 @@ namespace sls {

bool is_shared(sat::literal lit);
void run();

void add_shared_term(expr* t);
void add_uninterp(expr* smt_t);
void add_shared_var(sat::bool_var v, sat::bool_var w);

void import_phase_from_smt();
void import_values_from_sls();
void export_values_from_sls();
void export_phase_from_sls();
void import_activity_from_sls();
bool export_phase_to_sls();
bool export_units_to_sls();
void export_values_to_smt();
void export_activity_to_smt();
void export_phase_to_smt();

void export_from_sls();

Expand All @@ -106,9 +111,12 @@ namespace sls {
void updt_params(params_ref& p) {}
std::ostream& display(std::ostream& out) override;

void bounded_run(unsigned max_iterations);

bool export_to_sls();
void import_from_sls();
bool completed() { return m_completed; }
lbool result() { return m_result; }
void add_unit(sat::literal lit);

// local_search_plugin:
Expand All @@ -124,12 +132,13 @@ namespace sls {
m_sls_model = mdl;
}

void init_search() override {}

void finish_search() override {}

void on_rescale() override {}

void smt_phase_to_sls();
void smt_values_to_sls();
void sls_phase_to_smt();
void sls_values_to_smt();
void sls_activity_to_smt();


// sat_solver_context:
Expand Down
4 changes: 0 additions & 4 deletions src/ast/sls/sls_smt_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ namespace sls {
m.limit().pop_child(&m_ddfw.rlimit());
}

void init_search() override {}

void finish_search() override {}

void on_rescale() override {}

void on_restart() override {
Expand Down
2 changes: 1 addition & 1 deletion src/sat/smt/sls_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace sls {
return s().params();
}

void solver::initialize_value(expr* t, expr* v) {
void solver::set_value(expr* t, expr* v) {
ctx.user_propagate_initialize_value(t, v);
}

Expand Down
5 changes: 4 additions & 1 deletion src/sat/smt/sls_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,16 @@ namespace sls {

ast_manager& get_manager() override { return m; }
params_ref get_params() override;
void initialize_value(expr* t, expr* v) override;
void set_value(expr* t, expr* v) override;
void force_phase(sat::literal lit) override;
void set_has_new_best_phase(bool b) override;
bool get_best_phase(sat::bool_var v) override;
expr* bool_var2expr(sat::bool_var v) override;
void set_finished() override;
void inc_activity(sat::bool_var v, double inc) override {}
unsigned get_num_bool_vars() const override;
bool parallel_mode() const override { return false; }
bool get_value(expr* v, expr_ref& value) override { return false; }

};

Expand Down
1 change: 1 addition & 0 deletions src/smt/params/smt_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void smt_params::updt_local_params(params_ref const & _p) {
m_threads_cube_frequency = p.threads_cube_frequency();
m_core_validate = p.core_validate();
m_sls_enable = p.sls_enable();
m_sls_parallel = p.sls_parallel();
m_logic = _p.get_sym("logic", m_logic);
m_string_solver = p.string_solver();
m_up_persist_clauses = p.up_persist_clauses();
Expand Down
1 change: 1 addition & 0 deletions src/smt/params/smt_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ struct smt_params : public preprocessor_params,
bool m_clause_proof = false;
symbol m_proof_log;
bool m_sls_enable = false;
bool m_sls_parallel = true;

// -----------------------------------
//
Expand Down
1 change: 1 addition & 0 deletions src/smt/params/smt_params_helper.pyg
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def_module_params(module_name='smt',
('str.fixed_length_refinement', BOOL, False, 'use abstraction refinement in fixed-length equation solver (Z3str3 only)'),
('str.fixed_length_naive_cex', BOOL, True, 'construct naive counterexamples when fixed-length model construction fails for a given length assignment (Z3str3 only)'),
('sls.enable', BOOL, False, 'enable sls co-processor with SMT engine'),
('sls.parallel', BOOL, True, 'use sls co-processor in parallel or sequential with SMT engine'),
('core.minimize', BOOL, False, 'minimize unsat core produced by SMT context'),
('core.extend_patterns', BOOL, False, 'extend unsat core with literals that trigger (potential) quantifier instances'),
('core.extend_patterns.max_distance', UINT, UINT_MAX, 'limits the distance of a pattern-extended unsat core'),
Expand Down
Loading

0 comments on commit 8e3b9f6

Please sign in to comment.