Skip to content

Commit b70936a

Browse files
fix: add model id and model tag to all fastapi endpoint.
1 parent 51ec568 commit b70936a

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

src/emd/utils/framework_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def get_model_specific_path(model_id: str, model_tag: str, base_path: str) -> str:
2+
return f"/{model_id}/{model_tag}/{base_path.lstrip('/')}"

src/pipeline/framework/fast_api/fast_api.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from fastapi.responses import JSONResponse, Response, StreamingResponse
1111
from emd.utils.logger_utils import get_logger
1212
from fastapi.concurrency import run_in_threadpool
13+
from emd.utils.framework_utils import get_model_specific_path
1314

1415
model_id = os.environ.get("model_id")
1516
model_tag = os.environ.get("model_tag")
@@ -90,12 +91,24 @@ async def invocations(request: Request, authorization: str = Depends(get_authori
9091

9192
return await invoke(payload)
9293

94+
endpoints = {
95+
"ping": {"func": ping, "methods": ["GET"]},
96+
"health": {"func": health, "methods": ["GET"]},
97+
# Note: The functions for the POST endpoints all use "invocations".
98+
"invocations": {"func": invocations, "methods": ["POST"]},
99+
"v1/chat/completions": {"func": invocations, "methods": ["POST"]},
100+
"v1/embeddings": {"func": invocations, "methods": ["POST"]},
101+
"score": {"func": invocations, "methods": ["POST"]},
102+
}
103+
93104
if model_id and model_tag:
94-
app.add_api_route(
95-
path=f"/{model_id}/{model_tag}/v1/chat/completions",
96-
endpoint=invocations,
97-
methods=["POST"]
98-
)
105+
for base_path, route_info in endpoints.items():
106+
path = get_model_specific_path(model_id, model_tag, base_path)
107+
app.add_api_route(
108+
path=path,
109+
endpoint=route_info["func"],
110+
methods=route_info["methods"]
111+
)
99112

100113
if __name__ == "__main__":
101114
args = parse_args()

0 commit comments

Comments
 (0)