-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Active ranking #394
Active ranking #394
Changes from 10 commits
5532935
099d81c
32506ee
cb2f007
92f1595
02405b1
ffb7f07
63b1a7d
de5c346
1e906e2
d47466a
75d3cf0
c8e5508
fc52be0
8763112
91210ac
0782a60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import numpy as np | ||
from typing import Callable, List, Any | ||
from dataclasses import dataclass | ||
from collections import defaultdict | ||
|
||
|
||
@dataclass | ||
class ItemStats: | ||
wins: int = 0 | ||
comparisons: int = 0 | ||
lower_bound: float = 0.0 | ||
upper_bound: float = 1.0 | ||
|
||
|
||
class EfficientRanking: | ||
def __init__( | ||
self, | ||
items: List[Any], | ||
compare_pairwise_fn: Callable[[Any, Any], bool], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you share some details about the expected arguments of this function? |
||
confidence: float = 0.95, | ||
min_comparisons: int = 32, | ||
): | ||
self.items = items | ||
self.compare = compare_fn | ||
self.confidence = confidence | ||
self.min_comparisons = min_comparisons | ||
self.stats = defaultdict(ItemStats) | ||
self.total_comparisons = 0 | ||
|
||
def _update_bounds(self, item: Any) -> None: | ||
"""Update confidence bounds using Hoeffding's inequality""" | ||
stats = self.stats[item] | ||
if stats.comparisons == 0: | ||
return | ||
|
||
p_hat = stats.wins / stats.comparisons | ||
epsilon = np.sqrt(np.log(2 / self.confidence) / (2 * stats.comparisons)) | ||
stats.lower_bound = max(0.0, p_hat - epsilon) | ||
stats.upper_bound = min(1.0, p_hat + epsilon) | ||
|
||
def _compare_with_bounds(self, a: Any, b: Any) -> bool: | ||
"""Compare items using confidence bounds to minimize comparisons""" | ||
stats_a = self.stats[a] | ||
stats_b = self.stats[b] | ||
|
||
# If bounds don't overlap, we can decide without comparison | ||
if stats_a.lower_bound > stats_b.upper_bound: | ||
return True | ||
if stats_b.lower_bound > stats_a.upper_bound: | ||
return False | ||
|
||
# If either item has few comparisons, compare directly | ||
if ( | ||
stats_a.comparisons < self.min_comparisons | ||
or stats_b.comparisons < self.min_comparisons | ||
): | ||
result = self.compare(a, b) | ||
self.total_comparisons += 1 | ||
|
||
# Update statistics | ||
if result: | ||
stats_a.wins += 1 | ||
else: | ||
stats_b.wins += 1 | ||
stats_a.comparisons += 1 | ||
stats_b.comparisons += 1 | ||
|
||
self._update_bounds(a) | ||
self._update_bounds(b) | ||
|
||
return result | ||
|
||
# If bounds overlap but items have enough comparisons, | ||
# use current best estimate | ||
return stats_a.wins / stats_a.comparisons > stats_b.wins / stats_b.comparisons | ||
|
||
def _adaptive_merge(self, items: List[Any]) -> List[Any]: | ||
"""Merge sort with adaptive sampling""" | ||
if len(items) <= 1: | ||
return items | ||
|
||
mid = len(items) // 2 | ||
left = self._adaptive_merge(items[:mid]) | ||
right = self._adaptive_merge(items[mid:]) | ||
|
||
# Merge with confidence bounds | ||
result = [] | ||
i = j = 0 | ||
|
||
while i < len(left) and j < len(right): | ||
if self._compare_with_bounds(left[i], right[j]): | ||
result.append(left[i]) | ||
i += 1 | ||
else: | ||
result.append(right[j]) | ||
j += 1 | ||
|
||
result.extend(left[i:]) | ||
result.extend(right[j:]) | ||
return result | ||
|
||
def _quicksort_partition(self, items: List[Any], start: int, end: int) -> int: | ||
"""Partition items around pivot using adaptive sampling""" | ||
if end - start <= 1: | ||
return start | ||
|
||
pivot = items[end - 1] | ||
i = start | ||
|
||
for j in range(start, end - 1): | ||
if self._compare_with_bounds(items[j], pivot): | ||
items[i], items[j] = items[j], items[i] | ||
i += 1 | ||
|
||
items[i], items[end - 1] = items[end - 1], items[i] | ||
return i | ||
|
||
def _adaptive_quicksort(self, items: List[Any], start: int, end: int) -> None: | ||
"""Quicksort with adaptive sampling""" | ||
if end - start <= 1: | ||
return | ||
|
||
pivot = self._quicksort_partition(items, start, end) | ||
self._adaptive_quicksort(items, start, pivot) | ||
self._adaptive_quicksort(items, pivot + 1, end) | ||
|
||
def rank(self, method: str = "merge") -> List[Any]: | ||
"""Rank items using specified method""" | ||
items = self.items.copy() | ||
|
||
if method == "merge": | ||
return self._adaptive_merge(items) | ||
elif method == "quick": | ||
self._adaptive_quicksort(items, 0, len(items)) | ||
return items | ||
else: | ||
raise ValueError(f"Unknown method: {method}") | ||
|
||
|
||
if __name__ == "__main__": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider converting it to unittest. |
||
from scipy.stats import spearmanr | ||
from functools import partial | ||
|
||
def compare_fn(a: Any, b: Any, noise_rate: float = 0.0) -> bool: | ||
# Your comparison function | ||
# return model.predict(a, b) | ||
if np.random.random() < noise_rate: | ||
return np.random.random() < 0.5 | ||
return a < b | ||
|
||
num_samples = 10000 | ||
|
||
true_scores = np.arange(0, num_samples, 1) | ||
|
||
to_be_ranked = true_scores.copy() | ||
np.random.shuffle(to_be_ranked) | ||
|
||
ranker = EfficientRanking( | ||
to_be_ranked, partial(compare_fn, noise_rate=0.1), confidence=0.95 | ||
) | ||
ranked_items = ranker.rank(method="merge") # or 'quick' | ||
print(f"Total comparisons: {ranker.total_comparisons}") | ||
sr = spearmanr(ranked_items, true_scores) | ||
print(f"spearman r = {sr.statistic} p = {sr.pvalue}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import random | ||
from typing import List, Callable, Any, Optional | ||
import numpy as np | ||
from tqdm import trange | ||
from collections import defaultdict | ||
|
||
|
||
def aggregate_rankings( | ||
all_items: List[Any], | ||
ranking_model: Callable[[List[Any]], List[Any]], | ||
ranking_model_rank_batch_size: int = 8, | ||
budget: Optional[int] = None, | ||
) -> List[Any]: | ||
""" | ||
Aggregate rankings by efficient pairwise comparison. | ||
|
||
Args: | ||
- ranking_model: Function that returns a subset ranking | ||
- all_items: Complete list of items to be ranked | ||
- num_samples: Number of random sampling iterations | ||
|
||
Returns: | ||
Globally optimized ranking of items | ||
""" | ||
total_num_samples = len(all_items) | ||
if budget is None: | ||
budget = int(np.ceil(np.log10(total_num_samples) * total_num_samples * 2)) | ||
print("no budget selected, defaulting to budget=", budget) | ||
else: | ||
print("budget=", budget) | ||
item_scores = defaultdict(int) | ||
|
||
# Precompute total unique pairs to avoid redundant comparisons | ||
unique_pairs = set() | ||
|
||
# Generate pairwise comparisons through random sampling | ||
for _ in trange(budget): | ||
# Randomly select a subset of items | ||
sample_size = min( | ||
len(all_items), random.randint(2, ranking_model_rank_batch_size) | ||
) | ||
sample = np.random.choice(all_items, size=sample_size) | ||
|
||
# Get model's ranking for this subset | ||
ranked_sample = ranking_model(sample) | ||
|
||
# Record pairwise comparisons | ||
for i in range(len(ranked_sample)): | ||
for j in range(i + 1, len(ranked_sample)): | ||
higher_item = ranked_sample[i] | ||
lower_item = ranked_sample[j] | ||
|
||
# Create a hashable pair | ||
pair = (higher_item, lower_item) | ||
|
||
# Avoid redundant comparisons | ||
if pair not in unique_pairs: | ||
item_scores[higher_item] += 1 | ||
item_scores[lower_item] -= 1 | ||
unique_pairs.add(pair) | ||
|
||
# Sort items based on their aggregate score | ||
global_ranking = sorted(all_items, key=lambda x: item_scores[x], reverse=True) | ||
|
||
return global_ranking | ||
|
||
|
||
if __name__ == "__main__": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider converting it to unittest |
||
from scipy.stats import spearmanr | ||
from functools import partial | ||
|
||
def compare_fn(items: List, number_of_random_flipped: int = 0) -> List: | ||
ans = sorted(items) | ||
|
||
if number_of_random_flipped > 0: | ||
length = len(items) | ||
for _ in range(number_of_random_flipped): | ||
i = np.random.randint(0, length) | ||
val_i = items[i] | ||
j = np.random.randint(0, length) | ||
val_j = items[j] | ||
items[i] = val_j | ||
items[j] = val_i | ||
|
||
return ans | ||
|
||
num_samples = 100000 | ||
budget = 100000 | ||
|
||
true_scores = np.arange(0, num_samples, 1) | ||
to_be_ranked = true_scores.copy() | ||
np.random.shuffle(to_be_ranked) | ||
|
||
ranked_items = aggregate_rankings( | ||
to_be_ranked, | ||
partial(compare_fn, number_of_random_flipped=20), | ||
budget=budget, | ||
) | ||
|
||
# print(f"Total comparisons: {ranker.total_comparisons}") | ||
sr = spearmanr(ranked_items, true_scores) | ||
print(f"spearman r = {sr.statistic} p = {sr.pvalue}") | ||
|
||
############## | ||
|
||
# an example that uses a pairwise compare function | ||
|
||
def pairwise_compare_fn(a: Any, b: Any, noise_rate: float = 0.0) -> bool: | ||
# Your comparison function | ||
# return model.predict(a, b) | ||
if np.random.random() < noise_rate: | ||
return np.random.random() < 0.5 | ||
return a < b | ||
|
||
def convert_pairwise_to_ranker( | ||
pairwise_model: Callable[[Any, Any], bool] | ||
) -> Callable[[List], List]: | ||
""" | ||
A helper function that converts a pairwise model to a subsample ranker of length 2, | ||
to support using pairwise model in `aggregate_rankings()` | ||
|
||
pairwise_model: a function that returns true if the item provided as first argument should be ranked higher than the item provided as the second argument | ||
""" | ||
|
||
def build_func(items: List) -> List: | ||
assert 2 == len(items) | ||
if pairwise_model(items[0], items[1]): | ||
return items | ||
# flip the order | ||
return items[::-1] | ||
|
||
return build_func | ||
|
||
ranker_func = convert_pairwise_to_ranker(pairwise_compare_fn) | ||
|
||
true_scores = np.arange(0, num_samples, 1) | ||
to_be_ranked = true_scores.copy() | ||
np.random.shuffle(to_be_ranked) | ||
|
||
ranked_items = aggregate_rankings( | ||
to_be_ranked, | ||
ranker_func, | ||
budget=budget * 4, | ||
ranking_model_rank_batch_size=2, | ||
) | ||
|
||
# print(f"Total comparisons: {ranker.total_comparisons}") | ||
sr = spearmanr(ranked_items, true_scores) | ||
print(f"spearman r = {sr.statistic} p = {sr.pvalue}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note the
isort
error.Maybe you need to reinstall the pre-commit hooks? Something like
pre-commit install
when the repo is the current directory. If I right, you'll need to re-run the pre-commits afterwards to make sure it applies the changes :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides that the Jenkins seems to pass (after a rebuilt)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!! @SagiPolaczek