Skip to content

Commit 63d2b6e

Browse files
authored
fix thread/run thread/{thread_id} route order for craete_thread_and_run to work (#70)
1 parent 3666c71 commit 63d2b6e

File tree

2 files changed

+60
-39
lines changed

2 files changed

+60
-39
lines changed

client/tests/astra-assistants/test_run_v2.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,38 @@ def run_with_assistant(assistant, client):
3939

4040

4141

42+
def create_and_run_with_assistant(assistant, client):
43+
user_message = "What's your favorite animal."
44+
45+
thread = client.beta.threads.create()
46+
47+
client.beta.threads.messages.create(
48+
thread_id=thread.id, role="user", content=user_message
49+
)
50+
run = client.beta.threads.create_and_run(
51+
thread=thread,
52+
assistant_id=assistant.id,
53+
)
54+
55+
logger.info(run)
56+
57+
58+
59+
4260

4361
instructions="You're an animal expert who gives very long winded answers with flowery prose. Keep answers below 3 sentences."
4462
def test_run_gpt_4o_mini(patched_openai_client):
4563
gpt3_assistant = patched_openai_client.beta.assistants.create(
4664
name="GPT3 Animal Tutor",
4765
instructions=instructions,
48-
model="gpt-4o_mini",
66+
model="gpt-4o-mini",
4967
)
5068

5169
assistant = patched_openai_client.beta.assistants.retrieve(gpt3_assistant.id)
5270
logger.info(assistant)
5371

5472
run_with_assistant(gpt3_assistant, patched_openai_client)
73+
create_and_run_with_assistant(gpt3_assistant, patched_openai_client)
5574

5675
def test_run_cohere(patched_openai_client):
5776
cohere_assistant = patched_openai_client.beta.assistants.create(
@@ -91,4 +110,4 @@ def test_run_gemini(patched_openai_client):
91110
instructions=instructions,
92111
model="gemini/gemini-1.5-flash",
93112
)
94-
run_with_assistant(gemini_assistant, patched_openai_client)
113+
run_with_assistant(gemini_assistant, patched_openai_client)

impl/routes_v2/threads_v2.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,45 @@ async def create_thread(
112112
)
113113
return astradb.upsert_table_from_base_model("threads", thread)
114114

115+
@router.post(
116+
"/threads/runs",
117+
responses={
118+
200: {"model": RunObject, "description": "OK"},
119+
},
120+
tags=["Assistants"],
121+
summary="Create a thread and run it in one request.",
122+
response_model_by_alias=True,
123+
response_model=None
124+
)
125+
async def create_thread_and_run(
126+
create_thread_and_run_request: CreateThreadAndRunRequest = Body(None, description=""),
127+
astradb: CassandraClient = Depends(verify_db_client),
128+
embedding_model: str = Depends(infer_embedding_model),
129+
embedding_api_key: str = Depends(infer_embedding_api_key),
130+
litellm_kwargs: tuple[Dict[str, Any]] = Depends(get_litellm_kwargs),
131+
) -> RunObject:
132+
create_thread_request = create_thread_and_run_request.thread
133+
if create_thread_request is None:
134+
raise HTTPException(status_code=400, detail="thread is required.")
135+
136+
thread = await create_thread(create_thread_request, astradb)
137+
138+
create_run_request = CreateRunRequest(
139+
assistant_id=create_thread_and_run_request.assistant_id,
140+
model=create_thread_and_run_request.model,
141+
instructions=create_thread_and_run_request.instructions,
142+
tools=create_thread_and_run_request.tools,
143+
metadata=create_thread_and_run_request.metadata
144+
)
145+
return await create_run(
146+
thread_id=thread.id,
147+
create_run_request=create_run_request,
148+
astradb=astradb,
149+
embedding_model=embedding_model,
150+
embedding_api_key=embedding_api_key,
151+
litellm_kwargs=litellm_kwargs,
152+
)
153+
115154
@router.get(
116155
"/threads/{thread_id}",
117156
responses={
@@ -1823,41 +1862,4 @@ async def make_text_delta_obj_from_chunk(chunk, i, run, message_id):
18231862
return message_delta
18241863

18251864

1826-
@router.post(
1827-
"/threads/runs",
1828-
responses={
1829-
200: {"model": RunObject, "description": "OK"},
1830-
},
1831-
tags=["Assistants"],
1832-
summary="Create a thread and run it in one request.",
1833-
response_model_by_alias=True,
1834-
response_model=None
1835-
)
1836-
async def create_thread_and_run(
1837-
create_thread_and_run_request: CreateThreadAndRunRequest = Body(None, description=""),
1838-
astradb: CassandraClient = Depends(verify_db_client),
1839-
embedding_model: str = Depends(infer_embedding_model),
1840-
embedding_api_key: str = Depends(infer_embedding_api_key),
1841-
litellm_kwargs: tuple[Dict[str, Any]] = Depends(get_litellm_kwargs),
1842-
) -> RunObject:
1843-
create_thread_request = create_thread_and_run_request.thread
1844-
if create_thread_request is None:
1845-
raise HTTPException(status_code=400, detail="thread is required.")
1846-
1847-
thread = await create_thread(create_thread_request, astradb)
18481865

1849-
create_run_request = CreateRunRequest(
1850-
assistant_id=create_thread_and_run_request.assistant_id,
1851-
model=create_thread_and_run_request.model,
1852-
instructions=create_thread_and_run_request.instructions,
1853-
tools=create_thread_and_run_request.tools,
1854-
metadata=create_thread_and_run_request.metadata
1855-
)
1856-
return await create_run(
1857-
thread_id=thread.id,
1858-
create_run_request=create_run_request,
1859-
astradb=astradb,
1860-
embedding_model=embedding_model,
1861-
embedding_api_key=embedding_api_key,
1862-
litellm_kwargs=litellm_kwargs,
1863-
)

0 commit comments

Comments
 (0)