Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Active ranking #394

Merged
merged 17 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions fuse/eval/metrics/libs/efficient_active_ranking_pairwise_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, List

import numpy as np


@dataclass
class ItemStats:
wins: int = 0
comparisons: int = 0
lower_bound: float = 0.0
upper_bound: float = 1.0


class EfficientRanking:
def __init__(
self,
items: List[Any],
compare_pairwise_fn: Callable[[Any, Any], bool],
confidence: float = 0.95,
min_comparisons: int = 32,
):
"""
items: items to be ranked
compare_pairwise_fn: a function like this:

def pairwise_compare(a:int, b:int) -> bool:
# return boolean - should have a lower ranking number than b?
# an AI model (or anything) that predicts it.
# note: for binding affinity the convention is that being ranked lower means stronger binding.
# for example, rank 0 is the strongest binder in the list.
confidence: confidence bounds that will be used in the method
min_comparison: the minimum amount of comparison each items must participate in before providing the global ranking on the total list.
"""

self.items = items
self.compare = compare_pairwise_fn
self.confidence = confidence
self.min_comparisons = min_comparisons
self.stats = defaultdict(ItemStats)
self.total_comparisons = 0

def _update_bounds(self, item: Any) -> None:
"""Update confidence bounds using Hoeffding's inequality"""
stats = self.stats[item]
if stats.comparisons == 0:
return

p_hat = stats.wins / stats.comparisons
epsilon = np.sqrt(np.log(2 / self.confidence) / (2 * stats.comparisons))
stats.lower_bound = max(0.0, p_hat - epsilon)
stats.upper_bound = min(1.0, p_hat + epsilon)

def _compare_with_bounds(self, a: Any, b: Any) -> bool:
"""Compare items using confidence bounds to minimize comparisons"""
stats_a = self.stats[a]
stats_b = self.stats[b]

# If bounds don't overlap, we can decide without comparison
if stats_a.lower_bound > stats_b.upper_bound:
return True
if stats_b.lower_bound > stats_a.upper_bound:
return False

# If either item has few comparisons, compare directly
if (
stats_a.comparisons < self.min_comparisons
or stats_b.comparisons < self.min_comparisons
):
result = self.compare(a, b)
self.total_comparisons += 1

# Update statistics
if result:
stats_a.wins += 1
else:
stats_b.wins += 1
stats_a.comparisons += 1
stats_b.comparisons += 1

self._update_bounds(a)
self._update_bounds(b)

return result

# If bounds overlap but items have enough comparisons,
# use current best estimate
return stats_a.wins / stats_a.comparisons > stats_b.wins / stats_b.comparisons

def _adaptive_merge(self, items: List[Any]) -> List[Any]:
"""Merge sort with adaptive sampling"""
if len(items) <= 1:
return items

mid = len(items) // 2
left = self._adaptive_merge(items[:mid])
right = self._adaptive_merge(items[mid:])

# Merge with confidence bounds
result = []
i = j = 0

while i < len(left) and j < len(right):
if self._compare_with_bounds(left[i], right[j]):
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1

result.extend(left[i:])
result.extend(right[j:])
return result

def _quicksort_partition(self, items: List[Any], start: int, end: int) -> int:
"""Partition items around pivot using adaptive sampling"""
if end - start <= 1:
return start

pivot = items[end - 1]
i = start

for j in range(start, end - 1):
if self._compare_with_bounds(items[j], pivot):
items[i], items[j] = items[j], items[i]
i += 1

items[i], items[end - 1] = items[end - 1], items[i]
return i

def _adaptive_quicksort(self, items: List[Any], start: int, end: int) -> None:
"""Quicksort with adaptive sampling"""
if end - start <= 1:
return

pivot = self._quicksort_partition(items, start, end)
self._adaptive_quicksort(items, start, pivot)
self._adaptive_quicksort(items, pivot + 1, end)

def rank(self, method: str = "merge") -> List[Any]:
"""Rank items using specified method"""
items = self.items.copy()

if method == "merge":
return self._adaptive_merge(items)
elif method == "quick":
self._adaptive_quicksort(items, 0, len(items))
return items
else:
raise ValueError(f"Unknown method: {method}")


if __name__ == "__main__":
from functools import partial

from scipy.stats import spearmanr

def compare_fn(a: Any, b: Any, noise_rate: float = 0.0) -> bool:
# Your comparison function
# return model.predict(a, b)
if np.random.random() < noise_rate:
return np.random.random() < 0.5
return a < b

num_samples = 10000

true_scores = np.arange(0, num_samples, 1)

to_be_ranked = true_scores.copy()
np.random.shuffle(to_be_ranked)

ranker = EfficientRanking(
to_be_ranked, partial(compare_fn, noise_rate=0.1), confidence=0.95
)
ranked_items = ranker.rank(method="merge") # or 'quick'
print(f"Total comparisons: {ranker.total_comparisons}")
sr = spearmanr(ranked_items, true_scores)
print(f"spearman r = {sr.statistic} p = {sr.pvalue}")
171 changes: 171 additions & 0 deletions fuse/eval/metrics/libs/efficient_ranking_batch_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import random
from collections import defaultdict
from typing import Any, Callable, List, Optional

import numpy as np
from tqdm import trange


def aggregate_rankings(
all_items: List[Any],
ranking_model: Callable[[List[Any]], List[Any]],
ranking_model_rank_batch_size: int = 8,
budget: Optional[int] = None,
) -> List[Any]:
"""
Aggregate rankings by efficient pairwise comparison.

Args:
- ranking_model: Function that returns a subset ranking

for example:

def compare_fn(items: List, number_of_random_flipped: int = 0) -> List:
## for simulating synthetic data with noise, in reality an AI model will be used
ans = sorted(items)

if number_of_random_flipped > 0:
length = len(items)
for _ in range(number_of_random_flipped):
i = np.random.randint(0, length)
val_i = items[i]
j = np.random.randint(0, length)
val_j = items[j]
items[i] = val_j
items[j] = val_i

return ans


- all_items: Complete list of items to be ranked
- num_samples: Number of random sampling iterations

Returns:
Globally optimized ranking of items
"""
total_num_samples = len(all_items)
if budget is None:
budget = int(np.ceil(np.log10(total_num_samples) * total_num_samples * 2))
print("no budget selected, defaulting to budget=", budget)
else:
print("budget=", budget)
item_scores = defaultdict(int)

# Precompute total unique pairs to avoid redundant comparisons
unique_pairs = set()

# Generate pairwise comparisons through random sampling
for _ in trange(budget):
# Randomly select a subset of items
sample_size = min(
len(all_items), random.randint(2, ranking_model_rank_batch_size)
)
sample = np.random.choice(all_items, size=sample_size)

# Get model's ranking for this subset
ranked_sample = ranking_model(sample)

# Record pairwise comparisons
for i in range(len(ranked_sample)):
for j in range(i + 1, len(ranked_sample)):
higher_item = ranked_sample[i]
lower_item = ranked_sample[j]

# Create a hashable pair
pair = (higher_item, lower_item)

# Avoid redundant comparisons
if pair not in unique_pairs:
item_scores[higher_item] += 1
item_scores[lower_item] -= 1
unique_pairs.add(pair)

# Sort items based on their aggregate score
global_ranking = sorted(all_items, key=lambda x: item_scores[x], reverse=True)

return global_ranking


if __name__ == "__main__":
from functools import partial

from scipy.stats import spearmanr

def compare_fn(items: List, number_of_random_flipped: int = 0) -> List:
ans = sorted(items)

if number_of_random_flipped > 0:
length = len(items)
for _ in range(number_of_random_flipped):
i = np.random.randint(0, length)
val_i = items[i]
j = np.random.randint(0, length)
val_j = items[j]
items[i] = val_j
items[j] = val_i

return ans

num_samples = 100000
budget = 100000

true_scores = np.arange(0, num_samples, 1)
to_be_ranked = true_scores.copy()
np.random.shuffle(to_be_ranked)

ranked_items = aggregate_rankings(
to_be_ranked,
partial(compare_fn, number_of_random_flipped=20),
budget=budget,
)

# print(f"Total comparisons: {ranker.total_comparisons}")
sr = spearmanr(ranked_items, true_scores)
print(f"spearman r = {sr.statistic} p = {sr.pvalue}")

##############

# an example that uses a pairwise compare function

def pairwise_compare_fn(a: Any, b: Any, noise_rate: float = 0.0) -> bool:
# Your comparison function
# return model.predict(a, b)
if np.random.random() < noise_rate:
return np.random.random() < 0.5
return a < b

def convert_pairwise_to_ranker(
pairwise_model: Callable[[Any, Any], bool]
) -> Callable[[List], List]:
"""
A helper function that converts a pairwise model to a subsample ranker of length 2,
to support using pairwise model in `aggregate_rankings()`

pairwise_model: a function that returns true if the item provided as first argument should be ranked higher than the item provided as the second argument
"""

def build_func(items: List) -> List:
assert 2 == len(items)
if pairwise_model(items[0], items[1]):
return items
# flip the order
return items[::-1]

return build_func

ranker_func = convert_pairwise_to_ranker(pairwise_compare_fn)

true_scores = np.arange(0, num_samples, 1)
to_be_ranked = true_scores.copy()
np.random.shuffle(to_be_ranked)

ranked_items = aggregate_rankings(
to_be_ranked,
ranker_func,
budget=budget * 4,
ranking_model_rank_batch_size=2,
)

# print(f"Total comparisons: {ranker.total_comparisons}")
sr = spearmanr(ranked_items, true_scores)
print(f"spearman r = {sr.statistic} p = {sr.pvalue}")
Empty file.
Loading
Loading