|
10 | 10 | from fastapi.responses import JSONResponse, Response, StreamingResponse
|
11 | 11 | from emd.utils.logger_utils import get_logger
|
12 | 12 | from fastapi.concurrency import run_in_threadpool
|
| 13 | +from emd.utils.framework_utils import get_model_specific_path |
13 | 14 |
|
14 | 15 | model_id = os.environ.get("model_id")
|
15 | 16 | model_tag = os.environ.get("model_tag")
|
@@ -90,12 +91,24 @@ async def invocations(request: Request, authorization: str = Depends(get_authori
|
90 | 91 |
|
91 | 92 | return await invoke(payload)
|
92 | 93 |
|
| 94 | +endpoints = { |
| 95 | + "ping": {"func": ping, "methods": ["GET"]}, |
| 96 | + "health": {"func": health, "methods": ["GET"]}, |
| 97 | + # Note: The functions for the POST endpoints all use "invocations". |
| 98 | + "invocations": {"func": invocations, "methods": ["POST"]}, |
| 99 | + "v1/chat/completions": {"func": invocations, "methods": ["POST"]}, |
| 100 | + "v1/embeddings": {"func": invocations, "methods": ["POST"]}, |
| 101 | + "score": {"func": invocations, "methods": ["POST"]}, |
| 102 | +} |
| 103 | + |
93 | 104 | if model_id and model_tag:
|
94 |
| - app.add_api_route( |
95 |
| - path=f"/{model_id}/{model_tag}/v1/chat/completions", |
96 |
| - endpoint=invocations, |
97 |
| - methods=["POST"] |
98 |
| - ) |
| 105 | + for base_path, route_info in endpoints.items(): |
| 106 | + path = get_model_specific_path(model_id, model_tag, base_path) |
| 107 | + app.add_api_route( |
| 108 | + path=path, |
| 109 | + endpoint=route_info["func"], |
| 110 | + methods=route_info["methods"] |
| 111 | + ) |
99 | 112 |
|
100 | 113 | if __name__ == "__main__":
|
101 | 114 | args = parse_args()
|
|
0 commit comments