Skip to content

Commit 7ca6c59

Browse files
committed
Added tool dependency metric
1 parent e94850d commit 7ca6c59

File tree

8 files changed

+58
-16
lines changed

8 files changed

+58
-16
lines changed

src/demo/multi_agent/multi_agent.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pydantic import BaseModel
33
from judgeval.common.tracer import Tracer, wrap
44
from judgeval import JudgmentClient
5-
from judgeval.scorers import ToolOrderScorer
5+
from judgeval.scorers import ToolOrderScorer, ToolDependencyScorer
66
from judgeval.common.tracer import Tracer
77
import os
88

@@ -56,12 +56,13 @@ def run_simple_task(self, prompt: str):
5656
# Create two agents
5757
alice = self.add_agent("Alice")
5858
bob = self.add_agent("Bob")
59+
charles = self.add_agent("Charles")
5960

6061
# Have them exchange messages
61-
alice.send_message("Hello Bob, how are you?", "Bob")
62-
bob.send_message("I'm good Alice, thanks for asking!", "Alice")
63-
alice.send_message("Great to hear! Let's work together on a task.", "Bob")
6462

63+
bob.send_message("I'm good Alice, thanks for asking!", "Alice")
64+
alice.send_message("Great to hear! What about you, Charles?", "Charles")
65+
charles.send_message("I'm good Alice, thanks for asking!", "Alice")
6566
# Print the conversation
6667
print("\nAlice's messages:")
6768
for msg in alice.get_all_messages():
@@ -70,17 +71,29 @@ def run_simple_task(self, prompt: str):
7071
print("\nBob's messages:")
7172
for msg in bob.get_all_messages():
7273
print(f"From {msg.sender}: {msg.content}")
74+
75+
print("\nCharles's messages:")
76+
for msg in charles.get_all_messages():
77+
print(f"From {msg.sender}: {msg.content}")
7378

7479
# Example usage
7580
if __name__ == "__main__":
7681
system = MultiAgentSystem()
7782

78-
test_file = os.path.join(os.path.dirname(__file__), "tests.yaml")
83+
# test_file = os.path.join(os.path.dirname(__file__), "tests.yaml")
84+
# judgment_client.assert_test(
85+
# scorers=[ToolOrderScorer(threshold=0.5)],
86+
# function=system.run_simple_task,
87+
# tracer=judgment,
88+
# override=True,
89+
# test_file=test_file
90+
# )
91+
92+
test_file2 = os.path.join(os.path.dirname(__file__), "tests2.yaml")
7993
judgment_client.assert_test(
80-
scorers=[ToolOrderScorer(threshold=0.5)],
94+
scorers=[ToolDependencyScorer(threshold=0.5)],
8195
function=system.run_simple_task,
8296
tracer=judgment,
8397
override=True,
84-
test_file=test_file
98+
test_file=test_file2
8599
)
86-

src/demo/multi_agent/tests.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ examples:
1616
# name: "Random Tool Agent"
1717
- tool_name: "send_message"
1818
agent: Alice
19+
- tool_name: "send_message"
20+
agent: Charles
1921
# parameters:
2022
# self:
2123
# name: "Random Tool Agent"

src/demo/multi_agent/tests2.yaml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
expected_tools:
2-
- tool_name: "send_message"
3-
agent: Alice
4-
- tool_name: "send_message"
5-
agent: Bob
6-
- tool_name: "send_message"
7-
agent: Alice
1+
examples:
2+
- input:
3+
prompt: "Do something random"
4+
expected_tools:
5+
- tool_name: "send_message"
6+
agent: Bob
7+
dependencies:
8+
- tool_name: "send_message"
9+
agent: Alice
10+
- tool_name: "send_message"
11+
agent: Charles
12+
# require_all: true

src/judgeval/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class APIScorer(str, Enum):
2828
GROUNDEDNESS = "groundedness"
2929
DERAILMENT = "derailment"
3030
TOOL_ORDER = "tool_order"
31+
TOOL_DEPENDENCY = "tool_dependency"
3132
@classmethod
3233
def _missing_(cls, value):
3334
# Handle case-insensitive lookup

src/judgeval/run_evaluation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ def run_trace_eval(trace_run: TraceRun, override: bool = False, ignore_errors: b
387387
trace_run.organization_id,
388388
True
389389
)
390-
391390
if function and tracer:
392391
new_traces: List[Trace] = []
393392
tracer.offline_mode = True

src/judgeval/scorers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
GroundednessScorer,
1818
DerailmentScorer,
1919
ToolOrderScorer,
20+
ToolDependencyScorer,
2021
)
2122
from judgeval.scorers.judgeval_scorers.classifiers import (
2223
Text2SQLScorer,
@@ -43,4 +44,5 @@
4344
"GroundednessScorer",
4445
"DerailmentScorer",
4546
"ToolOrderScorer",
47+
"ToolDependencyScorer",
4648
]

src/judgeval/scorers/judgeval_scorers/api_scorers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from judgeval.scorers.judgeval_scorers.api_scorers.groundedness import GroundednessScorer
1414
from judgeval.scorers.judgeval_scorers.api_scorers.derailment_scorer import DerailmentScorer
1515
from judgeval.scorers.judgeval_scorers.api_scorers.tool_order import ToolOrderScorer
16+
from judgeval.scorers.judgeval_scorers.api_scorers.tool_dependency import ToolDependencyScorer
1617
__all__ = [
1718
"ExecutionOrderScorer",
1819
"JSONCorrectnessScorer",
@@ -29,4 +30,5 @@
2930
"GroundednessScorer",
3031
"DerailmentScorer",
3132
"ToolOrderScorer",
33+
"ToolDependencyScorer",
3234
]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
`judgeval` tool dependency scorer
3+
"""
4+
5+
# Internal imports
6+
from judgeval.scorers.api_scorer import APIJudgmentScorer
7+
from judgeval.constants import APIScorer
8+
9+
class ToolDependencyScorer(APIJudgmentScorer):
10+
def __init__(self, threshold: float=1.0):
11+
super().__init__(
12+
threshold=threshold,
13+
score_type=APIScorer.TOOL_DEPENDENCY,
14+
)
15+
16+
@property
17+
def __name__(self):
18+
return "Tool Dependency"

0 commit comments

Comments
 (0)