Skip to content

Commit f9c50d5

Browse files
committed
add assert test fix
1 parent ae48920 commit f9c50d5

File tree

2 files changed

+30
-44
lines changed

2 files changed

+30
-44
lines changed

src/judgeval/judgment_client.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def push_classifier_scorer(self, scorer: ClassifierScorer, slug: str = None) ->
480480

481481
return response.json()["slug"]
482482

483-
def assert_test(
483+
async def assert_test(
484484
self,
485485
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
486486
examples: Optional[List[Example]] = None,
@@ -494,7 +494,8 @@ def assert_test(
494494
override: bool = False,
495495
rules: Optional[List[Rule]] = None,
496496
function: Optional[Callable] = None,
497-
tracer: Optional[Union[Tracer, BaseCallbackHandler]] = None
497+
tracer: Optional[Union[Tracer, BaseCallbackHandler]] = None,
498+
async_execution: bool = False
498499
) -> None:
499500
"""
500501
Asserts a test by running the evaluation and checking the results for success
@@ -532,7 +533,7 @@ def assert_test(
532533
test_file=test_file
533534
)
534535
else:
535-
results = self.run_evaluation(
536+
results = await self.run_evaluation(
536537
examples=examples,
537538
scorers=scorers,
538539
model=model,
@@ -542,7 +543,8 @@ def assert_test(
542543
project_name=project_name,
543544
eval_run_name=eval_run_name,
544545
override=override,
545-
rules=rules
546+
rules=rules,
547+
async_execution=async_execution
546548
)
547549

548550
assert_test(results)

src/judgeval/run_evaluation.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ async def get_evaluation_status(eval_name: str, project_name: str, judgment_api_
471471
error(f"Failed to check evaluation status: {str(e)}")
472472
raise JudgmentAPIError(f"Failed to check evaluation status: {str(e)}")
473473

474-
async def _poll_evaluation_until_complete(eval_name: str, project_name: str, judgment_api_key: str, organization_id: str, poll_interval_seconds: int = 5, original_examples: Optional[List[Example]] = None, expected_scorers: Optional[List[Union[str, Any]]] = None) -> List[ScoringResult]:
474+
async def _poll_evaluation_until_complete(eval_name: str, project_name: str, judgment_api_key: str, organization_id: str, poll_interval_seconds: int = 5, original_examples: Optional[List[Example]] = None) -> List[ScoringResult]:
475475
"""
476476
Polls until the evaluation is complete and returns the results.
477477
@@ -483,8 +483,6 @@ async def _poll_evaluation_until_complete(eval_name: str, project_name: str, jud
483483
poll_interval_seconds (int, optional): Time between status checks in seconds. Defaults to 5.
484484
original_examples (List[Example], optional): The original examples sent for evaluation.
485485
If provided, will match results with original examples.
486-
expected_scorers (List[Union[str, Any]], optional): List of expected scorer names or scorer objects.
487-
Used to verify all scorer data is present.
488486
489487
Returns:
490488
List[ScoringResult]: The evaluation results
@@ -496,19 +494,8 @@ async def _poll_evaluation_until_complete(eval_name: str, project_name: str, jud
496494
for example in original_examples:
497495
original_example_map[example.example_id] = example
498496

499-
# Extract expected scorer names if provided
500-
expected_scorer_names = []
501-
if expected_scorers:
502-
for scorer in expected_scorers:
503-
if isinstance(scorer, str):
504-
expected_scorer_names.append(scorer)
505-
elif hasattr(scorer, 'name'):
506-
expected_scorer_names.append(scorer.name)
507-
elif hasattr(scorer, 'score_type') and hasattr(scorer.score_type, 'value'):
508-
expected_scorer_names.append(scorer.score_type.value)
509-
510-
debug(f"Expecting results for these scorers: {expected_scorer_names}")
511-
497+
# Remove the expected scorer names extraction and checking
498+
# We'll instead verify all examples have consistent scorer data
512499
while True:
513500
poll_count += 1
514501
try:
@@ -567,6 +554,7 @@ async def _poll_evaluation_until_complete(eval_name: str, project_name: str, jud
567554

568555
if "examples" in result_data:
569556
examples_data = result_data.get("examples", [])
557+
570558

571559
info(f"Successfully fetched {len(examples_data)} results for evaluation '{eval_name}'")
572560

@@ -576,6 +564,7 @@ async def _poll_evaluation_until_complete(eval_name: str, project_name: str, jud
576564
has_invalid_results = False
577565
for example_data in examples_data:
578566
example_id = example_data.get("example_id")
567+
579568
if example_id not in original_example_map:
580569
warning(f"Server returned example with ID {example_id} not found in original examples. " +
581570
f"This indicates stale or incorrect data. Continuing to poll...")
@@ -594,32 +583,28 @@ async def _poll_evaluation_until_complete(eval_name: str, project_name: str, jud
594583
f"This indicates incomplete data. Continuing to poll...")
595584
await asyncio.sleep(poll_interval_seconds)
596585
continue
597-
598-
# Verify all scorer data is present if expected_scorer_names is provided
599-
if expected_scorer_names:
600-
has_incomplete_scorer_data = False
586+
587+
# Collect all example IDs from scorer data
588+
scorer_example_ids = set()
601589
for example_data in examples_data:
602590
scorer_data_list = example_data.get("scorer_data", [])
603-
604-
# Extract scorer names from the retrieved data
605-
retrieved_scorer_names = set()
606591
for scorer_data in scorer_data_list:
607-
name = scorer_data.get("name")
608-
if name:
609-
retrieved_scorer_names.add(name)
610-
611-
# Check if all expected scorers are present
612-
missing_scorers = set(expected_scorer_names) - retrieved_scorer_names
613-
if missing_scorers:
614-
example_id = example_data.get("example_id", "unknown")
615-
warning(f"Example {example_id} is missing scorer data for: {missing_scorers}. " +
616-
f"Continuing to poll for complete data...")
617-
has_incomplete_scorer_data = True
618-
break
592+
if "example_id" in scorer_data:
593+
scorer_example_ids.add(scorer_data["example_id"])
594+
595+
# Get the set of original example IDs
596+
original_example_ids = set(original_example_map.keys())
597+
598+
# Check if the sets are equal
599+
missing_in_scorer = original_example_ids - scorer_example_ids
600+
extra_in_scorer = scorer_example_ids - original_example_ids
619601

620-
# If any example has incomplete scorer data, continue polling
621-
if has_incomplete_scorer_data:
622-
info("Detected incomplete scorer data. Waiting before polling again...")
602+
if missing_in_scorer or extra_in_scorer:
603+
if missing_in_scorer:
604+
warning(f"Examples missing in scorer data: {missing_in_scorer}")
605+
if extra_in_scorer:
606+
warning(f"Extra examples in scorer data: {extra_in_scorer}")
607+
info("Detected mismatched example IDs in scorer data. Waiting before polling again...")
623608
await asyncio.sleep(poll_interval_seconds)
624609
continue
625610

@@ -807,8 +792,7 @@ async def _async_evaluation_workflow():
807792
project_name=evaluation_run.project_name,
808793
judgment_api_key=evaluation_run.judgment_api_key,
809794
organization_id=evaluation_run.organization_id,
810-
original_examples=evaluation_run.examples, # Pass the original examples
811-
expected_scorers=evaluation_run.scorers # Pass the expected scorers for verification
795+
original_examples=evaluation_run.examples # Pass the original examples
812796
)
813797

814798
# Create and return a task that can be awaited

0 commit comments

Comments
 (0)