Skip to content

Commit

Permalink
consolidate functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Jan 26, 2025
1 parent a701057 commit 12e8082
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 77 deletions.
123 changes: 50 additions & 73 deletions src/ast/sls/sls_arith_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,11 @@ namespace sls {
if (vi.m_op == arith_op_kind::OP_NUM)
return;
if (is_add(v) && m_allow_recursive_delta)
add_update_add(m_adds[vi.m_def_idx], delta_out);
add_update_add(get_add(v), delta_out);
else if (is_mul(v) && m_allow_recursive_delta)
add_update_mul(m_muls[vi.m_def_idx], delta_out);
add_update_mul(get_mul(v), delta_out);
else if (is_op(v) && m_allow_recursive_delta)
add_update(m_ops[vi.m_def_idx], delta_out);
add_update(get_op(v), delta_out);
else if (vi.is_if_op() && m_allow_recursive_delta) {
expr* c, * t, * e;
VERIFY(m.is_ite(vi.m_expr, c, t, e));
Expand Down Expand Up @@ -1283,48 +1283,48 @@ namespace sls {
case LAST_ARITH_OP:
break;
case OP_ADD: {
auto const& ad = m_adds[vi.m_def_idx];
auto const& ad = get_add(v);
auto const& args = ad.m_args;
result = ad.m_coeff;
for (auto [c, w] : args)
result += c * value(w);
break;
}
case OP_MUL: {
auto const& [w, monomial] = m_muls[vi.m_def_idx];
auto const& [w, monomial] = get_mul(v);
result = num_t(1);
for (auto [w, p] : monomial)
result *= power_of(value(w), p);
break;
}
case OP_MOD:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
v1 = value(get_op(v).m_arg1);
v2 = value(get_op(v).m_arg2);
result = v2 == 0 ? num_t(0) : mod(v1, v2);
break;
case OP_DIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
v1 = value(get_op(v).m_arg1);
v2 = value(get_op(v).m_arg2);
result = v2 == 0 ? num_t(0) : v1 / v2;
break;
case OP_IDIV:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
v1 = value(get_op(v).m_arg1);
v2 = value(get_op(v).m_arg2);
result = v2 == 0 ? num_t(0) : div(v1, v2);
break;
case OP_REM:
v1 = value(m_ops[vi.m_def_idx].m_arg1);
v2 = value(m_ops[vi.m_def_idx].m_arg2);
v1 = value(get_op(v).m_arg1);
v2 = value(get_op(v).m_arg2);
result = v2 == 0 ? num_t(0) : v1 %= v2;
break;
case OP_ABS:
result = abs(value(m_ops[vi.m_def_idx].m_arg1));
result = abs(value(get_op(v).m_arg1));
break;
case OP_TO_REAL:
result = value(m_ops[vi.m_def_idx].m_arg1);
result = value(get_op(v).m_arg1);
break;
case OP_TO_INT: {
rational r = value(m_ops[vi.m_def_idx].m_arg1).to_rational();
rational r = value(get_op(v).m_arg1).to_rational();
result = to_num(floor(r));
break;
}
Expand Down Expand Up @@ -1368,25 +1368,25 @@ namespace sls {
case arith_op_kind::LAST_ARITH_OP:
break;
case arith_op_kind::OP_ADD:
return repair_add(m_adds[vi.m_def_idx]);
return repair_add(get_add(v));
case arith_op_kind::OP_MUL:
return repair_mul(m_muls[vi.m_def_idx]);
return repair_mul(get_mul(v));
case arith_op_kind::OP_MOD:
return repair_mod(m_ops[vi.m_def_idx]);
return repair_mod(get_op(v));
case arith_op_kind::OP_REM:
return repair_rem(m_ops[vi.m_def_idx]);
return repair_rem(get_op(v));
case arith_op_kind::OP_POWER:
return repair_power(m_ops[vi.m_def_idx]);
return repair_power(get_op(v));
case arith_op_kind::OP_IDIV:
return repair_idiv(m_ops[vi.m_def_idx]);
return repair_idiv(get_op(v));
case arith_op_kind::OP_DIV:
return repair_div(m_ops[vi.m_def_idx]);
return repair_div(get_op(v));
case arith_op_kind::OP_ABS:
return repair_abs(m_ops[vi.m_def_idx]);
return repair_abs(get_op(v));
case arith_op_kind::OP_TO_INT:
return repair_to_int(m_ops[vi.m_def_idx]);
return repair_to_int(get_op(v));
case arith_op_kind::OP_TO_REAL:
return repair_to_real(m_ops[vi.m_def_idx]);
return repair_to_real(get_op(v));
default:
throw default_exception("no repair " + mk_pp(e, m));
}
Expand Down Expand Up @@ -1514,7 +1514,7 @@ namespace sls {
case OP_REM:
break;
case OP_MOD: {
auto v2 = m_ops[vi.m_def_idx].m_arg2;
auto v2 = get_op(v).m_arg2;
auto const& vi2 = m_vars[v2];
if (vi2.m_lo && vi2.m_hi && vi2.m_lo->value == vi2.m_hi->value && vi2.m_lo->value > 0) {
add_le(v, vi2.m_lo->value - 1);
Expand All @@ -1532,7 +1532,7 @@ namespace sls {
}

template<typename num_t>
void arith_base<num_t>::initialize_of_bool_var(sat::bool_var bv) {
void arith_base<num_t>::initialize_vars_of(sat::bool_var bv) {
auto* ineq = get_ineq(bv);
if (!ineq)
return;
Expand All @@ -1542,11 +1542,9 @@ namespace sls {
m_tmp_set.reset();
for (unsigned i = 0; i < todo.size(); ++i) {
var_t u = todo[i];
auto& ui = m_vars[u];
if (m_tmp_set.contains(u))
continue;
m_tmp_set.insert(u);
ui.m_bool_vars_of.push_back(bv);
if (is_add(u)) {
auto const& ad = get_add(u);
for (auto const& [c, w] : ad.m_args)
Expand All @@ -1558,45 +1556,25 @@ namespace sls {
todo.push_back(w);
}
if (is_op(u)) {
auto const& op = m_ops[ui.m_def_idx];
auto const& op = get_op(u);
todo.push_back(op.m_arg1);
todo.push_back(op.m_arg2);
}
}
}

template<typename num_t>
void arith_base<num_t>::initialize_of_bool_var(sat::bool_var bv) {
initialize_vars_of(bv);
for (auto v : m_tmp_set)
m_vars[v].m_bool_vars_of.push_back(bv);
}

template<typename num_t>
void arith_base<num_t>::initialize_clauses_of(sat::bool_var bv, unsigned ci) {
auto* ineq = get_ineq(bv);
if (!ineq)
return;
buffer<var_t> todo;
for (auto const& [coeff, v] : ineq->m_args)
todo.push_back(v);
m_tmp_set.reset();
for (unsigned i = 0; i < todo.size(); ++i) {
var_t u = todo[i];
auto& ui = m_vars[u];
if (m_tmp_set.contains(u))
continue;
m_tmp_set.insert(u);
ui.m_clauses_of.push_back(ci);
if (is_add(u)) {
auto const& ad = get_add(u);
for (auto const& [c, w] : ad.m_args)
todo.push_back(w);
}
if (is_mul(u)) {
auto const& [w, monomial] = get_mul(u);
for (auto [w, p] : monomial)
todo.push_back(w);
}
if (is_op(u)) {
auto const& op = m_ops[ui.m_def_idx];
todo.push_back(op.m_arg1);
todo.push_back(op.m_arg2);
}
}
initialize_vars_of(bv);
for (auto v : m_tmp_set)
m_vars[v].m_clauses_of.push_back(ci);
}

template<typename num_t>
Expand Down Expand Up @@ -1942,8 +1920,7 @@ namespace sls {

template<typename num_t>
num_t arith_base<num_t>::mul_value_without(var_t m, var_t x) {
auto const& vi = m_vars[m];
auto const& [w, monomial] = m_muls[vi.m_def_idx];
auto const& [w, monomial] = get_mul(m);
SASSERT(m == w);
num_t r(1);
for (auto [y, p] : monomial)
Expand Down Expand Up @@ -2477,52 +2454,52 @@ namespace sls {
case arith_op_kind::LAST_ARITH_OP:
break;
case arith_op_kind::OP_ADD: {
auto ad = m_adds[vi.m_def_idx];
auto ad = get_add(v);
num_t sum(ad.m_coeff);
for (auto [c, w] : ad.m_args)
sum += c * value(w);
return sum == value(v);
}
case arith_op_kind::OP_MUL: {
auto md = m_muls[vi.m_def_idx];
auto md = get_mul(v);
num_t prod(1);
for (auto [w, p] : md.m_monomial)
prod *= power_of(value(w), p);
return prod == value(v);
}
case arith_op_kind::OP_MOD: {
auto od = m_ops[vi.m_def_idx];
auto od = get_op(v);
return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2)));
}
case arith_op_kind::OP_REM: {
auto od = m_ops[vi.m_def_idx];
auto od = get_op(v);
return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2)));
}
case arith_op_kind::OP_POWER: {
//auto od = m_ops[vi.m_def_idx];
//auto od = get_op(v);
throw default_exception("unsupported " + mk_pp(vi.m_expr, m));
break;
}
case arith_op_kind::OP_IDIV: {
auto od = m_ops[vi.m_def_idx];
auto od = get_op(v);
return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : div(value(od.m_arg1), value(od.m_arg2)));
}
case arith_op_kind::OP_DIV: {
auto od = m_ops[vi.m_def_idx];
auto od = get_op(v);
return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : value(od.m_arg1) / value(od.m_arg2));
}
case arith_op_kind::OP_ABS: {
auto od = m_ops[vi.m_def_idx];
auto od = get_op(v);
return value(v) == abs(value(od.m_arg1));
}
case arith_op_kind::OP_TO_INT: {
auto od = m_ops[vi.m_def_idx];
auto od = get_op(v);
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
return val - 1 < v1 && v1 <= val;
}
case arith_op_kind::OP_TO_REAL: {
auto od = m_ops[vi.m_def_idx];
auto od = get_op(v);
auto val = value(od.m_var);
auto v1 = value(od.m_arg1);
return val == v1;
Expand Down
6 changes: 4 additions & 2 deletions src/ast/sls/sls_arith_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ namespace sls {
bool is_if(var_t v) const { return m.is_ite(m_vars[v].m_expr); }
mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; }
add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; }
op_def const& get_op(var_t v) const { SASSERT(is_op(v)); return m_ops[m_vars[v].m_def_idx]; }

bool update(var_t v, num_t const& new_value);
bool apply_update();
Expand All @@ -295,8 +296,9 @@ namespace sls {
double compute_score(var_t x, num_t const& delta);
void save_best_values();

void initialize_of_bool_var(sat::bool_var v);
void initialize_clauses_of(sat::bool_var v, unsigned cl);
void initialize_vars_of(sat::bool_var bv);
void initialize_of_bool_var(sat::bool_var bv);
void initialize_clauses_of(sat::bool_var bv, unsigned cl);
var_t mk_var(expr* e);
var_t mk_term(expr* e);
var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y);
Expand Down
4 changes: 2 additions & 2 deletions src/ast/sls/sls_arith_clausal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ namespace sls {
if (bv != sat::null_bool_var) tout << "bool flip " << bv << "\n";
else if (v != null_arith_var) tout << "arith flip v" << v << "\n";
else tout << "no flip\n";
tout << "unsat-vars " << vars_in_unsat << "\n";
tout << "bools: " << bool_in_unsat << " timeup-bool " << time_up_bool << "\n";
tout << "unsat-vars " << ctx.unsat_vars().size() << "\n";
tout << "bools: " << (ctx.unsat_vars().size() - ctx.num_external_in_unsat_vars()) << " timeup-bool " << time_up_bool << "\n";
tout << "no-improve bool: " << m_no_improve_bool << "\n";
tout << "no-improve arith: " << m_no_improve_arith << "\n";
tout << "ext: " << ext_in_unsat << " timeup-arith " << time_up_arith << "\n";
Expand Down

0 comments on commit 12e8082

Please sign in to comment.