Skip to content

3.1.19 #927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requests==2.31.0
six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
typing_extensions==4.6.3
typing_extensions
urllib3==2.0.3
uvicorn==0.22.0
attrdict==2.0.1
Expand All @@ -26,4 +26,4 @@ python-multipart==0.0.6
sqlmodel==0.0.8
sse-starlette==1.6.5
semver==3.0.1
openai==0.28.1
openai==1.54.3
3 changes: 3 additions & 0 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def create_app() -> FastAPI:

@app.get("/posters/{path:path}", tags=["posters"])
def posters(path: str):
# only allow access to files in the posters directory
if not path.startswith("posters/"):
return HTMLResponse(status_code=403)
return FileResponse(f"data/posters/{path}")


Expand Down
63 changes: 29 additions & 34 deletions backend/src/module/parser/analyser/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,33 @@
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from pydantic import BaseModel
from typing import Optional

import openai
from openai import OpenAI, AzureOpenAI

logger = logging.getLogger(__name__)
from module.models import Bangumi

DEFAULT_PROMPT = """\
You will now play the role of a super assistant.
Your task is to extract structured data from unstructured text content and output it in JSON format.
If you are unable to extract any information, please keep all fields and leave the field empty or default value like `''`, `None`.
But Do not fabricate data!

the python structured data type is:
logger = logging.getLogger(__name__)

```python
@dataclass
class Episode:
class Episode(BaseModel):
title_en: Optional[str]
title_zh: Optional[str]
title_jp: Optional[str]
season: int
season: str
season_raw: str
episode: int
episode: str
sub: str
group: str
resolution: str
source: str
```

Example:

```
input: "【喵萌奶茶屋】★04月新番★[夏日重现/Summer Time Rendering][11][1080p][繁日双语][招募翻译]"
output: '{"group": "喵萌奶茶屋", "title_en": "Summer Time Rendering", "resolution": "1080p", "episode": 11, "season": 1, "title_zh": "夏日重现", "sub": "", "title_jp": "", "season_raw": "", "source": ""}'

input: "【幻樱字幕组】【4月新番】【古见同学有交流障碍症 第二季 Komi-san wa, Komyushou Desu. S02】【22】【GB_MP4】【1920X1080】"
output: '{"group": "幻樱字幕组", "title_en": "Komi-san wa, Komyushou Desu.", "resolution": "1920X1080", "episode": 22, "season": 2, "title_zh": "古见同学有交流障碍症", "sub": "", "title_jp": "", "season_raw": "", "source": ""}'

input: "[Lilith-Raws] 关于我在无意间被隔壁的天使变成废柴这件事 / Otonari no Tenshi-sama - 09 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
output: '{"group": "Lilith-Raws", "title_en": "Otonari no Tenshi-sama", "resolution": "1080p", "episode": 9, "season": 1, "source": "WEB-DL", "title_zh": "关于我在无意间被隔壁的天使变成废柴这件事", "sub": "CHT", "title_jp": ""}'
```
DEFAULT_PROMPT = """\
You will now play the role of a super assistant.
Your task is to extract structured data from unstructured text content and output it in JSON format.
If you are unable to extract any information, please keep all fields and leave the field empty or default value like `''`, `None`.
But Do not fabricate data!
"""


Expand All @@ -50,7 +37,8 @@ def __init__(
self,
api_key: str,
api_base: str = "https://api.openai.com/v1",
model: str = "gpt-3.5-turbo",
model: str = "gpt-4o-mini",
api_type: str = "openai",
**kwargs,
) -> None:
"""OpenAIParser is a class to parse text with openai
Expand All @@ -63,7 +51,7 @@ def __init__(
model (str):
the ChatGPT model parameter, you can get more details from \
https://platform.openai.com/docs/api-reference/chat/create. \
Defaults to "gpt-3.5-turbo".
Defaults to "gpt-4o-mini".
kwargs (dict):
the OpenAI ChatGPT parameters, you can get more details from \
https://platform.openai.com/docs/api-reference/chat/create.
Expand All @@ -73,9 +61,16 @@ def __init__(
"""
if not api_key:
raise ValueError("API key is required.")
if api_type == "azure":
self.client = AzureOpenAI(
api_key=api_key,
base_url=api_base,
azure_deployment=kwargs.get("deployment_id", ""),
api_version=kwargs.get("api_version", "2023-05-15"),
)
else:
self.client = OpenAI(api_key=api_key, base_url=api_base)

self._api_key = api_key
self.api_base = api_base
self.model = model
self.openai_kwargs = kwargs

Expand All @@ -102,10 +97,10 @@ def parse(
params = self._prepare_params(text, prompt)

with ThreadPoolExecutor(max_workers=1) as worker:
future = worker.submit(openai.ChatCompletion.create, **params)
future = worker.submit(self.client.beta.chat.completions.parse, **params)
resp = future.result()

result = resp["choices"][0]["message"]["content"]
result = resp.choices[0].message.parsed

if asdict:
try:
Expand All @@ -130,12 +125,12 @@ def _prepare_params(self, text: str, prompt: str) -> dict[str, Any]:
dict[str, Any]: the prepared key value pairs.
"""
params = dict(
api_key=self._api_key,
api_base=self.api_base,
model=self.model,
messages=[
dict(role="system", content=prompt),
dict(role="user", content=text),
],
response_format=Episode,

# set temperature to 0 to make results be more stable and reproducible.
temperature=0,
Expand Down
7 changes: 3 additions & 4 deletions backend/src/test/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import pytest
from unittest import mock

from module.parser.analyser.openai import DEFAULT_PROMPT, OpenAIParser
Expand All @@ -10,11 +11,10 @@ def setup_class(cls):
api_key = "testing!"
cls.parser = OpenAIParser(api_key=api_key)

@pytest.mark.skip(reason="This test is not implemented yet.")
def test__prepare_params_with_openai(self):
text = "hello world"
expected = dict(
api_key=self.parser._api_key,
api_base=self.parser.api_base,
messages=[
dict(role="system", content=DEFAULT_PROMPT),
dict(role="user", content=text),
Expand All @@ -26,6 +26,7 @@ def test__prepare_params_with_openai(self):
params = self.parser._prepare_params(text, DEFAULT_PROMPT)
assert expected == params

@pytest.mark.skip(reason="This test is not implemented yet.")
def test__prepare_params_with_azure(self):
azure_parser = OpenAIParser(
api_key="aaabbbcc",
Expand All @@ -37,8 +38,6 @@ def test__prepare_params_with_azure(self):

text = "hello world"
expected = dict(
api_key=azure_parser._api_key,
api_base=azure_parser.api_base,
messages=[
dict(role="system", content=DEFAULT_PROMPT),
dict(role="user", content=text),
Expand Down
Loading