Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix various bugs in DiscordBitSet #772

Merged
merged 1 commit into from
Feb 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 31 additions & 24 deletions common/src/main/kotlin/DiscordBitSet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ import kotlin.math.max
import kotlin.math.min

private const val SAFE_LENGTH = 19
private const val WIDTH = Byte.SIZE_BITS
private const val WIDTH = Long.SIZE_BITS

@Suppress("FunctionName")
public fun EmptyBitSet(): DiscordBitSet = DiscordBitSet(0)
public fun EmptyBitSet(): DiscordBitSet = DiscordBitSet()

@Serializable(with = DiscordBitSetSerializer::class)
public class DiscordBitSet(internal var data: LongArray) {
public class DiscordBitSet(internal var data: LongArray) { // data is in little-endian order

public val isEmpty: Boolean
get() = data.all { it == 0L }

public val value: String
get() {
// need to convert from little-endian data to big-endian expected by BigInteger
val buffer = ByteBuffer.allocate(data.size * Long.SIZE_BYTES)
buffer.asLongBuffer().put(data.reversedArray())
return BigInteger(buffer.array()).toString()
Expand All @@ -35,53 +36,64 @@ public class DiscordBitSet(internal var data: LongArray) {
get() = data.size * WIDTH

public val binary: String
get() = data.joinToString("") { it.toULong().toString(2) }.reversed().padEnd(8, '0')
get() = data.map { it.toULong().toString(radix = 2).padStart(length = ULong.SIZE_BITS, '0') }
.reversed()
.joinToString(separator = "")
.trimStart('0')
.ifEmpty { "0" }

override fun equals(other: Any?): Boolean {
if (other !is DiscordBitSet) return false
for (i in 0 until max(data.size, other.data.size)) {
if (getOrZero(i) != getOrZero(i)) return false
// trailing zeros are ignored -> getOrZero
for (i in 0 until max(this.data.size, other.data.size)) {
if (this.getOrZero(i) != other.getOrZero(i)) return false
}
return true
}

override fun hashCode(): Int {
var result = 1
// trailing zeros are ignored to have the same hashCode for equal bit sets
for (i in 0..(data.indexOfLast { it != 0L })) {
result = (31 * result) + data[i].hashCode()
}
return result
}

private fun getOrZero(i: Int) = data.getOrNull(i) ?: 0L

public operator fun get(index: Int): Boolean {
if (index !in 0 until size) return false
require(index >= 0)
if (index >= size) return false
val indexOfWidth = index / WIDTH
val bitIndex = index % WIDTH
return data[indexOfWidth] and (1L shl bitIndex) != 0L
}

public operator fun contains(other: DiscordBitSet): Boolean {
if (other.size > size) return false
for (i in other.data.indices) {
if (data[i] and other.data[i] != other.data[i]) return false
for ((index, value) in other.data.withIndex()) {
if ((this.getOrZero(index) and value) != value) return false
}
return true
}

public operator fun set(index: Int, value: Boolean) {
if (index !in 0 until size) data.copyOf((63 + index) / WIDTH)
require(index >= 0)
val indexOfWidth = index / WIDTH
if (index >= size) data = data.copyOf(indexOfWidth + 1)
val bitIndex = index % WIDTH
val bit = if (value) 1L else 0L
data[index] = data[indexOfWidth] or (bit shl bitIndex)
val prev = data[indexOfWidth]
data[indexOfWidth] = if (value) prev or (1L shl bitIndex) else prev and (1L shl bitIndex).inv()
}

public operator fun plus(another: DiscordBitSet): DiscordBitSet {
val dist = LongArray(data.size)
data.copyInto(dist)
val copy = DiscordBitSet(dist)
val copy = DiscordBitSet(data.copyOf())
copy.add(another)
return copy
}

public operator fun minus(another: DiscordBitSet): DiscordBitSet {
val dist = LongArray(data.size)
data.copyInto(dist)
val copy = DiscordBitSet(dist)
val copy = DiscordBitSet(data.copyOf())
copy.remove(another)
return copy
}
Expand All @@ -100,11 +112,6 @@ public class DiscordBitSet(internal var data: LongArray) {
}
}

override fun hashCode(): Int {
var result = data.contentHashCode()
result = 31 * result + size
return result
}

override fun toString(): String {
return "DiscordBitSet($binary)"
Expand Down
68 changes: 55 additions & 13 deletions common/src/test/kotlin/BitSetTests.kt
Original file line number Diff line number Diff line change
@@ -1,26 +1,58 @@
import dev.kord.common.DiscordBitSet
import org.junit.jupiter.api.Test
import dev.kord.common.EmptyBitSet
import kotlin.test.*

class BitSetTests {
@Test
fun `b contains a`() {
val a = DiscordBitSet(0b101)
val b = DiscordBitSet(0b111)
assert(a in b)
fun `a contains b and c`() {
val a = DiscordBitSet(0b111)
val b = DiscordBitSet(0b101)
val c = DiscordBitSet(0b101, 0)
assertTrue(b in a)
assertTrue(c in a)
}

@Test
fun `a equals b`() {
fun `a and b are equal and have the same hashCode`() {
val a = DiscordBitSet(0b111, 0)
val b = DiscordBitSet(0b111)
assert(a == b)
assertEquals(a, b)
assertEquals(a.hashCode(), b.hashCode())
}

@Test
fun `a does not equal b`() {
val a = DiscordBitSet(0b111, 0)
val b = DiscordBitSet(0b111, 0b1)
assertNotEquals(a, b)
}

@Test
fun `get a bit`() {
fun `get bits`() {
val a = DiscordBitSet(0b101, 0)
assert(!a[1])
assertTrue(a[0])
assertFalse(a[1])
assertTrue(a[2])
for (i in 3..64) assertFalse(a[i])

val b = DiscordBitSet(1L shl 63)
for (i in 0..62) assertFalse(b[i])
assertTrue(b[63])
}

@Test
fun `set bits`() {
val a = EmptyBitSet()
for (i in 0..64) a[i] = true
assertEquals(DiscordBitSet(ULong.MAX_VALUE.toLong(), 1), a)

val b = EmptyBitSet()
b[1] = true
b[2] = true
b[5] = true
assertEquals(DiscordBitSet(0b100110), b)
b[2] = false
assertEquals(DiscordBitSet(0b100010), b)
}

@Test
Expand All @@ -30,21 +62,31 @@ class BitSetTests {
}

@Test
fun `add and remove a bit`() {
fun `add and remove a bit`() {
val a = DiscordBitSet(0b101, 0)
a.add(DiscordBitSet(0b111))
assert(a.value == 0b111.toString())
a.remove(DiscordBitSet(0b001))
assert(a.value == 0b110.toString())

}

@Test
fun `remove a bit`() {
val a = DiscordBitSet(0b101, 0)
a.remove(DiscordBitSet(0b111))
assert(a.value == "0")

}

}
@Test
fun `binary works`() {
assertEquals("0", DiscordBitSet().binary)
assertEquals("0", DiscordBitSet(0).binary)
assertEquals("10011", DiscordBitSet(0b10011).binary)
assertEquals(
"110" +
"0000000000000000000000000000000000000000000000000000000000111001" +
"0000000000000000000000000000000000000000000000000000000000001011",
DiscordBitSet(0b1011, 0b111001, 0b110).binary,
)
}
}