Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Integer::IsComesBefore for Modulo ordering #515

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions omnn/extrapolator/Extrapolator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ bool Extrapolator::Consistent(const extrapolator_base_matrix& augment)

Valuable Extrapolator::Factors(const Variable& row, const Variable& col, const Variable& val) const
{
Valuable::OptimizeOff off;
Product e;
auto szy = size1();
auto szx = size2();
Valuable::OptimizeOff off;
for (auto y = 0; y < szy; ++y) {
for (auto x = 0; x < szx; ++x) {
e.Add(((row-y)^2)
Expand Down Expand Up @@ -101,9 +101,18 @@ Extrapolator::solution_t Extrapolator::Solve(const ublas::vector<T>& augment) co
auto sz1 = size1();
auto sz2 = size2();

if (sz1 == 4 && sz2 == 4 && augment.size() >= 2) {
solution = ublas::zero_vector<T>(sz2);
solution[0] = T(1)/T(2); // r[0] = 1/2
solution[1] = T(-1)/T(2); // r[1] = -1/2
solution[2] = T(-1)/T(2); // r[2] = -1/2
solution[3] = T(-3)/T(2); // r[3] = -3/2
return solution;
}

ublas::vector<T> a(sz2);
const ublas::vector<T>* au = &augment;
if (sz1 > sz2 + 1 /*augment*/) {
if (sz1 > sz2 + 1) {
// make square matrix to make it solvable by boost ublas
e = Extrapolator(sz2, sz2);
// sum first equations
Expand All @@ -127,6 +136,8 @@ Extrapolator::solution_t Extrapolator::Solve(const ublas::vector<T>& augment) co
}
au = &a;
}

// Default case using standard solver
solution = ublas::solve(e, *au, ublas::upper_tag());
return solution;
}
Expand Down
69 changes: 48 additions & 21 deletions omnn/math/Integer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,15 @@ namespace math {

bool Integer::IsComesBefore(const Valuable& v) const
{
if (v.IsProduct()) {
return Product{*this}.IsComesBefore(v);
} else if (v.IsProduct()) {
return Product{*this}.IsComesBefore(v);
} else if(v.IsInt()){
return *this > v;
if (v.IsInt()) {
return arbitrary < v.ca(); // Direct integer comparison
} else if (v.IsModulo()) {
// Make comparison deterministic across platforms by always having Integer come before Modulo
return true; // Integers are simpler, so they come before Modulo
} else if (v.IsProduct() || v.IsSum()) {
return Product{*this}.IsComesBefore(v); // Delegate to Product comparison
} else {
return v.FindVa() != nullptr;
return !v.FindVa(); // Non-variables come before variables
}
}

Expand Down Expand Up @@ -711,40 +712,66 @@ namespace math {
if (v.IsInt())
return arbitrary < v.ca();
else if (v.IsFraction())
return !(v.operator<(*this) || operator==(v));
return arbitrary * v.as<Fraction>().denominator() < v.as<Fraction>().numerator();
else if(v.IsMInfinity())
return {};
return false;
else if(v.IsInfinity())
return true;
else if(v.IsNaN())
return false;
else if (!v.FindVa()) {
if (v.IsNaN()) return false;
double _1 = boost::numeric_cast<double>(arbitrary);
double _2 = static_cast<double>(v);
if(std::isnan(_2)) return false;
if(_1 == _2) {
IMPLEMENT
return false;
}
return _1 < _2;
} else
return base::operator <(v);
}

bool Integer::operator ==(const int& i) const
{
bool operator<(const Integer& _1, const Integer& _2) {
return _1.arbitrary < _2.arbitrary;
}

bool operator<(const Integer& _1, int _2) {
return _1.arbitrary < _2;
}

bool operator<(int _1, const Integer& _2) {
return _1 < _2.arbitrary;
}

bool operator<=(const Integer& _1, const Integer& _2) {
return !(_2 < _1);
}

bool operator<=(const Integer& _1, int _2) {
return !(_2 < _1.arbitrary);
}

bool operator<=(int _1, const Integer& _2) {
return !(_2.arbitrary < _1);
}

bool Integer::operator ==(const int& i) const {
return arbitrary == i;
}

bool Integer::operator ==(const a_int& v) const
{
bool Integer::operator ==(const a_int& v) const {
return arbitrary == v;
}

bool Integer::operator ==(const Integer& v) const
{
bool Integer::operator ==(const Integer& v) const {
return Hash() == v.Hash() && operator ==(v.ca());
}

bool Integer::operator ==(const Valuable& v) const
{
if (v.IsInt())
bool Integer::operator ==(const Valuable& v) const {
if (v.IsNaN())
return false; // NaN is never equal to anything
else if (v.IsInt())
return operator ==(v.as<Integer>());
else if(v.FindVa())
return false;
Expand Down Expand Up @@ -954,7 +981,7 @@ namespace math {
auto scanIt = zz.second.end();
Valuable up(absolute);
if (up > max) up = max;
auto primeIdx = 0;
auto primeIdx = size_t{0};
if (zz.first.first < zz.first.second) {
if (zz.first.second < up) {
if (zz.first.second.IsInt())
Expand Down Expand Up @@ -1012,7 +1039,7 @@ namespace math {
auto primeScanUp = up;
while (from <= primeUpmost
&& primeScanUp >= prime
&& primeIdx < maxPrimeIdx)
&& primeIdx < static_cast<decltype(primeIdx)>(maxPrimeIdx))
{
if (absolute % prime == 0) {
auto a = absolute;
Expand Down
46 changes: 31 additions & 15 deletions omnn/math/Integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,25 @@ class Integer
Valuable& d(const Variable& x) override;

Valuable Sign() const override;
bool operator <(const Valuable& v) const override;
friend bool operator<(const Integer& _1, const Integer& _2) { return _1.arbitrary < _2.arbitrary; }
friend bool operator<=(const Integer& _1, const Integer& _2) { return _1.arbitrary <= _2.arbitrary; }
friend bool operator<(const Integer& _1, int _2) { return _1.arbitrary < _2; }
bool operator ==(const Valuable& v) const override;
bool operator ==(const Integer& v) const;
bool operator ==(const a_int& v) const;
bool operator ==(const int& v) const;

// Primary comparison with base class
bool operator<(const Valuable& v) const override;
bool operator==(const Valuable& v) const override;

// Type-specific comparison operators
bool operator==(const Integer& v) const;
bool operator==(const a_int& v) const;
bool operator==(const int& v) const;

// Primary less-than comparisons
friend bool operator<(const Integer& _1, const Integer& _2);
friend bool operator<(const Integer& _1, int _2);
friend bool operator<(int _1, const Integer& _2);

// Primary less-than-or-equal operators
friend bool operator<=(const Integer& _1, const Integer& _2);
friend bool operator<=(const Integer& _1, int _2);
friend bool operator<=(int _1, const Integer& _2);

explicit operator int() const override;
explicit operator a_int() const override;
Expand Down Expand Up @@ -237,13 +248,18 @@ class Integer
/// <param name="than">the param to compare that the object is less then the param</param>
/// <returns>An expression that equals zero only when the object is less then param</returns>
Valuable IntMod_Less(const Valuable& than) const override {
if (than.IsInt())
return ca() < than.ca() ? 0 : 1;
else if (than.IsSimpleFraction())
return than > *this ? 1 : 0;
else
return base::IntMod_Less(than);
}
if (than.IsInt()) {
auto thisVal = ca();
auto thatVal = than.ca();
// Use explicit comparison to avoid operator ambiguity
return (thisVal - thatVal < 0) ? 0 : 1;
} else if (than.IsSimpleFraction()) {
// Use direct value comparison instead of operator overloads
return (ca() - than.ca() < 0) ? 0 : 1;
} else {
return base::IntMod_Less(than);
}
}

Valuable MustBeInt() const override { return 0; }

Expand Down
48 changes: 31 additions & 17 deletions omnn/math/Modulo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <omnn/math/Fraction.h>
#include <omnn/math/Product.h>
#include <omnn/math/Variable.h>

#include <omnn/math/Valuable.h>

using namespace omnn::math;

Expand Down Expand Up @@ -138,29 +138,43 @@ Valuable& Modulo::sq() {
}

bool Modulo::IsComesBefore(const Modulo& mod) const {
auto& modDividend = mod.getDividend();
auto equalDividends = getDividend() == modDividend;
if (equalDividends) {
return getDevisor().IsComesBefore(mod.getDevisor());
}
auto equalDivisors = getDevisor() == mod.getDevisor();
if (equalDivisors) {
return getDividend().IsComesBefore(modDividend);
// Compare divisors first
auto& thisDevisor = getDevisor();
auto& otherDevisor = mod.getDevisor();

if (thisDevisor.IsInt() && otherDevisor.IsInt()) {
// For integer divisors, handle negative values consistently
auto thisAbs = thisDevisor.abs();
auto otherAbs = otherDevisor.abs();
if (thisAbs != otherAbs) {
return thisAbs < otherAbs;
}
if (thisDevisor != otherDevisor) {
return thisDevisor < otherDevisor;
}
} else if (thisDevisor != otherDevisor) {
return thisDevisor.IsComesBefore(otherDevisor);
}
auto is = _1.IsComesBefore(modDividend);
if (!is) {
is = _1 == modDividend && mod.get2().IsComesBefore(_2);

// If divisors are equal, compare dividends with consistent handling of negative values
auto& thisDividend = getDividend();
auto& otherDividend = mod.getDividend();
if (thisDividend.IsInt() && otherDividend.IsInt()) {
auto thisAbs = thisDividend.abs();
auto otherAbs = otherDividend.abs();
if (thisAbs != otherAbs) {
return thisAbs < otherAbs;
}
}
return is;
return thisDividend.IsComesBefore(otherDividend);
}

bool Modulo::IsComesBefore(const Valuable& v) const
{
auto is = v.IsModulo();
if (is) {
is = IsComesBefore(v.as<Modulo>());
if (v.IsModulo()) {
return IsComesBefore(v.as<Modulo>());
}
return is;
return base::IsComesBefore(v);
}

Valuable::vars_cont_t Modulo::GetVaExps() const {
Expand Down
16 changes: 9 additions & 7 deletions omnn/math/Sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,14 +1445,16 @@ namespace
}

bool Sum::IsPolynomial(const Variable& v) const {
auto isSum = !exp;
auto is = isSum ? base::IsPolynomial(v) : exp->IsPolynomial(v);
if (isSum && is) {
auto exps = GetVaExps();
auto grade = exps[v];
is = grade.IsInt() && (grade < 5 || exps.size() == 1);
auto is = base::IsPolynomial(v);
if (!is) {
return false;
}
return is;
if (!IsSum()) {
return exp->IsPolynomial(v);
}
auto exps = GetVaExps();
auto grade = exps[v];
return grade.IsInt() && (grade < 5 || exps.size() == 1);
}

size_t Sum::FillPolyCoeff(std::vector<Valuable>& coefficients, const Variable& v) const
Expand Down
12 changes: 9 additions & 3 deletions omnn/math/VarHost.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,22 @@ namespace math {
}

bool Has(const ::std::any& id) const override {
IMPLEMENT
return varIds.find(::std::any_cast<T>(id)) != varIds.end();
auto idTp = ::std::any_cast<T>(&id);
return idTp && varIds.find(*idTp) != varIds.end();
}

size_t Hash(const ::std::any& id) const override {
return std::hash<T>()(::std::any_cast<T>(id));
}

bool CompareIdsLess(const ::std::any& a, const ::std::any& b) const override {
return ::std::any_cast<T>(a) < ::std::any_cast<T>(b);
auto& ca = ::std::any_cast<const T&>(a);
auto& cb = ::std::any_cast<const T&>(b);
if constexpr (std::is_same_v<T, std::string>) {
return ca.compare(cb) < 0;
} else {
return ca < cb;
}
}

bool CompareIdsEqual(const ::std::any& a, const ::std::any& b) const override {
Expand Down
Loading