Skip to content

Commit aa3c40c

Browse files
authored
Fireworks rft (JUD-1754) (#511)
* organization code * add custom trainable model * add multiturn * add custom training function * add final steps * add model config * improve syntax * improve syntax * add wrapper fix * add spinner for visual indicator * reduce visual noise * improvement to spinner * add async fix * add some fixes * test this * test this * try this again * change so no weird rich print * revert back * make it work * need to await * add feedback * renaming * fix * fix comments * fix stuff * more fixes * resolve comments * remove print
1 parent 424289f commit aa3c40c

File tree

12 files changed

+1705
-33
lines changed

12 files changed

+1705
-33
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"langchain-core",
3232
"click<8.2.0",
3333
"typer>=0.9.0",
34+
"fireworks-ai>=0.19.18",
3435
]
3536

3637
[project.urls]

src/e2etests/test_tracer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def validate_trace_token_counts(
9090
"TOGETHER_API_CALL",
9191
"GOOGLE_API_CALL",
9292
"GROQ_API_CALL",
93+
"FIREWORKS_TRAINABLE_MODEL_CALL",
9394
}
9495

9596
for span in trace_spans:

src/judgeval/common/tracer/core.py

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,8 @@ def __init__(
815815
== "true",
816816
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
817817
== "true",
818+
show_trace_urls: bool = os.getenv("JUDGMENT_SHOW_TRACE_URLS", "true").lower()
819+
== "true",
818820
# S3 configuration
819821
use_s3: bool = False,
820822
s3_bucket_name: Optional[str] = None,
@@ -859,6 +861,7 @@ def __init__(
859861
self.traces: List[Trace] = []
860862
self.enable_monitoring: bool = enable_monitoring
861863
self.enable_evaluations: bool = enable_evaluations
864+
self.show_trace_urls: bool = show_trace_urls
862865
self.class_identifiers: Dict[
863866
str, str
864867
] = {} # Dictionary to store class identifiers
@@ -1731,6 +1734,93 @@ def _cleanup_on_exit(self):
17311734
f"Error during background service shutdown: {e}"
17321735
)
17331736

1737+
def trace_to_message_history(
1738+
self, trace: Union[Trace, TraceClient]
1739+
) -> List[Dict[str, str]]:
1740+
"""
1741+
Extract message history from a trace for training purposes.
1742+
1743+
This method processes trace spans to reconstruct the conversation flow,
1744+
extracting messages in chronological order from LLM, user, and tool spans.
1745+
1746+
Args:
1747+
trace: Trace or TraceClient instance to extract messages from
1748+
1749+
Returns:
1750+
List of message dictionaries with 'role' and 'content' keys
1751+
1752+
Raises:
1753+
ValueError: If no trace is provided
1754+
"""
1755+
if not trace:
1756+
raise ValueError("No trace provided")
1757+
1758+
# Handle both Trace and TraceClient objects
1759+
if isinstance(trace, TraceClient):
1760+
spans = trace.trace_spans
1761+
else:
1762+
spans = trace.trace_spans if hasattr(trace, "trace_spans") else []
1763+
1764+
messages = []
1765+
first_found = False
1766+
1767+
# Process spans in chronological order
1768+
for span in sorted(
1769+
spans, key=lambda s: s.created_at if hasattr(s, "created_at") else 0
1770+
):
1771+
# Skip spans without output (except for first LLM span which may have input messages)
1772+
if span.output is None and span.span_type != "llm":
1773+
continue
1774+
1775+
if span.span_type == "llm":
1776+
# For the first LLM span, extract input messages (system + user prompts)
1777+
if not first_found and hasattr(span, "inputs") and span.inputs:
1778+
input_messages = span.inputs.get("messages", [])
1779+
if input_messages:
1780+
first_found = True
1781+
# Add input messages (typically system and user messages)
1782+
for msg in input_messages:
1783+
if (
1784+
isinstance(msg, dict)
1785+
and "role" in msg
1786+
and "content" in msg
1787+
):
1788+
messages.append(
1789+
{"role": msg["role"], "content": msg["content"]}
1790+
)
1791+
1792+
# Add assistant response from span output
1793+
if span.output is not None:
1794+
messages.append({"role": "assistant", "content": str(span.output)})
1795+
1796+
elif span.span_type == "user":
1797+
# Add user messages
1798+
if span.output is not None:
1799+
messages.append({"role": "user", "content": str(span.output)})
1800+
1801+
elif span.span_type == "tool":
1802+
# Add tool responses as user messages (common pattern in training)
1803+
if span.output is not None:
1804+
messages.append({"role": "user", "content": str(span.output)})
1805+
1806+
return messages
1807+
1808+
def get_current_message_history(self) -> List[Dict[str, str]]:
1809+
"""
1810+
Get message history from the current trace.
1811+
1812+
Returns:
1813+
List of message dictionaries from the current trace context
1814+
1815+
Raises:
1816+
ValueError: If no current trace is found
1817+
"""
1818+
current_trace = self.get_current_trace()
1819+
if not current_trace:
1820+
raise ValueError("No current trace found")
1821+
1822+
return self.trace_to_message_history(current_trace)
1823+
17341824

17351825
def _get_current_trace(
17361826
trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
@@ -1746,7 +1836,7 @@ def wrap(
17461836
) -> Any:
17471837
"""
17481838
Wraps an API client to add tracing capabilities.
1749-
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1839+
Supports OpenAI, Together, Anthropic, Google GenAI clients, and TrainableModel.
17501840
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
17511841
"""
17521842
(
@@ -1871,6 +1961,39 @@ async def wrapper(*args, **kwargs):
18711961
setattr(client.chat.completions, "create", wrapped(original_create))
18721962
elif isinstance(client, (groq_AsyncGroq)):
18731963
setattr(client.chat.completions, "create", wrapped_async(original_create))
1964+
1965+
# Check for TrainableModel from judgeval.common.trainer
1966+
try:
1967+
from judgeval.common.trainer import TrainableModel
1968+
1969+
if isinstance(client, TrainableModel):
1970+
# Define a wrapper function that can be reapplied to new model instances
1971+
def wrap_model_instance(model_instance):
1972+
"""Wrap a model instance with tracing functionality"""
1973+
if hasattr(model_instance, "chat") and hasattr(
1974+
model_instance.chat, "completions"
1975+
):
1976+
if hasattr(model_instance.chat.completions, "create"):
1977+
setattr(
1978+
model_instance.chat.completions,
1979+
"create",
1980+
wrapped(model_instance.chat.completions.create),
1981+
)
1982+
if hasattr(model_instance.chat.completions, "acreate"):
1983+
setattr(
1984+
model_instance.chat.completions,
1985+
"acreate",
1986+
wrapped_async(model_instance.chat.completions.acreate),
1987+
)
1988+
1989+
# Register the wrapper function with the TrainableModel
1990+
client._register_tracer_wrapper(wrap_model_instance)
1991+
1992+
# Apply wrapping to the current model
1993+
wrap_model_instance(client._current_model)
1994+
except ImportError:
1995+
pass # TrainableModel not available
1996+
18741997
return client
18751998

18761999

@@ -1977,6 +2100,22 @@ def _get_client_config(
19772100
return "GROQ_API_CALL", client.chat.completions.create, None, None, None
19782101
elif isinstance(client, (groq_AsyncGroq)):
19792102
return "GROQ_API_CALL", client.chat.completions.create, None, None, None
2103+
2104+
# Check for TrainableModel
2105+
try:
2106+
from judgeval.common.trainer import TrainableModel
2107+
2108+
if isinstance(client, TrainableModel):
2109+
return (
2110+
"FIREWORKS_TRAINABLE_MODEL_CALL",
2111+
client._current_model.chat.completions.create,
2112+
None,
2113+
None,
2114+
None,
2115+
)
2116+
except ImportError:
2117+
pass # TrainableModel not available
2118+
19802119
raise ValueError(f"Unsupported client type: {type(client)}")
19812120

19822121

@@ -2155,6 +2294,37 @@ def _format_output_data(
21552294
cache_creation_input_tokens,
21562295
)
21572296

2297+
# Check for TrainableModel
2298+
try:
2299+
from judgeval.common.trainer import TrainableModel
2300+
2301+
if isinstance(client, TrainableModel):
2302+
# TrainableModel uses Fireworks LLM internally, so response format should be similar to OpenAI
2303+
if (
2304+
hasattr(response, "model")
2305+
and hasattr(response, "usage")
2306+
and hasattr(response, "choices")
2307+
):
2308+
model_name = response.model
2309+
prompt_tokens = response.usage.prompt_tokens if response.usage else 0
2310+
completion_tokens = (
2311+
response.usage.completion_tokens if response.usage else 0
2312+
)
2313+
message_content = response.choices[0].message.content
2314+
2315+
# Use LiteLLM cost calculation with fireworks_ai prefix
2316+
# LiteLLM supports Fireworks AI models for cost calculation when prefixed with "fireworks_ai/"
2317+
fireworks_model_name = f"fireworks_ai/{model_name}"
2318+
return message_content, _create_usage(
2319+
fireworks_model_name,
2320+
prompt_tokens,
2321+
completion_tokens,
2322+
cache_read_input_tokens,
2323+
cache_creation_input_tokens,
2324+
)
2325+
except ImportError:
2326+
pass # TrainableModel not available
2327+
21582328
judgeval_logger.warning(f"Unsupported client type: {type(client)}")
21592329
return None, None
21602330

src/judgeval/common/tracer/trace_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,12 @@ def upsert_trace(
7171

7272
server_response = self.api_client.upsert_trace(trace_data)
7373

74-
if not offline_mode and show_link and "ui_results_url" in server_response:
74+
if (
75+
not offline_mode
76+
and show_link
77+
and "ui_results_url" in server_response
78+
and self.tracer.show_trace_urls
79+
):
7580
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
7681
rprint(pretty_str)
7782

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .trainer import JudgmentTrainer
2+
from .config import TrainerConfig, ModelConfig
3+
from .trainable_model import TrainableModel
4+
5+
__all__ = ["JudgmentTrainer", "TrainerConfig", "ModelConfig", "TrainableModel"]

src/judgeval/common/trainer/config.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Dict, Any
3+
import json
4+
5+
6+
@dataclass
7+
class TrainerConfig:
8+
"""Configuration class for JudgmentTrainer parameters."""
9+
10+
deployment_id: str
11+
user_id: str
12+
model_id: str
13+
base_model_name: str = "qwen2p5-7b-instruct"
14+
rft_provider: str = "fireworks"
15+
num_steps: int = 5
16+
num_generations_per_prompt: int = (
17+
5 # Number of rollouts/generations per input prompt
18+
)
19+
num_prompts_per_step: int = 4 # Number of input prompts to sample per training step
20+
concurrency: int = 100
21+
epochs: int = 1
22+
learning_rate: float = 1e-5
23+
accelerator_count: int = 1
24+
accelerator_type: str = "NVIDIA_A100_80GB"
25+
temperature: float = 1.5
26+
max_tokens: int = 50
27+
enable_addons: bool = True
28+
29+
30+
@dataclass
31+
class ModelConfig:
32+
"""
33+
Configuration class for storing and loading trained model state.
34+
35+
This class enables persistence of trained models so they can be loaded
36+
and used later without retraining.
37+
38+
Example usage:
39+
trainer = JudgmentTrainer(config)
40+
model_config = trainer.train(agent_function, scorers, prompts)
41+
42+
# Save the trained model configuration
43+
model_config.save_to_file("my_trained_model.json")
44+
45+
# Later, load and use the trained model
46+
loaded_config = ModelConfig.load_from_file("my_trained_model.json")
47+
trained_model = TrainableModel.from_model_config(loaded_config)
48+
49+
# Use the trained model for inference
50+
response = trained_model.chat.completions.create(
51+
model="current", # Uses the loaded trained model
52+
messages=[{"role": "user", "content": "Hello!"}]
53+
)
54+
"""
55+
56+
# Base model configuration
57+
base_model_name: str
58+
deployment_id: str
59+
user_id: str
60+
model_id: str
61+
enable_addons: bool
62+
63+
# Training state
64+
current_step: int
65+
total_steps: int
66+
67+
# Current model information
68+
current_model_name: Optional[str] = None
69+
is_trained: bool = False
70+
71+
# Training parameters used (for reference)
72+
training_params: Optional[Dict[str, Any]] = None
73+
74+
def to_dict(self) -> Dict[str, Any]:
75+
"""Convert ModelConfig to dictionary for serialization."""
76+
return {
77+
"base_model_name": self.base_model_name,
78+
"deployment_id": self.deployment_id,
79+
"user_id": self.user_id,
80+
"model_id": self.model_id,
81+
"enable_addons": self.enable_addons,
82+
"current_step": self.current_step,
83+
"total_steps": self.total_steps,
84+
"current_model_name": self.current_model_name,
85+
"is_trained": self.is_trained,
86+
"training_params": self.training_params,
87+
}
88+
89+
@classmethod
90+
def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
91+
"""Create ModelConfig from dictionary."""
92+
return cls(
93+
base_model_name=data.get("base_model_name", "qwen2p5-7b-instruct"),
94+
deployment_id=data.get("deployment_id", "my-base-deployment"),
95+
user_id=data.get("user_id", ""),
96+
model_id=data.get("model_id", ""),
97+
enable_addons=data.get("enable_addons", True),
98+
current_step=data.get("current_step", 0),
99+
total_steps=data.get("total_steps", 0),
100+
current_model_name=data.get("current_model_name"),
101+
is_trained=data.get("is_trained", False),
102+
training_params=data.get("training_params"),
103+
)
104+
105+
def to_json(self) -> str:
106+
"""Convert ModelConfig to JSON string."""
107+
return json.dumps(self.to_dict(), indent=2)
108+
109+
@classmethod
110+
def from_json(cls, json_str: str) -> "ModelConfig":
111+
"""Create ModelConfig from JSON string."""
112+
data = json.loads(json_str)
113+
return cls.from_dict(data)
114+
115+
def save_to_file(self, filepath: str):
116+
"""Save ModelConfig to a JSON file."""
117+
with open(filepath, "w") as f:
118+
f.write(self.to_json())
119+
120+
@classmethod
121+
def load_from_file(cls, filepath: str) -> "ModelConfig":
122+
"""Load ModelConfig from a JSON file."""
123+
with open(filepath, "r") as f:
124+
json_str = f.read()
125+
return cls.from_json(json_str)

0 commit comments

Comments
 (0)