Skip to content

Commit b2f2445

Browse files
committed
feat: cleaned up diversity calculator tool
1 parent 4188898 commit b2f2445

File tree

3 files changed

+25
-23
lines changed

3 files changed

+25
-23
lines changed

examples/diversity_calculation.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
from langdiversity.utils import DiversityCalculator
2+
from langdiversity.measures import ShannonEntropyMeasure, GiniImpurityMeasure
23

3-
# Calculate diversity measures
4-
diversity_calculator = DiversityCalculator(default_measures=["entropy", "gini"])
5-
print("DIVERSITY MEASURES:", diversity_calculator.calculate([1, 1, 1, 2]))
4+
# Define a custom measure class
5+
class CustomMeasure:
6+
def generate(self, values):
7+
return sum(values) // len(values)
8+
9+
# Create instances
10+
entropy = ShannonEntropyMeasure()
11+
gini = GiniImpurityMeasure()
12+
custom = CustomMeasure()
13+
14+
# Use built-in and custom measures
15+
diversity_calculator = DiversityCalculator(measures=[entropy, gini, custom])
16+
17+
print("Diversity Measures:", diversity_calculator.calculate([1, 1, 1, 2]))
Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
1-
from ..measures import ShannonEntropyMeasure, GiniImpurityMeasure
1+
from langdiversity.extras.spinner import loading_spinner
22

33
class DiversityCalculator:
4-
# by default, include entropy as part of diversity measures
5-
def __init__(self, default_measures=["entropy"]):
6-
self.default_measures = default_measures
7-
8-
# returns diversity measure results for each question
9-
def calculate(self, values, measures=None):
10-
if measures is None:
11-
measures = self.default_measures
4+
def __init__(self, measures=[]):
5+
self.measures = measures
126

7+
def calculate(self, values):
138
results = {}
14-
15-
if "entropy" in measures:
16-
entropy_measure = ShannonEntropyMeasure()
17-
results["entropy"] = entropy_measure.generate(values)
18-
19-
if "gini" in measures:
20-
gini_measure = GiniImpurityMeasure()
21-
results["gini"] = gini_measure.generate(values)
22-
9+
for measure in self.measures:
10+
display_name = measure.__class__.__name__
11+
with loading_spinner(f"Calculating '{display_name}' for {len(values)} values..."):
12+
results[display_name] = measure.generate(values)
13+
2314
return results
24-

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
setup(
99
name='langdiversity',
1010
packages=find_packages(exclude=['tests']),
11-
version='1.0.3',
11+
version='1.0.5',
1212
description='A tool to elevate your language models with insightful diversity metrics.',
1313
long_description=long_description,
1414
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)