Skip to content

Commit 31bf894

Browse files
feat: Add XGB detectors
1 parent 5a66b50 commit 31bf894

File tree

7 files changed

+262
-0
lines changed

7 files changed

+262
-0
lines changed

detectors/xgb/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# XGB Classification Detector
2+
3+
## Setup
4+
1. Train XGB model and save trained model
5+
```
6+
cd guardrails-detectors/detectors/xgb/build
7+
make all
8+
```
9+
10+
2. Build image
11+
```
12+
cd guardrails-detectors
13+
podman build --file=Dockerfile.xgb -t xgb_detector:latest
14+
```
15+
16+
## Detector API
17+
## `/api/v1/text/contents`
18+
*
19+
20+
## Testing Locally
21+
```
22+
podman run -p 8001:8000 --platform=linux/amd64 quay.io/christinaexyou/xgb_detector:latest
23+
```
24+
25+
Wait for the server to start
26+
```
27+
```

detectors/xgb/build/Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
train_pipeline:
2+
python3 train.py
3+
all:
4+
- train_pipeline

detectors/xgb/build/app.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
import sys
3+
from contextlib import asynccontextmanager
4+
from typing import Annotated
5+
6+
from fastapi import Header
7+
from prometheus_fastapi_instrumentator import Instrumentator
8+
9+
sys.path.insert(0, os.path.abspath(".."))
10+
11+
from detector import Detector
12+
13+
from detectors.common.app import DetectorBaseAPI as FastAPI
14+
from detectors.common.scheme import (
15+
ContentAnalysisHttpRequest,
16+
ContentsAnalysisResponse,
17+
Error,
18+
)
19+
20+
detector_objects = {}
21+
22+
23+
@asynccontextmanager
24+
async def lifespan(app: FastAPI):
25+
app.set_detector(Detector())
26+
yield
27+
# Clean up the ML models and release the resources
28+
detector: Detector = app.get_detector()
29+
if detector and hasattr(detector, 'close'):
30+
detector.close()
31+
app.cleanup_detector()
32+
33+
34+
app = FastAPI(lifespan=lifespan, dependencies=[])
35+
Instrumentator().instrument(app).expose(app)
36+
37+
38+
@app.post(
39+
"/api/v1/text/contents",
40+
response_model=ContentsAnalysisResponse,
41+
description="""Detectors that work on content text, be it user prompt or generated text. \
42+
Generally classification type detectors qualify for this. <br>""",
43+
responses={
44+
404: {"model": Error, "description": "Resource Not Found"},
45+
422: {"model": Error, "description": "Validation Error"},
46+
},
47+
)
48+
async def detector_unary_handler(
49+
request: ContentAnalysisHttpRequest,
50+
detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")],
51+
):
52+
return ContentsAnalysisResponse(root=detector_objects["detector"].run(request))

detectors/xgb/build/detector.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import sys
3+
4+
sys.path.insert(0, os.path.abspath(".."))
5+
import pathlib
6+
import torch
7+
import xgboost as xgb
8+
from detectors.common.scheme import (
9+
ContentAnalysisHttpRequest,
10+
ContentAnalysisResponse,
11+
ContentAnalysisHttpResponse,
12+
)
13+
import pickle as pkl
14+
15+
try:
16+
from common.app import logger
17+
except ImportError:
18+
sys.path.insert(0, os.path.join(pathlib.Path(__file__).parent.parent.resolve()))
19+
from common.app import logger
20+
21+
class Detector:
22+
def __init__(self):
23+
# initialize the detector
24+
model_files_path = os.path.abspath(s.path.join(os.sep, "app", "model_artifacts"))
25+
if not os.path.exists(model_files_path):
26+
model_files_path = os.path.join("build", "model_artifacts")
27+
logger.info(model_files_path)
28+
29+
self.model = pkl.load(open(os.path.join(model_files_path, 'model.pkl'), 'rb'))
30+
self.vectorizer = pkl.load(open(os.path.join(model_files_path, 'vectorizer.pkl'), 'rb'))
31+
32+
if torch.cuda.is_available():
33+
self.cuda_device = torch.device("cuda")
34+
torch.cuda.empty_cache()
35+
self.model.to(self.cuda_device)
36+
logger.info("cuda_device".upper() + " " + str(self.cuda_device))
37+
self.batch_size = 1
38+
else:
39+
self.batch_size = 8
40+
logger.info("Detector initialized.")
41+
42+
def run(self, request: ContentAnalysisHttpRequest) -> ContentsAnalysisHttpResponse:
43+
if hasattr(request, "detection_type") and request.detection_type != "spamCheck":
44+
logger.warning(f"Unsupported detection type: {request.detection_type}")
45+
46+
content_analyses = []
47+
for batch_idx in range(0, len(request.contents), self.batch_size):
48+
text = request.contents[batch_idx:batch_idx + self.batch_size]
49+
vectorized_text = self.vectorizer.transform(text)
50+
predictions = self.model.predict(vectorized_text)
51+
detections = any([True for p in predictions if p == 1])
52+
53+
content_analyses.append(
54+
ContentAnalysisResponse(
55+
start=0,
56+
end=len(text),
57+
detection=detections,
58+
detection_type="spamCheck",
59+
text=text,
60+
evidences=[],
61+
)
62+
)
63+
return content_analyses

detectors/xgb/build/train.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import argparse
2+
import os
3+
import pathlib
4+
import pickle
5+
import re
6+
7+
from datasets import load_dataset
8+
import pandas as pd
9+
import xgboost as xgb
10+
from nltk.corpus import stopwords
11+
from nltk.stem import PorterStemmer
12+
from sklearn.feature_extraction.text import TfidfVectorizer
13+
from sklearn.model_selection import GridSearchCV
14+
15+
16+
def load_data(dataset_name, **dataset_kwargs):
17+
return load_dataset(dataset_name, **dataset_kwargs)
18+
19+
def generate_training_df(data):
20+
df = pd.DataFrame(data).rename(columns={"sms": "text"})
21+
return df
22+
23+
def preprocess_text(X):
24+
stemmer = PorterStemmer()
25+
stop_words = stopwords.words('english')
26+
X['text'] = X['text'].apply(lambda x: " ".join([stemmer.stem(i) for i in re.sub("[^a-zA-Z]", " ", x).split() if i not in stop_words]).lower())
27+
return X
28+
29+
# ==================================================================================================
30+
# === MAIN =========================================================================================
31+
# ==================================================================================================
32+
if __name__ == '__main__':
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument('--dataset', type=str, default='sms_spam')
35+
parser.add_argument('--hf_token', type=str, default=os.getenv('HF_TOKEN', ''))
36+
37+
args = parser.parse_args()
38+
artifact_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "model_artifacts")
39+
os.makedirs(artifact_path, exist_ok=True)
40+
41+
if args.dataset.lower() == 'sms_spam':
42+
print("Loading SMS spam dataset...")
43+
data = load_data("ucirvine/sms_spam", token=args.hf_token, split="train")
44+
train_df = generate_training_df(data)
45+
46+
print("Preprocessing data...")
47+
X = train_df.drop(columns=['label'])
48+
X = preprocess_text(X)
49+
vectorizer = TfidfVectorizer()
50+
X_vec = vectorizer.fit_transform(X['text'])
51+
52+
y = train_df['label']
53+
54+
print("Training XGBoost model...")
55+
param_grid = {
56+
'max_depth': [3, 5, 7],
57+
'learning_rate': [0.1, 0.01, 0.001],
58+
'subsample': [0.5, 0.7, 1]
59+
}
60+
grid_search = GridSearchCV(
61+
xgb.XGBClassifier(random_state=42),
62+
param_grid,
63+
cv=5,
64+
scoring='accuracy'
65+
)
66+
grid_search.fit(X_vec, y)
67+
clf = xgb.XGBClassifier(
68+
max_depth=grid_search.best_params_['max_depth'],
69+
learning_rate=grid_search.best_params_['learning_rate'],
70+
subsample=grid_search.best_params_['subsample'],
71+
random_state=42
72+
)
73+
clf.fit(X, y)
74+
75+
print(f"Saving training artifacts to {artifact_path}...")
76+
pickle.dump(vectorizer, open(f'{artifact_path}/vectorizer.pkl', 'wb'))
77+
pickle.dump(clf, open(f'{artifact_path}/model.pkl', 'wb'))
78+
79+
else:
80+
raise NotImplementedError(f"Dataset {args.dataset} not yet supported")

detectors/xgb/requirements.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
xgboost==3.0.2
2+
torch==2.4.0
3+
pandas==2.2.2
4+
numpy==1.26.4
5+
datasets
6+
nltk==3.9.1
7+
scikit-learn==1.7.0

detectors/xgb/test_xgb.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
from fastapi.testclient import TestClient
3+
4+
class TestXGBDetectors:
5+
@pytest.fixture
6+
def client(self):
7+
from detectors.xgb.build.app import app
8+
from detectors.xgb.build.detector import Detector
9+
10+
app.set_detector(Detector(), "detector")
11+
return TestClient(app)
12+
13+
@pytest.mark.parametrize(
14+
"content,expected",
15+
[
16+
(["Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."], True),
17+
(["Don't forget to bring your notebook to class tomorrow."], False),
18+
]
19+
)
20+
21+
def test_xgb_detectors(self, client, content, expected):
22+
payload = {
23+
"content": [content],
24+
}
25+
resp = client.post("api/v1/text/contexts", json=payload)
26+
assert resp.status_code == 200
27+
assert len(resp.json()[0]) > 0
28+
assert resp.json()[0][0]['spam_check'] == expected
29+

0 commit comments

Comments
 (0)