import numpy as np
from collections import deque
class PerPromptStatTracker:
def __init__(self, global_std=False):
self.global_std = global_std
self.stats = {}
self.history_prompts = set()
# exp reward is for rwr
def update(self, prompts, rewards, exp=False):
prompts = np.array(prompts)
rewards = np.array(rewards, dtype=np.float64)
unique = np.unique(prompts)
advantages = np.empty_like(rewards) * 0.0
for prompt in unique:
prompt_rewards = rewards[prompts == prompt]
if prompt not in self.stats:
self.stats[prompt] = []
self.stats[prompt].extend(prompt_rewards)
self.history_prompts.add(hash(prompt)) # Add hash of prompt to history_prompts
for prompt in unique:
self.stats[prompt] = np.stack(self.stats[prompt])
prompt_rewards = rewards[prompts == prompt] # Fix: Recalculate prompt_rewards for each prompt
mean = np.mean(self.stats[prompt], axis=0, keepdims=True)
if self.global_std:
std = np.std(rewards, axis=0, keepdims=True) + 1e-4 # Use global std of all rewards
else:
std = np.std(self.stats[prompt], axis=0, keepdims=True) + 1e-4
advantages[prompts == prompt] = (prompt_rewards - mean) / std
return advantages
def get_stats(self):
avg_group_size = sum(len(v) for v in self.stats.values()) / len(self.stats) if self.stats else 0
history_prompts = len(self.history_prompts)
return avg_group_size, history_prompts
def clear(self):
self.stats = {}
def get_mean_of_top_rewards(self, top_percentage):
if not self.stats:
return 0.0
assert 0 <= top_percentage <= 100
per_prompt_top_means = []
for prompt_rewards in self.stats.values():
if isinstance(prompt_rewards, list):
rewards = np.array(prompt_rewards)
else:
rewards = prompt_rewards
if rewards.size == 0:
continue
if top_percentage == 100:
per_prompt_top_means.append(np.mean(rewards))
continue
lower_bound_percentile = 100 - top_percentage
threshold = np.percentile(rewards, lower_bound_percentile)
top_rewards = rewards[rewards >= threshold]
if top_rewards.size > 0:
per_prompt_top_means.append(np.mean(top_rewards))
if not per_prompt_top_means:
return 0.0
return np.mean(per_prompt_top_means)
def main():
tracker = PerPromptStatTracker()
prompts = ["a", "b", "a", "c", "b", "a"]
rewards = [1, 2, 3, 4, 5, 6]
advantages = tracker.update(prompts, rewards)
print("Advantages:", advantages)
avg_group_size, history_prompts = tracker.get_stats()
print("Average Group Size:", avg_group_size)
print("History Prompts:", history_prompts)
tracker.clear()
print("Stats after clear:", tracker.stats)
if __name__ == "__main__":
main()