diff --git a/py/autoevals/ragas.py b/py/autoevals/ragas.py index 9a6d8a5..3ec690d 100644 --- a/py/autoevals/ragas.py +++ b/py/autoevals/ragas.py @@ -681,6 +681,26 @@ def extract_context_precision_request(question, answer, context, **extra_args): ) +def _calculate_context_precision_score(verdicts: list) -> float: + """Calculate RAGAS ContextPrecision score based on position of relevant chunks. + + Score = sum(precision_at_k * verdict_k) / total_relevant + where precision_at_k = relevant chunks up to position k / k + """ + total_relevant = sum(verdicts) + if total_relevant == 0: + return 0.0 + + score = 0.0 + relevant_so_far = 0 + for k, verdict in enumerate(verdicts, start=1): + if verdict == 1: + relevant_so_far += 1 + score += relevant_so_far / k + + return score / total_relevant + + class ContextPrecision(OpenAILLMScorer): """Measures how precise and focused the context is for answering the question. @@ -730,31 +750,55 @@ def _postprocess(self, response): async def _run_eval_async(self, output, expected=None, input=None, context=None, **kwargs): check_required("ContextPrecision", input=input, expected=expected, context=context) - if isinstance(context, list): - context = "\n".join(context) + if not isinstance(context, list): + context = [context] - return self._postprocess( - await arun_cached_request( + # Score each context chunk individually + verdicts = [] + for chunk in context: + response = await arun_cached_request( client=self.client, **extract_context_precision_request( - question=input, answer=expected, context=context, model=self.model, **self.extra_args + question=input, answer=expected, context=chunk, model=self.model, **self.extra_args ), ) + precision = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) + verdicts.append(precision["verdict"]) + + # Apply RAGAS positional formula + score = _calculate_context_precision_score(verdicts) + + return Score( + name=self._name(), + score=score, + metadata={"verdicts": verdicts}, ) def _run_eval_sync(self, output, expected=None, input=None, context=None, **kwargs): check_required("ContextPrecision", input=input, expected=expected, context=context) - if isinstance(context, list): - context = "\n".join(context) + if not isinstance(context, list): + context = [context] - return self._postprocess( - run_cached_request( + # Score each context chunk individually + verdicts = [] + for chunk in context: + response = run_cached_request( client=self.client, **extract_context_precision_request( - question=input, answer=expected, context=context, model=self.model, **self.extra_args + question=input, answer=expected, context=chunk, model=self.model, **self.extra_args ), ) + precision = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"]) + verdicts.append(precision["verdict"]) + + # Apply RAGAS positional formula + score = _calculate_context_precision_score(verdicts) + + return Score( + name=self._name(), + score=score, + metadata={"verdicts": verdicts}, )