Skip to content

Commit 761ac8c

Browse files
Add Dockerfile.xgb
1 parent 31bf894 commit 761ac8c

File tree

6 files changed

+68
-9
lines changed

6 files changed

+68
-9
lines changed

detectors/Dockerfile.xgb

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
FROM registry.access.redhat.com/ubi9/ubi-minimal as base
2+
RUN microdnf update -y && \
3+
microdnf install -y --nodocs \
4+
python-pip python-devel && \
5+
pip install --upgrade --no-cache-dir pip wheel && \
6+
microdnf clean all
7+
RUN pip install --no-cache-dir torch
8+
9+
# FROM icr.io/fm-stack/ubi9-minimal-py39-torch as builder
10+
FROM base as builder
11+
12+
COPY ./common/requirements.txt .
13+
RUN pip install --no-cache-dir -r requirements.txt
14+
15+
COPY ./xgb/requirements.txt .
16+
RUN pip install --no-cache-dir -r requirements.txt
17+
18+
FROM builder
19+
20+
21+
WORKDIR /app
22+
ARG CACHEBUST=1
23+
RUN echo "$CACHEBUST"
24+
COPY xgb/build/model_artifacts /app/model_artifacts
25+
COPY ./common /common
26+
27+
COPY ./xgb/build/scheme.py /app
28+
COPY ./xgb/build/app.py /app
29+
COPY ./xgb/build/detector.py /app
30+
31+
EXPOSE 8000
32+
CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]

detectors/xgb/build/app.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@
55

66
from fastapi import Header
77
from prometheus_fastapi_instrumentator import Instrumentator
8-
98
sys.path.insert(0, os.path.abspath(".."))
109

10+
from common.app import DetectorBaseAPI as FastAPI
1111
from detector import Detector
12-
13-
from detectors.common.app import DetectorBaseAPI as FastAPI
1412
from detectors.common.scheme import (
1513
ContentAnalysisHttpRequest,
1614
ContentsAnalysisResponse,

detectors/xgb/build/detector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from detectors.common.scheme import (
99
ContentAnalysisHttpRequest,
1010
ContentAnalysisResponse,
11-
ContentAnalysisHttpResponse,
1211
)
1312
import pickle as pkl
13+
from base_detector_registry import BaseDetectorRegistry
1414

1515
try:
1616
from common.app import logger
@@ -21,7 +21,7 @@
2121
class Detector:
2222
def __init__(self):
2323
# initialize the detector
24-
model_files_path = os.path.abspath(s.path.join(os.sep, "app", "model_artifacts"))
24+
model_files_path = os.path.abspath(os.path.join(os.sep, "app", "model_artifacts"))
2525
if not os.path.exists(model_files_path):
2626
model_files_path = os.path.join("build", "model_artifacts")
2727
logger.info(model_files_path)
@@ -39,7 +39,7 @@ def __init__(self):
3939
self.batch_size = 8
4040
logger.info("Detector initialized.")
4141

42-
def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisHttpResponse:
42+
def run(self, request: ContentAnalysisHttpRequest) -> ContentAnalysisResponse:
4343
if hasattr(request, "detection_type") and request.detection_type != "spamCheck":
4444
logger.warning(f"Unsupported detection type: {request.detection_type}")
4545

detectors/xgb/build/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def preprocess_text(X):
7070
subsample=grid_search.best_params_['subsample'],
7171
random_state=42
7272
)
73-
clf.fit(X, y)
73+
clf.fit(X_vec, y)
7474

7575
print(f"Saving training artifacts to {artifact_path}...")
7676
pickle.dump(vectorizer, open(f'{artifact_path}/vectorizer.pkl', 'wb'))

detectors/xgb/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
xgboost==3.0.2
1+
xgboost
22
torch==2.4.0
33
pandas==2.2.2
44
numpy==1.26.4
55
datasets
66
nltk==3.9.1
7-
scikit-learn==1.7.0
7+
scikit-learn

tests/detectors/xgb/test_xgb.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
from fastapi.testclient import TestClient
3+
4+
class TestXGBDetectors:
5+
@pytest.fixture
6+
def client(self):
7+
from detectors.xgb.build.app import app
8+
from detectors.xgb.build.detector import Detector
9+
10+
app.set_detector(Detector(), "detector")
11+
return TestClient(app)
12+
13+
@pytest.mark.parametrize(
14+
"content,expected",
15+
[
16+
(["Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."], True),
17+
(["Don't forget to bring your notebook to class tomorrow."], False),
18+
]
19+
)
20+
21+
def test_xgb_detectors(self, client, content, expected):
22+
payload = {
23+
"content": [content],
24+
}
25+
resp = client.post("api/v1/text/contexts", json=payload)
26+
assert resp.status_code == 200
27+
assert len(resp.json()[0]) > 0
28+
assert resp.json()[0][0]['spam_check'] == expected
29+

0 commit comments

Comments
 (0)