Skip to content

Commit 2b231b4

Browse files
authored
Do not convert json schema types to python types for tools in GenericProvider. (#17)
* Do not convert json schema types to python for tools in GenericProvider. * Fix bug in get_provider() * Update tests.
1 parent a3a428d commit 2b231b4

File tree

3 files changed

+104
-42
lines changed

3 files changed

+104
-42
lines changed

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -814,9 +814,7 @@ def convert_to_oci_tool(
814814
"type": "object",
815815
"properties": {
816816
p_name: {
817-
"type": JSON_TO_PYTHON_TYPES.get(
818-
p_def.get("type"), p_def.get("type", "string")
819-
),
817+
"type": p_def.get("type", "any"),
820818
"description": p_def.get("description", ""),
821819
}
822820
for p_name, p_def in tool.args.items()

libs/oci/langchain_oci/llms/oci_generative_ai.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def completion_response_to_text(self, response: Any) -> str:
4242

4343
class GenericProvider(Provider):
4444
"""Provider for models using generic API spec."""
45+
4546
stop_sequence_key: str = "stop"
4647

4748
def __init__(self) -> None:
@@ -51,10 +52,11 @@ def __init__(self) -> None:
5152

5253
def completion_response_to_text(self, response: Any) -> str:
5354
return response.data.inference_response.choices[0].text
54-
55+
5556

5657
class MetaProvider(GenericProvider):
5758
"""Provider for Meta models. This provider is for backward compatibility."""
59+
5860
pass
5961

6062

@@ -217,15 +219,14 @@ def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
217219
elif self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
218220
raise ValueError("provider is required for custom endpoints.")
219221
else:
220-
221-
provider = provider_map.get(self.model_id.split(".")[0].lower(), "generic")
222+
provider = self.model_id.split(".")[0].lower()
223+
# Use generic provider for non-custom endpoint
224+
# if provider derived from the model_id is not in the provider map
225+
if provider not in provider_map:
226+
provider = "generic"
222227

223228
if provider not in provider_map:
224-
raise ValueError(
225-
f"Invalid provider derived from model_id: {self.model_id} "
226-
"Please explicitly pass in the supported provider "
227-
"when using custom endpoint"
228-
)
229+
raise ValueError(f"Invalid provider {provider}.")
229230
return provider_map[provider]
230231

231232

libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py

Lines changed: 94 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __getattr__(self, val): # type: ignore[no-untyped-def]
2424

2525
@pytest.mark.requires("oci")
2626
@pytest.mark.parametrize(
27-
"test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"]
27+
"test_model_id", ["cohere.command-r-16k", "meta.llama-3.3-70b-instruct"]
2828
)
2929
def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
3030
"""Test valid chat call to OCI Generative AI LLM service."""
@@ -77,25 +77,34 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
7777
{
7878
"chat_response": MockResponseDict(
7979
{
80+
"api_format": "GENERIC",
8081
"choices": [
8182
MockResponseDict(
8283
{
8384
"message": MockResponseDict(
8485
{
86+
"role": "ASSISTANT",
87+
"name": None,
8588
"content": [
8689
MockResponseDict(
8790
{
8891
"text": response_text, # noqa: E501
92+
"type": "TEXT",
8993
}
9094
)
9195
],
9296
"tool_calls": [
9397
MockResponseDict(
9498
{
95-
"type": "function",
99+
"type": "FUNCTION",
96100
"id": "call_123",
97-
"function": {
98-
"name": "get_weather", # noqa: E501
101+
"name": "get_weather", # noqa: E501
102+
"arguments": '{"location": "current location"}', # noqa: E501
103+
"attribute_map": {
104+
"id": "id",
105+
"type": "type",
106+
"name": "name",
107+
"arguments": "arguments", # noqa: E501
99108
},
100109
}
101110
)
@@ -106,10 +115,10 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
106115
}
107116
)
108117
],
109-
"time_created": "2024-09-01T00:00:00Z",
118+
"time_created": "2025-08-14T10:00:01.100000+00:00",
110119
}
111120
),
112-
"model_id": "meta.llama-3.1-70b-instruct",
121+
"model_id": "meta.llama-3.3-70b-instruct",
113122
"model_version": "1.0.0",
114123
}
115124
),
@@ -164,11 +173,15 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
164173
"tool_calls": [
165174
MockResponseDict(
166175
{
167-
"type": "function",
176+
"type": "FUNCTION",
168177
"id": "call_456",
169-
"function": {
170-
"name": "get_weather", # noqa: E501
171-
"arguments": '{"location": "San Francisco"}', # noqa: E501
178+
"name": "get_weather", # noqa: E501
179+
"arguments": '{"location": "San Francisco"}', # noqa: E501
180+
"attribute_map": {
181+
"id": "id",
182+
"type": "type",
183+
"name": "name",
184+
"arguments": "arguments", # noqa: E501
172185
},
173186
}
174187
)
@@ -179,7 +192,7 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
179192
}
180193
)
181194
],
182-
"time_created": "2024-09-01T00:00:00Z",
195+
"time_created": "2025-08-14T10:00:01.100000+00:00",
183196
}
184197
),
185198
"model_id": "meta.llama-3-70b-instruct",
@@ -285,36 +298,62 @@ def test_meta_tool_conversion(monkeypatch: MonkeyPatch) -> None:
285298
from pydantic import BaseModel, Field
286299

287300
oci_gen_ai_client = MagicMock()
288-
llm = ChatOCIGenAI(model_id="meta.llama-3-70b-instruct", client=oci_gen_ai_client)
301+
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
289302

290303
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
304+
request = args[0]
305+
# Check the conversion of tools to oci generic API spec
306+
# Function tool
307+
assert request.chat_request.tools[0].parameters["properties"] == {
308+
"x": {"description": "Input number", "type": "integer"}
309+
}
310+
# Pydantic tool
311+
assert request.chat_request.tools[1].parameters["properties"] == {
312+
"x": {"description": "Input number", "type": "integer"},
313+
"y": {"description": "Input string", "type": "string"},
314+
}
315+
291316
return MockResponseDict(
292317
{
293318
"status": 200,
294319
"data": MockResponseDict(
295320
{
296321
"chat_response": MockResponseDict(
297322
{
323+
"api_format": "GENERIC",
298324
"choices": [
299325
MockResponseDict(
300326
{
301327
"message": MockResponseDict(
302328
{
303-
"content": [
329+
"role": "ASSISTANT",
330+
"content": None,
331+
"tool_calls": [
304332
MockResponseDict(
305-
{"text": "Response"}
333+
{
334+
"arguments": '{"x": "10"}', # noqa: E501
335+
"id": "chatcmpl-tool-d123", # noqa: E501
336+
"name": "function_tool",
337+
"type": "FUNCTION",
338+
"attribute_map": {
339+
"id": "id",
340+
"type": "type",
341+
"name": "name",
342+
"arguments": "arguments", # noqa: E501
343+
},
344+
}
306345
)
307-
]
346+
],
308347
}
309348
),
310-
"finish_reason": "completed",
349+
"finish_reason": "tool_calls",
311350
}
312351
)
313352
],
314-
"time_created": "2024-09-01T00:00:00Z",
353+
"time_created": "2025-08-14T10:00:01.100000+00:00",
315354
}
316355
),
317-
"model_id": "meta.llama-3-70b-instruct",
356+
"model_id": "meta.llama-3.3-70b-instruct",
318357
"model_version": "1.0.0",
319358
}
320359
),
@@ -348,7 +387,10 @@ class PydanticTool(BaseModel):
348387
tools=[function_tool, PydanticTool],
349388
).invoke(messages)
350389

351-
assert response.content == "Response"
390+
# For tool calls, the response content should be empty.
391+
assert response.content == ""
392+
assert len(response.tool_calls) == 1
393+
assert response.tool_calls[0]["name"] == "function_tool"
352394

353395

354396
@pytest.mark.requires("oci")
@@ -411,13 +453,13 @@ class WeatherResponse(BaseModel):
411453
conditions: str = Field(description="Weather conditions")
412454

413455
oci_gen_ai_client = MagicMock()
414-
llm = ChatOCIGenAI(model_id="cohere.command-r-16k", client=oci_gen_ai_client)
456+
llm = ChatOCIGenAI(model_id="cohere.command-latest", client=oci_gen_ai_client)
415457

416458
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
417459
# Verify that response_format contains the schema
418460
request = args[0]
419-
assert request.response_format["type"] == "JSON_OBJECT"
420-
assert "schema" in request.response_format
461+
assert request.chat_request.response_format["type"] == "JSON_OBJECT"
462+
assert "schema" in request.chat_request.response_format
421463

422464
return MockResponseDict(
423465
{
@@ -426,16 +468,17 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
426468
{
427469
"chat_response": MockResponseDict(
428470
{
471+
"api_format": "COHERE",
429472
"text": '{"temperature": 25.5, "conditions": "Sunny"}',
430-
"finish_reason": "completed",
473+
"finish_reason": "COMPLETE",
431474
"is_search_required": None,
432475
"search_queries": None,
433476
"citations": None,
434477
"documents": None,
435478
"tool_calls": None,
436479
}
437480
),
438-
"model_id": "cohere.command-r-16k",
481+
"model_id": "cohere.command-latest",
439482
"model_version": "1.0.0",
440483
}
441484
),
@@ -462,13 +505,19 @@ def test_auth_file_location(monkeypatch: MonkeyPatch) -> None:
462505
from unittest.mock import patch
463506

464507
with patch("oci.config.from_file") as mock_from_file:
465-
custom_config_path = "/custom/path/config"
466-
ChatOCIGenAI(
467-
model_id="cohere.command-r-16k", auth_file_location=custom_config_path
468-
)
469-
mock_from_file.assert_called_once_with(
470-
file_location=custom_config_path, profile_name="DEFAULT"
471-
)
508+
with patch(
509+
"oci.generative_ai_inference.generative_ai_inference_client.validate_config"
510+
):
511+
with patch("oci.base_client.validate_config"):
512+
with patch("oci.signer.load_private_key"):
513+
custom_config_path = "/custom/path/config"
514+
ChatOCIGenAI(
515+
model_id="cohere.command-r-16k",
516+
auth_file_location=custom_config_path,
517+
)
518+
mock_from_file.assert_called_once_with(
519+
file_location=custom_config_path, profile_name="DEFAULT"
520+
)
472521

473522

474523
@pytest.mark.requires("oci")
@@ -524,3 +573,17 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
524573
assert isinstance(response["parsed"], WeatherResponse)
525574
assert response["parsed"].temperature == 25.5
526575
assert response["parsed"].conditions == "Sunny"
576+
577+
578+
def test_get_provider():
579+
"""Test determining the provider based on the model_id."""
580+
model_provider_map = {
581+
"cohere.command-latest": "CohereProvider",
582+
"meta.llama-3.3-70b-instruct": "MetaProvider",
583+
"xai.grok-3": "GenericProvider",
584+
}
585+
for model_id, provider_name in model_provider_map.items():
586+
assert (
587+
ChatOCIGenAI(model_id=model_id)._provider.__class__.__name__
588+
== provider_name
589+
)

0 commit comments

Comments
 (0)