Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow skipping stats in GroupAnalysis metric #383

Merged
merged 2 commits into from
Dec 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions fuse/eval/metrics/metrics_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,19 +688,27 @@ def eval(

class GroupAnalysis(MetricWithCollectorBase):
"""
Evaluate a metric per group and compute basic statistics about the different per group results.
Evaluate a metric per group and compute basic statistics over the different per group results.
eval() method returns a dictionary of the following format:
{'mean': <>, 'std': <>, 'median': <>, <group 0>: <>, <group 1>: <>, ...}
"""

def __init__(self, metric: MetricBase, group: str, **super_kwargs: Any) -> None:
def __init__(
self,
metric: MetricBase,
group: str,
compute_group_stats: bool = True,
**super_kwargs: Any,
) -> None:
"""
:param metric: metric to analyze
:param group: key to extract the group from
:compute_group_stats: wether to compute stats such as mean, std, median over the per group results
:param super_kwargs: additional arguments for super class (MetricWithCollectorBase) constructor
"""
super().__init__(group=group, **super_kwargs)
self._metric = metric
self._compute_group_stats = compute_group_stats

def collect(self, batch: Dict) -> None:
"See super class"
Expand All @@ -718,7 +726,9 @@ def reset(self) -> None:
return super().reset()

def eval(
self, results: Dict[str, Any] = None, ids: Optional[Sequence[Hashable]] = None
self,
results: Dict[str, Any] = None,
ids: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Any]:
"""
See super class
Expand All @@ -745,31 +755,32 @@ def eval(
)

# compute stats
group_results_list = list(group_analysis_results.values())
if isinstance(group_results_list[0], dict): # multiple values
# get all keys
all_keys = set()
for group_result in group_results_list:
all_keys |= set(group_result.keys())

for key in all_keys:
values = [group_result[key] for group_result in group_results_list]
if self._compute_group_stats:
group_results_list = list(group_analysis_results.values())
if isinstance(group_results_list[0], dict): # multiple values
# get all keys
all_keys = set()
for group_result in group_results_list:
all_keys |= set(group_result.keys())

for key in all_keys:
values = [group_result[key] for group_result in group_results_list]
try:
group_analysis_results[f"{key}.mean"] = np.mean(values)
group_analysis_results[f"{key}.std"] = np.std(values)
group_analysis_results[f"{key}.median"] = np.median(values)
except:
# do nothing
pass
else: # single value
values = [group_result for group_result in group_results_list]
try:
group_analysis_results[f"{key}.mean"] = np.mean(values)
group_analysis_results[f"{key}.std"] = np.std(values)
group_analysis_results[f"{key}.median"] = np.median(values)
group_analysis_results["mean"] = np.mean(values)
group_analysis_results["std"] = np.std(values)
group_analysis_results["median"] = np.median(values)
except:
# do nothing
pass
else: # single value
values = [group_result for group_result in group_results_list]
try:
group_analysis_results["mean"] = np.mean(values)
group_analysis_results["std"] = np.std(values)
group_analysis_results["median"] = np.median(values)
except:
# do nothing
pass

return group_analysis_results

Expand Down