Skip to content

Commit 7c18212

Browse files
committed
refactor: Overhauled utility classes for ease of use
1 parent b2f2445 commit 7c18212

File tree

7 files changed

+112
-78
lines changed

7 files changed

+112
-78
lines changed

docs/hello-world.md

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,50 @@ from langdiversity.parser import extract_math_answer
4444
model = OpenAIModel(openai_api_key="[API KEY]", extractor=extract_math_answer)
4545
```
4646

47-
### Prompt Selection
47+
### Collecting Diversity Measures
4848

49-
Now, we initialize the `PromptSelection` object. This is where we specify how many responses we want from the language model for each prompt, the diversity measure to use, and the selection method.
49+
In this step, we initialize the `DiversityMeasureCollector` object. This is where we specify how many responses we want from the language model for each prompt and the diversity measure to use.
5050

5151
```python
52-
from langdiversity.utils import PromptSelection
53-
prompt_selection = PromptSelection(model=model, num_responses=4, diversity_measure=diversity_measure, selection='min')
52+
from langdiversity.utils import DiversityMeasureCollector
53+
diversity_collector = DiversityMeasureCollector(model=model, num_responses=4, diversity_measure=diversity_measure)
5454
```
5555

56-
### Generate and Select Prompts
57-
58-
Finally, we pass in a list of prompts to the `PromptSelection` object. It will send these prompts to the language model, calculate the diversity measure for each set of responses, and then select the prompt with the minimum (or maximum) diversity measure.
56+
### Collecting Data
5957

60-
The selected prompt and its corresponding diversity measure are stored in `selected_prompt` and `selected_measure`, respectively.
58+
Next, we pass in a list of prompts to the `DiversityMeasureCollector` object. It will send these prompts to the language model, collect the responses, and calculate the diversity measure for each set of responses.
6159

6260
Note: The prompts are structured to guide the language model in generating a specific type of response. This makes it easier for the parser to extract clean answers.
6361

6462
```python
65-
selected_prompt, selected_measure = prompt_selection.generate([
63+
prompts = [
6664
"At the end, say 'the answer is [put your numbers here separated by commas]'.\nQuestion: What is the speed of the current if Junior's boat can cover 12 miles downstream in the same time it takes to travel 9 miles upstream, given that his boat's speed in still water is 15 miles per hour?",
6765
"At the end, say 'the answer is [put your numbers here separated by commas]'.\nQuestion: What is the speed of the current if Junior's boat travels at a constant speed of 15 miles per hour in still water and he spends the same amount of time traveling 12 miles downstream as he does traveling 9 miles upstream?.",
6866
"At the end, say 'the answer is [put your numbers here separated by commas]'.\nQuestion: Juniors boat will go 15 miles per hour in still water . If he can go 12 miles downstream in the same amount of time as it takes to go 9 miles upstream , then what is the speed of the current?",
69-
])
67+
68+
]
69+
70+
diversity_collector.collect(prompts, verbose=True) # Set verbose to True to see intermediate values
71+
```
72+
73+
### Prompt Selection
74+
75+
Now, we initialize the `PromptSelection` object with the data collected in the previous step.
76+
77+
```python
78+
from langdiversity.utils import PromptSelection
79+
prompt_selection = PromptSelection(data=diversity_collector.data, selection='min')
80+
```
81+
82+
### Selecting Prompts
83+
84+
Finally, we call the select method on the `PromptSelection` object to select the prompt with the desired diversity measure based on the user's selection method.
85+
86+
In this example, the selected prompt and its corresponding diversity measure are stored in `selected_prompt` and `selected_measure`, respectively.
87+
88+
```python
89+
selected_prompt, selected_diversity = prompt_selection.select()
7090

7191
print("Selected Prompt:", selected_prompt)
72-
print("Selected Measure:", selected_measure)
92+
print("Selected Diversity:", selected_diversity)
7393
```

docs/langdiversity_library.md

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,37 @@ pip install langdiversity
1313
Example:
1414

1515
```python
16+
import os
17+
from dotenv import load_dotenv
18+
19+
from langdiversity.utils import PromptSelection, DiversityMeasureCollector
1620
from langdiversity.models import OpenAIModel
1721
from langdiversity.measures import ShannonEntropyMeasure
18-
from langdiversity.utils import PromptSelection
19-
from langdiversity.parser import # Select a parser that suits your question set
22+
from langdiversity.parser import extract_last_letters # Select a parser that suits your question set
23+
24+
load_dotenv()
25+
openai_api_key = os.getenv("OPENAI_API_KEY") # place your language model's API key in a .env file
2026

2127
# Initialize the OpenAI model and diversity measure
22-
model = OpenAIModel(openai_api_key="[YOUR API KEY]", extractor="[SELECT YOUR PARSER](optional)")
28+
model = OpenAIModel(openai_api_key=openai_api_key, extractor=extract_last_letters)
2329
diversity_measure = ShannonEntropyMeasure()
2430

25-
# Use the PromptSelection utility
26-
prompt_selection = PromptSelection(model=model, num_responses=10, diversity_measure=diversity_measure)
31+
# Define your list of prompts
32+
prompts = [
33+
"At the end, say 'the answer is [put the concatenated word here]'.\nQuestion: Take the last letter of each word in \"Tal Evan Lesley Sidney\" and concatenate them..",
34+
# ... Add more prompts as needed
35+
]
36+
37+
# Create an instance of DiversityMeasureCollector and collect diversity measures
38+
diversity_collector = DiversityMeasureCollector(model=model, num_responses=4, diversity_measure=diversity_measure)
39+
diversity_collector.collect(prompts)
2740

28-
# Pass in question set to the LLM & selects the prompt with the configured diversity measure criteria from the LLM's 10 responses
29-
selected_prompt, selected_measure = prompt_selection.generate(["Your list of prompts here..."])
41+
# Create an instance of PromptSelection and select a prompt
42+
prompt_selection = PromptSelection(data=diversity_collector.data, selection="min")
43+
selected_prompt, selected_measure = prompt_selection.select()
3044

31-
print("Selected Prompt:", selected_prompt)
32-
print("Selected Measure:", selected_measure)
45+
print("Selected prompt:", selected_prompt)
46+
print("Selected measure:", selected_measure)
3347
```
3448

3549
### Modules:
@@ -48,21 +62,25 @@ LangDiversity offers a variety of modules for different use-cases. Below are the
4862
- [Utility Classes](https://github.com/lab-v2/langdiversity/tree/main/langdiversity/utils) (`langdiversity.utils`)
4963

5064
- `PromptSelection`: Handles the selection of prompts based on diversity measures.
51-
- `DiversityCalculator`: Calculates various diversity measures for a given set of values. Supports Shannon's entropy and Gini impurity by default.
65+
- `DiversityMeasureCollector`: Collects diversity measures for a given set of prompts using a specified language model and diversity measure algorithm.
5266

5367
- [Parsers](https://github.com/lab-v2/langdiversity/tree/main/langdiversity/parser) (`langdiversity.parsers`)
5468
- `extract_last_letters(response: str)`: Extracts the last letters of each word in the response.
5569
- `extract_math_answer(response: str)`: Extracts numerical answers from a mathematical question in the response.
5670
- `extract_multi_choice_answer(response: str)`: Extracts the selected choice (A, B, C, D, E) from a multiple-choice question in the response.
5771

58-
### PromptSelection Paramaters:
72+
### DiversityMeasureCollector Paramaters:
5973

6074
- `model`: The language model you want to use. In this example, we're using OpenAI's model.
6175

6276
- `diversity_measure`: The measure of diversity measure you want to use. Here, we're using entropy.
6377

6478
- `num_responses`: The number of responses you want the model to generate for each prompt. Default is 1.
6579

80+
### PromptSelection Parameters:
81+
82+
- `data`: A list of dictionaries, each containing information about a prompt, the responses it generated, and its diversity measure. This data is collected using the `DiversityMeasureCollector` class.
83+
6684
- `selection`: Determines how the best prompt is selected based on its diversity measure. It can be:
6785

6886
- `"min"`: Selects the prompt with the minimum diversity measure. (default)

examples/prompt_selection.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from dotenv import load_dotenv
33

4-
from langdiversity.utils import PromptSelection
4+
from langdiversity.utils import PromptSelection, DiversityMeasureCollector
55
from langdiversity.models import OpenAIModel
66
from langdiversity.measures import ShannonEntropyMeasure
77
from langdiversity.parser import extract_last_letters
@@ -11,16 +11,18 @@
1111

1212
diversity_measure = ShannonEntropyMeasure()
1313
model = OpenAIModel(openai_api_key=openai_api_key, extractor=extract_last_letters)
14-
prompt_selection = PromptSelection(
15-
model=model, num_responses=4, diversity_measure=diversity_measure, selection="min"
16-
)
17-
selected_prompt, selected_diversity = prompt_selection.generate(
18-
[
19-
"At the end, say 'the answer is [put the concatenated word here]'.\nQuestion: Take the last letter of each word in \"Tal Evan Lesley Sidney\" and concatenate them..",
20-
"At the end, say 'the answer is [put the concatenated word here]'.\nQuestion: Concatenate the last letter of each word in \"Tal Evan Lesley Sidney\".",
21-
"At the end, say 'the answer is [put the concatenated word here]'.\nQuestion: Combine the last letter of each word in \"Tal Evan Lesley Sidney\".",
22-
]
23-
)
14+
15+
prompts = [
16+
"At the end, say 'the answer is [put the concatenated word here]'.\nQuestion: Take the last letter of each word in \"Tal Evan Lesley Sidney\" and concatenate them..",
17+
"At the end, say 'the answer is [put the concatenated word here]'.\nQuestion: Concatenate the last letter of each word in \"Tal Evan Lesley Sidney\".",
18+
"At the end, say 'the answer is [put the concatenated word here]'.\nQuestion: Combine the last letter of each word in \"Tal Evan Lesley Sidney\".",
19+
]
20+
21+
diversity_collector = DiversityMeasureCollector(model=model, num_responses=4, diversity_measure=diversity_measure)
22+
diversity_collector.collect(prompts, verbose=True)
23+
24+
prompt_selection = PromptSelection(data=diversity_collector.data, selection="min")
25+
selected_prompt, selected_diversity = prompt_selection.select()
2426

2527
print("SELECTED PROMPT:", selected_prompt)
26-
print("SELECTED DIVERSITY:", selected_diversity)
28+
print("SELECTED DIVERSITY:", selected_diversity)

langdiversity/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .calculate_measures import *
2+
from .diversity_measures import *
23
from .prompt_selection import *
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import List
2+
3+
from langdiversity.extras.spinner import loading_spinner
4+
5+
from langdiversity.measures import AbstractMeasure
6+
from langdiversity.models import AbstractBaseModel
7+
8+
class DiversityMeasureCollector:
9+
def __init__(self, model: AbstractBaseModel, diversity_measure: AbstractMeasure, num_responses: int = 1):
10+
self.model = model
11+
self.diversity_measure = diversity_measure
12+
self.num_responses = num_responses
13+
self.data = [] # A list to store the data (prompt, responses, diversity measure)
14+
15+
def collect(self, prompts: List[str], verbose: bool = False):
16+
total_prompts = len(prompts)
17+
for i, prompt in enumerate(prompts):
18+
with loading_spinner(f"Collecting {self.num_responses} responses...", current_step=i + 1, total_steps=total_prompts):
19+
responses = self.model.generate(prompt, self.num_responses)
20+
with loading_spinner("Performing diversity measure calculations...", current_step=i + 1, total_steps=total_prompts):
21+
diversity = self.diversity_measure.generate(responses)
22+
if verbose:
23+
print(f"Prompt {i + 1}: {prompt}")
24+
print(f"Responses: {', '.join(responses)}") # Assuming responses are strings
25+
print(f"Diversity: {diversity}")
26+
self.data.append(
27+
{"prompt": prompt, "responses": responses, "diversity": diversity}
28+
)
Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,18 @@
1-
from typing import List
2-
3-
from langdiversity.extras.spinner import loading_spinner
4-
5-
from langdiversity.measures import AbstractMeasure
6-
from langdiversity.models import AbstractBaseModel
1+
from typing import List, Dict, Union
72

83
class PromptSelection:
9-
def __init__(
10-
self,
11-
model: AbstractBaseModel,
12-
diversity_measure: AbstractMeasure,
13-
num_responses: int = 1,
14-
selection: str = "min",
15-
):
4+
def __init__(self, data: List[Dict[str, Union[str, List[str], float]]], selection: str = "min"):
165
valid_selections = ["min", "max"]
176
if selection not in valid_selections:
187
raise ValueError(
198
"Invalid selection type. Expected one of %s" % valid_selections
209
)
21-
self.model = model
22-
self.diversity_measure = diversity_measure
23-
self.num_responses = num_responses
10+
self.data = data
2411
self.selection = selection
2512

26-
def generate(self, prompts: List[str]):
27-
if len(prompts) == 0:
28-
raise ValueError("Invalid prompts. There should be at least 1 prompt.")
29-
30-
selected_prompt = ""
31-
selected_diversity = float("inf") if self.selection == "min" else float("-inf")
32-
33-
info = []
34-
35-
total_prompts = len(prompts)
36-
for i, prompt in enumerate(prompts):
37-
with loading_spinner(f"Collecting {self.num_responses} responses...", current_step=i+1, total_steps=total_prompts):
38-
responses = self.model.generate(prompt, self.num_responses)
39-
with loading_spinner("Performing diversity measure calculations...", current_step=i+1, total_steps=total_prompts):
40-
diversity = self.diversity_measure.generate(responses)
41-
42-
if self.selection == "max" and diversity > selected_diversity:
43-
selected_diversity = diversity
44-
selected_prompt = prompt
45-
if self.selection == "min" and diversity < selected_diversity:
46-
selected_diversity = diversity
47-
selected_prompt = prompt
48-
49-
info.append(
50-
{"responses": responses, "diversity": diversity, "prompt": prompt}
51-
)
13+
def select(self):
14+
if not self.data:
15+
raise ValueError("No data to select from.")
5216

53-
return selected_prompt, selected_diversity
17+
selected_item = min(self.data, key=lambda x: x['diversity']) if self.selection == "min" else max(self.data, key=lambda x: x['diversity'])
18+
return selected_item['prompt'], selected_item['diversity']

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
setup(
99
name='langdiversity',
1010
packages=find_packages(exclude=['tests']),
11-
version='1.0.5',
11+
version='1.1.0',
1212
description='A tool to elevate your language models with insightful diversity metrics.',
1313
long_description=long_description,
1414
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)