Skip to content

Commit 8a3b7fd

Browse files
committed
baseline code added
1 parent 7d26888 commit 8a3b7fd

File tree

11 files changed

+595
-0
lines changed

11 files changed

+595
-0
lines changed

requirements.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
langchain==0.2.6
2+
langchain-community==0.2.6
3+
sentence-transformers==2.6.1
4+
scikit-learn==0.22.2.post1
5+
coclust==0.2.1
6+
rouge-score==0.1.2
7+
scipy==1.13.0
8+
numpy==1.23.5
9+
tqdm==4.66.4
10+
torch==2.2.2

scripts/run_theme_detection.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: CC-BY-NC-4.0
3+
4+
from argparse import ArgumentParser
5+
import json
6+
import os
7+
import copy
8+
import collections
9+
10+
import getpass
11+
import tqdm
12+
from langchain_huggingface import HuggingFaceEmbeddings
13+
from langchain_core.runnables import RunnableParallel
14+
from sklearn.cluster import KMeans
15+
16+
from dstc12.prompts import LABEL_CLUSTERS_PROMPT
17+
from dstc12.utils import get_llm, DotAllRegexParser
18+
import numpy as np
19+
20+
21+
def parse_args():
22+
parser = ArgumentParser()
23+
parser.add_argument('dataset_file', type=str)
24+
parser.add_argument('preferences_file', type=str)
25+
parser.add_argument('result_file', type=str)
26+
parser.add_argument('--n-clusters', type=int, default=10)
27+
parser.add_argument('--random-state', type=int, default=42)
28+
parser.add_argument('--embedding-model-name', type=str, default='sentence-transformers/all-mpnet-base-v2')
29+
parser.add_argument('--llm-name', type=str, default='mistralai/Mistral-7B-Instruct-v0.3')
30+
return parser.parse_args()
31+
32+
33+
def find_second_closest_cluster(emb, centroids):
34+
distances = [np.linalg.norm(emb - centroid) for centroid in centroids]
35+
sorted_indices = np.argsort(distances)
36+
return sorted_indices[1]
37+
38+
39+
def apply_preferences_to_clusters(utterances, utterance_embs, cluster_labels, cluster_centroids, shouldlink_pairs, cannot_link_pairs):
40+
assert len(utterances) == len(cluster_labels)
41+
42+
datapoint_modification_counter = collections.defaultdict(lambda: 0)
43+
44+
utterance_cluster_mapping = collections.defaultdict(lambda: -1)
45+
utterance_idx_mapping = collections.defaultdict(lambda: -1)
46+
for utt_idx, cluster_label in enumerate(cluster_labels):
47+
utterance = utterances[utt_idx]
48+
utterance_cluster_mapping[utterance] = cluster_label
49+
utterance_idx_mapping[utterance] = utt_idx
50+
modified_cluster_labels = copy.deepcopy(cluster_labels)
51+
for utt_a, utt_b in shouldlink_pairs:
52+
cluster_a, cluster_b = utterance_cluster_mapping[utt_a], utterance_cluster_mapping[utt_b]
53+
if cluster_a != cluster_b:
54+
utt_b_idx = utterance_idx_mapping[utt_b]
55+
modified_cluster_labels[utt_b_idx] = cluster_a
56+
utterance_cluster_mapping[utt_b] = cluster_a
57+
datapoint_modification_counter[utt_b_idx] += 1
58+
for utt_a, utt_b in cannot_link_pairs:
59+
cluster_a, cluster_b = utterance_cluster_mapping[utt_a], utterance_cluster_mapping[utt_b]
60+
if cluster_a == cluster_b:
61+
utt_b_idx = utterance_idx_mapping[utt_b]
62+
utt_b_new_cluster = find_second_closest_cluster(utterance_embs[utt_b_idx], cluster_centroids)
63+
modified_cluster_labels[utt_b_idx] = utt_b_new_cluster
64+
utterance_cluster_mapping[utt_b] = utt_b_new_cluster
65+
datapoint_modification_counter[utt_b_idx] += 1
66+
return modified_cluster_labels
67+
68+
69+
def main(utterances, linking_preferences, embedding_model_name, llm_name, n_clusters, random_state):
70+
llm = get_llm(llm_name)
71+
chain = (
72+
LABEL_CLUSTERS_PROMPT |
73+
llm |
74+
RunnableParallel(
75+
theme_label=DotAllRegexParser(regex=r'<theme_label>(.*?)</theme_label>', output_keys=['theme_label']),
76+
theme_label_explanation=DotAllRegexParser(regex=r'<theme_label_explanation>(.*?)</theme_label_explanation>', output_keys=['theme_label_explanation'])
77+
)
78+
)
79+
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
80+
query_embeddings = [embeddings.embed_query(utterance) for utterance in tqdm.tqdm(utterances)]
81+
kmeans = KMeans(n_clusters=n_clusters, n_init=1, init='k-means++', random_state=random_state)
82+
kmeans.fit(query_embeddings)
83+
clusters = kmeans.labels_
84+
centroids = kmeans.cluster_centers_
85+
clusters_with_preferences = apply_preferences_to_clusters(
86+
utterances,
87+
query_embeddings,
88+
clusters,
89+
centroids,
90+
linking_preferences['should_link'],
91+
linking_preferences['cannot_link']
92+
)
93+
clustered_utterances = [[] for _ in range(n_clusters)]
94+
for i, label in enumerate(clusters_with_preferences):
95+
clustered_utterances[label].append(utterances[i])
96+
cluster_label_map = {}
97+
for i, cluster in tqdm.tqdm(enumerate(clustered_utterances)):
98+
outputs_parsed = chain.invoke({'utterances': '\n'.join(cluster)})
99+
for utterance in cluster:
100+
cluster_label_map[utterance] = outputs_parsed['theme_label']['theme_label']
101+
return cluster_label_map
102+
103+
104+
105+
if __name__ == '__main__':
106+
args = parse_args()
107+
108+
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
109+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = getpass.getpass("Enter your token: ")
110+
111+
with open(args.dataset_file) as f:
112+
dataset = [json.loads(line) for line in f]
113+
themed_utterances = set([])
114+
for dialogue in dataset:
115+
for turn in dialogue['turns']:
116+
if turn['theme_label'] is not None:
117+
themed_utterances.add(turn['utterance'])
118+
119+
with open(args.preferences_file) as prefs_in:
120+
linking_preferences = json.load(prefs_in)
121+
cluster_label_map = main(
122+
list(themed_utterances),
123+
linking_preferences,
124+
args.embedding_model_name,
125+
args.llm_name,
126+
args.n_clusters,
127+
args.random_state
128+
)
129+
dataset_predicted = copy.deepcopy(dataset)
130+
for dialogue in dataset_predicted:
131+
for turn in dialogue['turns']:
132+
if turn['theme_label'] is not None:
133+
turn['theme_label_predicted'] = cluster_label_map[turn['utterance']]
134+
with open(args.result_file, 'w') as result_out:
135+
for dialogue in dataset_predicted:
136+
print(json.dumps(dialogue), file=result_out)

set_paths.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export PYTHONPATH=$PYTHONPATH:`pwd`/src/:`pwd`/scripts/

src/dstc12/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: CC-BY-NC-4.0

src/dstc12/eval.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: CC-BY-NC-4.0
3+
4+
import tqdm
5+
import numpy as np
6+
from coclust.evaluation.external import accuracy
7+
from sklearn.metrics import normalized_mutual_info_score
8+
from sklearn.metrics.pairwise import cosine_similarity
9+
from rouge_score import rouge_scorer
10+
from rouge_score.scoring import Score
11+
from langchain_core.runnables import RunnableParallel
12+
13+
from dstc12.utils import DotAllRegexParser
14+
from dstc12.prompts import STYLEGUIDE_SECTION_1_PROMPT, STYLEGUIDE_SECTION_2_PROMPT, STYLEGUIDE_SECTION_3_PROMPT
15+
16+
17+
def acc(references=None, predictions=None):
18+
assert references and predictions and len(references) == len(predictions)
19+
return accuracy(references, predictions)
20+
21+
22+
def nmi(references=None, predictions=None):
23+
assert references and predictions and len(references) == len(predictions)
24+
return normalized_mutual_info_score(references, predictions)
25+
26+
27+
def rouge(references=None, predictions=None, metrics=['rouge1', 'rouge2', 'rouge3'], average=False):
28+
assert len(references) == len(predictions)
29+
scorer = rouge_scorer.RougeScorer(metrics, use_stemmer=True)
30+
scores = [scorer.score(ref, pred) for ref, pred in zip(references, predictions)]
31+
if average:
32+
scores_aggregated = {metric: {'precision': 0, 'recall': 0, 'fmeasure': 0} for metric in scores}
33+
for score in scores:
34+
for metric in score:
35+
scores_aggregated[metric]['precision'] += score[metric].precision / len(scores)
36+
scores_aggregated[metric]['recall'] += score[metric].recall / len(scores)
37+
scores_aggregated[metric]['fmeasure'] += score[metric].fmeasure / len(scores)
38+
result = {
39+
metric_name: Score(metric['precision'], metric['recall'], metric['fmeasure'])
40+
for metric_name, metric in result.items()
41+
}
42+
else:
43+
result = scores
44+
return result
45+
46+
47+
def rouge_with_multiple_references(references_list, predictions):
48+
scores = [rouge(refs_i, predictions, aggregate=True) for refs_i in references_list]
49+
scores_averaged = {
50+
metric_name: {
51+
'precision': 0,
52+
'recall': 0,
53+
'fmeasure': 0
54+
} for metric_name in scores[0]
55+
}
56+
57+
for score_i in scores:
58+
for metric_name, score in score_i.items():
59+
scores_averaged[metric_name]['precision'] += score.precision / len(scores)
60+
scores_averaged[metric_name]['recall'] += score.recall / len(scores)
61+
scores_averaged[metric_name]['fmeasure'] += score.fmeasure / len(scores)
62+
return scores_averaged
63+
64+
65+
def cosine_similarity_with_multiple_references(references_list, predictions):
66+
scores = [cosine_similarity(refs_i, predictions) for refs_i in references_list]
67+
scores_averaged = sum(scores) / len(scores)
68+
return scores_averaged
69+
70+
71+
def process_llm_judge_output(output):
72+
scores = []
73+
for section in ['section_1', 'section_2', 'section_3']:
74+
assert section in output and 'score' in output[section]
75+
scores.append(int(output[section]['score']['value'] == 'Good'))
76+
return np.mean(scores)
77+
78+
79+
def llm_score(predictions, llm):
80+
chain = (
81+
RunnableParallel(
82+
section_1=STYLEGUIDE_SECTION_1_PROMPT
83+
| llm
84+
| RunnableParallel(
85+
score=DotAllRegexParser(regex=r'<score>\s*(.*?)\s*</score>', output_keys=['value']),
86+
explanation=DotAllRegexParser(regex=r'<explanation>\s*(.*?)\s*</explanation>', output_keys=['value']),
87+
),
88+
section_2=STYLEGUIDE_SECTION_2_PROMPT
89+
| llm
90+
| RunnableParallel(
91+
score=DotAllRegexParser(regex=r'<score>\s*(.*?)\s*</score>', output_keys=['value']),
92+
explanation=DotAllRegexParser(regex=r'<explanation>\s*(.*?)\s*</explanation>', output_keys=['value']),
93+
),
94+
section_3=STYLEGUIDE_SECTION_3_PROMPT
95+
| llm
96+
| RunnableParallel(
97+
score=DotAllRegexParser(regex=r'<score>\s*(.*?)\s*</score>', output_keys=['value']),
98+
explanation=DotAllRegexParser(regex=r'<explanation>\s*(.*?)\s*</explanation>', output_keys=['value']),
99+
),
100+
)
101+
)
102+
scores = []
103+
for prediction in tqdm.tqdm(predictions):
104+
judge_output = chain.invoke({'theme_label': prediction})
105+
scores.append(process_llm_judge_output(judge_output))
106+
return np.mean(scores)

src/dstc12/prompts/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: CC-BY-NC-4.0
3+
4+
from .extract_themes import PROMPT as EXTRACT_THEMES_PROMPT
5+
from .label_utterances import PROMPT as LABEL_UTTERANCES_PROMPT
6+
from .label_clusters import PROMPT as LABEL_CLUSTERS_PROMPT
7+
from .styleguide import (
8+
SECTION_1_PROMPT as STYLEGUIDE_SECTION_1_PROMPT,
9+
SECTION_2_PROMPT as STYLEGUIDE_SECTION_2_PROMPT,
10+
SECTION_3_PROMPT as STYLEGUIDE_SECTION_3_PROMPT
11+
)

src/dstc12/prompts/extract_themes.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: CC-BY-NC-4.0
3+
4+
from langchain_core.prompts import PromptTemplate
5+
6+
7+
PROMPT = PromptTemplate.from_template(
8+
'''<task>
9+
You are an expert call center assistant. You will be given a set of utterances in <utterances> </utterances> tags, each one on a new line. Read through them carefully and cluster them into themes. The themes should be exhaustive and mutually exclusive and should cover the dataset completely.
10+
Output a full set of themes you identified. One utterance can only belong to one theme.
11+
12+
<guidance>
13+
Write your output in the following format:
14+
Unique themes number: n
15+
<theme>theme label 1</theme>
16+
<theme>theme label 2</theme>
17+
...
18+
<theme>theme label n</theme>
19+
</guidance>
20+
21+
H:
22+
<utterances>
23+
{utterances}
24+
</utterances>
25+
'''
26+
)

src/dstc12/prompts/label_clusters.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: CC-BY-NC-4.0
3+
4+
from langchain_core.prompts import PromptTemplate
5+
6+
7+
PROMPT = PromptTemplate.from_template(
8+
'''<task>
9+
You are an expert call center assistant. You will be given a set of utterances in <utterances> </utterances> tags, each one on a new line.
10+
The utterances are part of callcenter conversations between the customer and the support agent.
11+
Your task is to generate a short label describing the theme of all the given utterances. The theme label should be under 5 words and describe the desired customer's action in the call.
12+
13+
14+
<guidance>
15+
Output your response in the following way.
16+
<theme_label_explanation>Your short step-by-step explanation behind the theme</theme_label_explanation>
17+
<theme_label>your theme label</theme_label>
18+
</guidance>
19+
</task>
20+
21+
H:
22+
<utterances>
23+
{utterances}
24+
</utterances>
25+
'''
26+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: CC-BY-NC-4.0
3+
4+
from langchain_core.prompts import PromptTemplate
5+
6+
7+
PROMPT = PromptTemplate.from_template(
8+
'''<task>
9+
You are an expert call center assistant. You will be given a set of utterances in <utterances> </utterances> tags, each one on a new line.
10+
You will also receive a set of theme labels in <themes> </themes> tags. Read through them carefully and associate each utterance with the corresponding theme label index.
11+
12+
<example>
13+
H:
14+
<utterances>
15+
I want to cancel my account
16+
I never received my order
17+
I want to get some information about your insurance offerings
18+
</utterances>
19+
20+
<themes>
21+
1. book a flight
22+
2. information about insurance
23+
3. return product
24+
4. cancel account
25+
5. request refund
26+
6. open account
27+
</themes>
28+
29+
A:
30+
<theme_indices>4, 0, 2</theme_indices>
31+
</example>
32+
33+
<guidance>
34+
Write output in the following format: <theme_indices>comma separated theme indices for every input utterance</theme_indices>
35+
If no theme matches an utterance, assign it the index 0. If multiple themes match an utterance, assign it the theme you thought of first.
36+
</guidance>
37+
</task>
38+
39+
H:
40+
<utterances>
41+
{utterances}
42+
</utterances>
43+
44+
<themes>
45+
{themes}
46+
</themes>
47+
'''
48+
)

0 commit comments

Comments
 (0)