@osmosis_reward 装饰函数并返回 0 到 1 之间的浮点值即可。
基本示例
文件:reward_fn/compute_reward.py
复制
询问AI
import re
from osmosis_ai import osmosis_reward
def extract_solution(solution_str):
solution = re.search(r'####\s*([-+]?\d*\.?\d+)', solution_str)
if not solution:
return None
return solution.group(1)
@osmosis_reward
def numbers_match_reward(
solution_str: str,
ground_truth: str,
extra_info: dict = None,
**kwargs
) -> float:
"""
Reward function that checks if extracted number matches ground truth.
Returns 1.0 for match, 0.0 otherwise.
"""
extracted = extract_solution(solution_str)
try:
sol_val = float(extracted)
gt_val = float(ground_truth)
return 1.0 if abs(gt_val - sol_val) < 1e-7 else 0.0
except:
return 0.0
常见模式
精确匹配
复制
询问AI
@osmosis_reward
def exact_match_reward(
solution_str: str,
ground_truth: str,
extra_info: dict = None,
**kwargs
) -> float:
"""Returns 1.0 for exact match, 0.0 otherwise"""
return 1.0 if solution_str.strip() == ground_truth.strip() else 0.0
多条件评估
复制
询问AI
import json
@osmosis_reward
def multi_criteria_reward(
solution_str: str,
ground_truth: str,
extra_info: dict = None,
**kwargs
) -> float:
"""Evaluates multiple criteria and returns weighted average"""
try:
solution = json.loads(solution_str)
expected = json.loads(ground_truth)
scores = []
# Criterion 1: Correctness (weight: 0.5)
correctness = 1.0 if solution["answer"] == expected["answer"] else 0.0
scores.append(correctness * 0.5)
# Criterion 2: Explanation quality (weight: 0.3)
explanation_length = len(solution.get("explanation", ""))
explanation_score = min(explanation_length / 100, 1.0)
scores.append(explanation_score * 0.3)
# Criterion 3: Code validity (weight: 0.2)
has_code = "code" in solution and len(solution["code"]) > 0
scores.append(1.0 * 0.2 if has_code else 0.0)
return sum(scores)
except:
return 0.0