diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index f1e4eb8569..686e5130b4 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -489,11 +489,11 @@ namespace sls { if (vi.m_op == arith_op_kind::OP_NUM) return; if (is_add(v) && m_allow_recursive_delta) - add_update_add(m_adds[vi.m_def_idx], delta_out); + add_update_add(get_add(v), delta_out); else if (is_mul(v) && m_allow_recursive_delta) - add_update_mul(m_muls[vi.m_def_idx], delta_out); + add_update_mul(get_mul(v), delta_out); else if (is_op(v) && m_allow_recursive_delta) - add_update(m_ops[vi.m_def_idx], delta_out); + add_update(get_op(v), delta_out); else if (vi.is_if_op() && m_allow_recursive_delta) { expr* c, * t, * e; VERIFY(m.is_ite(vi.m_expr, c, t, e)); @@ -1283,7 +1283,7 @@ namespace sls { case LAST_ARITH_OP: break; case OP_ADD: { - auto const& ad = m_adds[vi.m_def_idx]; + auto const& ad = get_add(v); auto const& args = ad.m_args; result = ad.m_coeff; for (auto [c, w] : args) @@ -1291,40 +1291,40 @@ namespace sls { break; } case OP_MUL: { - auto const& [w, monomial] = m_muls[vi.m_def_idx]; + auto const& [w, monomial] = get_mul(v); result = num_t(1); for (auto [w, p] : monomial) result *= power_of(value(w), p); break; } case OP_MOD: - v1 = value(m_ops[vi.m_def_idx].m_arg1); - v2 = value(m_ops[vi.m_def_idx].m_arg2); + v1 = value(get_op(v).m_arg1); + v2 = value(get_op(v).m_arg2); result = v2 == 0 ? num_t(0) : mod(v1, v2); break; case OP_DIV: - v1 = value(m_ops[vi.m_def_idx].m_arg1); - v2 = value(m_ops[vi.m_def_idx].m_arg2); + v1 = value(get_op(v).m_arg1); + v2 = value(get_op(v).m_arg2); result = v2 == 0 ? num_t(0) : v1 / v2; break; case OP_IDIV: - v1 = value(m_ops[vi.m_def_idx].m_arg1); - v2 = value(m_ops[vi.m_def_idx].m_arg2); + v1 = value(get_op(v).m_arg1); + v2 = value(get_op(v).m_arg2); result = v2 == 0 ? num_t(0) : div(v1, v2); break; case OP_REM: - v1 = value(m_ops[vi.m_def_idx].m_arg1); - v2 = value(m_ops[vi.m_def_idx].m_arg2); + v1 = value(get_op(v).m_arg1); + v2 = value(get_op(v).m_arg2); result = v2 == 0 ? num_t(0) : v1 %= v2; break; case OP_ABS: - result = abs(value(m_ops[vi.m_def_idx].m_arg1)); + result = abs(value(get_op(v).m_arg1)); break; case OP_TO_REAL: - result = value(m_ops[vi.m_def_idx].m_arg1); + result = value(get_op(v).m_arg1); break; case OP_TO_INT: { - rational r = value(m_ops[vi.m_def_idx].m_arg1).to_rational(); + rational r = value(get_op(v).m_arg1).to_rational(); result = to_num(floor(r)); break; } @@ -1368,25 +1368,25 @@ namespace sls { case arith_op_kind::LAST_ARITH_OP: break; case arith_op_kind::OP_ADD: - return repair_add(m_adds[vi.m_def_idx]); + return repair_add(get_add(v)); case arith_op_kind::OP_MUL: - return repair_mul(m_muls[vi.m_def_idx]); + return repair_mul(get_mul(v)); case arith_op_kind::OP_MOD: - return repair_mod(m_ops[vi.m_def_idx]); + return repair_mod(get_op(v)); case arith_op_kind::OP_REM: - return repair_rem(m_ops[vi.m_def_idx]); + return repair_rem(get_op(v)); case arith_op_kind::OP_POWER: - return repair_power(m_ops[vi.m_def_idx]); + return repair_power(get_op(v)); case arith_op_kind::OP_IDIV: - return repair_idiv(m_ops[vi.m_def_idx]); + return repair_idiv(get_op(v)); case arith_op_kind::OP_DIV: - return repair_div(m_ops[vi.m_def_idx]); + return repair_div(get_op(v)); case arith_op_kind::OP_ABS: - return repair_abs(m_ops[vi.m_def_idx]); + return repair_abs(get_op(v)); case arith_op_kind::OP_TO_INT: - return repair_to_int(m_ops[vi.m_def_idx]); + return repair_to_int(get_op(v)); case arith_op_kind::OP_TO_REAL: - return repair_to_real(m_ops[vi.m_def_idx]); + return repair_to_real(get_op(v)); default: throw default_exception("no repair " + mk_pp(e, m)); } @@ -1514,7 +1514,7 @@ namespace sls { case OP_REM: break; case OP_MOD: { - auto v2 = m_ops[vi.m_def_idx].m_arg2; + auto v2 = get_op(v).m_arg2; auto const& vi2 = m_vars[v2]; if (vi2.m_lo && vi2.m_hi && vi2.m_lo->value == vi2.m_hi->value && vi2.m_lo->value > 0) { add_le(v, vi2.m_lo->value - 1); @@ -1532,7 +1532,7 @@ namespace sls { } template - void arith_base::initialize_of_bool_var(sat::bool_var bv) { + void arith_base::initialize_vars_of(sat::bool_var bv) { auto* ineq = get_ineq(bv); if (!ineq) return; @@ -1542,11 +1542,9 @@ namespace sls { m_tmp_set.reset(); for (unsigned i = 0; i < todo.size(); ++i) { var_t u = todo[i]; - auto& ui = m_vars[u]; if (m_tmp_set.contains(u)) continue; m_tmp_set.insert(u); - ui.m_bool_vars_of.push_back(bv); if (is_add(u)) { auto const& ad = get_add(u); for (auto const& [c, w] : ad.m_args) @@ -1558,45 +1556,25 @@ namespace sls { todo.push_back(w); } if (is_op(u)) { - auto const& op = m_ops[ui.m_def_idx]; + auto const& op = get_op(u); todo.push_back(op.m_arg1); todo.push_back(op.m_arg2); } } } + template + void arith_base::initialize_of_bool_var(sat::bool_var bv) { + initialize_vars_of(bv); + for (auto v : m_tmp_set) + m_vars[v].m_bool_vars_of.push_back(bv); + } + template void arith_base::initialize_clauses_of(sat::bool_var bv, unsigned ci) { - auto* ineq = get_ineq(bv); - if (!ineq) - return; - buffer todo; - for (auto const& [coeff, v] : ineq->m_args) - todo.push_back(v); - m_tmp_set.reset(); - for (unsigned i = 0; i < todo.size(); ++i) { - var_t u = todo[i]; - auto& ui = m_vars[u]; - if (m_tmp_set.contains(u)) - continue; - m_tmp_set.insert(u); - ui.m_clauses_of.push_back(ci); - if (is_add(u)) { - auto const& ad = get_add(u); - for (auto const& [c, w] : ad.m_args) - todo.push_back(w); - } - if (is_mul(u)) { - auto const& [w, monomial] = get_mul(u); - for (auto [w, p] : monomial) - todo.push_back(w); - } - if (is_op(u)) { - auto const& op = m_ops[ui.m_def_idx]; - todo.push_back(op.m_arg1); - todo.push_back(op.m_arg2); - } - } + initialize_vars_of(bv); + for (auto v : m_tmp_set) + m_vars[v].m_clauses_of.push_back(ci); } template @@ -1942,8 +1920,7 @@ namespace sls { template num_t arith_base::mul_value_without(var_t m, var_t x) { - auto const& vi = m_vars[m]; - auto const& [w, monomial] = m_muls[vi.m_def_idx]; + auto const& [w, monomial] = get_mul(m); SASSERT(m == w); num_t r(1); for (auto [y, p] : monomial) @@ -2477,52 +2454,52 @@ namespace sls { case arith_op_kind::LAST_ARITH_OP: break; case arith_op_kind::OP_ADD: { - auto ad = m_adds[vi.m_def_idx]; + auto ad = get_add(v); num_t sum(ad.m_coeff); for (auto [c, w] : ad.m_args) sum += c * value(w); return sum == value(v); } case arith_op_kind::OP_MUL: { - auto md = m_muls[vi.m_def_idx]; + auto md = get_mul(v); num_t prod(1); for (auto [w, p] : md.m_monomial) prod *= power_of(value(w), p); return prod == value(v); } case arith_op_kind::OP_MOD: { - auto od = m_ops[vi.m_def_idx]; + auto od = get_op(v); return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2))); } case arith_op_kind::OP_REM: { - auto od = m_ops[vi.m_def_idx]; + auto od = get_op(v); return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2))); } case arith_op_kind::OP_POWER: { - //auto od = m_ops[vi.m_def_idx]; + //auto od = get_op(v); throw default_exception("unsupported " + mk_pp(vi.m_expr, m)); break; } case arith_op_kind::OP_IDIV: { - auto od = m_ops[vi.m_def_idx]; + auto od = get_op(v); return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : div(value(od.m_arg1), value(od.m_arg2))); } case arith_op_kind::OP_DIV: { - auto od = m_ops[vi.m_def_idx]; + auto od = get_op(v); return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : value(od.m_arg1) / value(od.m_arg2)); } case arith_op_kind::OP_ABS: { - auto od = m_ops[vi.m_def_idx]; + auto od = get_op(v); return value(v) == abs(value(od.m_arg1)); } case arith_op_kind::OP_TO_INT: { - auto od = m_ops[vi.m_def_idx]; + auto od = get_op(v); auto val = value(od.m_var); auto v1 = value(od.m_arg1); return val - 1 < v1 && v1 <= val; } case arith_op_kind::OP_TO_REAL: { - auto od = m_ops[vi.m_def_idx]; + auto od = get_op(v); auto val = value(od.m_var); auto v1 = value(od.m_arg1); return val == v1; diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h index b7c1ab872c..a504583751 100644 --- a/src/ast/sls/sls_arith_base.h +++ b/src/ast/sls/sls_arith_base.h @@ -283,6 +283,7 @@ namespace sls { bool is_if(var_t v) const { return m.is_ite(m_vars[v].m_expr); } mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; } add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; } + op_def const& get_op(var_t v) const { SASSERT(is_op(v)); return m_ops[m_vars[v].m_def_idx]; } bool update(var_t v, num_t const& new_value); bool apply_update(); @@ -295,8 +296,9 @@ namespace sls { double compute_score(var_t x, num_t const& delta); void save_best_values(); - void initialize_of_bool_var(sat::bool_var v); - void initialize_clauses_of(sat::bool_var v, unsigned cl); + void initialize_vars_of(sat::bool_var bv); + void initialize_of_bool_var(sat::bool_var bv); + void initialize_clauses_of(sat::bool_var bv, unsigned cl); var_t mk_var(expr* e); var_t mk_term(expr* e); var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y); diff --git a/src/ast/sls/sls_arith_clausal.cpp b/src/ast/sls/sls_arith_clausal.cpp index 3de698afd4..4f4fdd7772 100644 --- a/src/ast/sls/sls_arith_clausal.cpp +++ b/src/ast/sls/sls_arith_clausal.cpp @@ -73,8 +73,8 @@ namespace sls { if (bv != sat::null_bool_var) tout << "bool flip " << bv << "\n"; else if (v != null_arith_var) tout << "arith flip v" << v << "\n"; else tout << "no flip\n"; - tout << "unsat-vars " << vars_in_unsat << "\n"; - tout << "bools: " << bool_in_unsat << " timeup-bool " << time_up_bool << "\n"; + tout << "unsat-vars " << ctx.unsat_vars().size() << "\n"; + tout << "bools: " << (ctx.unsat_vars().size() - ctx.num_external_in_unsat_vars()) << " timeup-bool " << time_up_bool << "\n"; tout << "no-improve bool: " << m_no_improve_bool << "\n"; tout << "no-improve arith: " << m_no_improve_arith << "\n"; tout << "ext: " << ext_in_unsat << " timeup-arith " << time_up_arith << "\n";