Skip to content

Commit 82d9960

Browse files
committed
formatted
1 parent c293201 commit 82d9960

File tree

3 files changed

+14
-18
lines changed

3 files changed

+14
-18
lines changed

src/llm_change_agent/evaluations/evaluator.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
logger.info("Evaluating the LLM Change Agent.")
2929

3030

31-
3231
def download_document(url, input_dir):
3332
"""Download the document from the URL."""
3433
if not os.path.exists(input_dir):
@@ -174,44 +173,39 @@ def generate_changes_via_llm(eval_dir, output_dir, provider, model):
174173
print(f"Predicted changes saved to {output_sub_dir}")
175174

176175

177-
def compare_changes(expected_dir:Path, output_dir:Path):
176+
def compare_changes(expected_dir: Path, output_dir: Path):
178177
"""Compare the actual changes with the predicted changes."""
179178
# For each document in the expected directory, there is a corresponding document in the output directory
180179

181180
output_files = list(output_dir.rglob("*.yaml"))
182181

183182
# output_files_dict is : {provider_model: {filename: file_path}}
184-
output_files_list_of_dicts = [
185-
{f"{file.parts[-3]}_{file.parts[-2]}": {file.name:file}} for file in output_files
186-
]
187-
183+
output_files_list_of_dicts = [{f"{file.parts[-3]}_{file.parts[-2]}": {file.name: file}} for file in output_files]
184+
188185
for model_output in output_files_list_of_dicts:
189186
for provider_model, file_info in model_output.items():
190187
for filename, filepath in file_info.items():
191188
filename = filepath.name
192189
expected_file = expected_dir / filename
193190
output_file = filepath
194-
with open(expected_file, "r") as ex , open(output_file, "r") as out:
191+
with open(expected_file, "r") as ex, open(output_file, "r") as out:
195192
expected_yaml = yaml.safe_load(ex)
196193
output_yaml = yaml.safe_load(out)
197194
expected_yaml_subset = {k: v for k, v in expected_yaml.items() if k in output_yaml}
198195
for pr_id, output_changes in output_yaml.items():
199196
expected_change = expected_yaml_subset.get(pr_id)
200197
if len(output_changes) > 0:
201198
compare_output_vs_expected(expected_change, output_changes)
199+
logger.info(f"Finished comparing changes for {provider_model}")
202200

203201

204-
205-
def compare_output_vs_expected(expected_changes, output_changes:List):
202+
def compare_output_vs_expected(expected_changes, output_changes: List):
206203
"""Compare the expected changes with the output changes."""
207204
output_changes = normalize_changes(output_changes)
208205
accuracy = 0.0
209206
total = len(expected_changes)
210207
correct = 0
211208
import pdb; pdb.set_trace()
212-
213-
214-
215209

216210

217211
def run_evaluate(model: str, provider: str):
@@ -230,4 +224,3 @@ def run_evaluate(model: str, provider: str):
230224
generate_changes_via_llm(model=model, provider=provider, eval_dir=eval_dir, output_dir=output_dir)
231225

232226
compare_changes(expected_dir=expected_dir, output_dir=output_dir)
233-

src/llm_change_agent/utils/general_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ def jaccard_similarity(statement1, statement2):
4646
union = set1.union(set2)
4747

4848
# Calculate the Jaccard similarity coefficient
49-
return len(intersection) / len(union)
49+
return len(intersection) / len(union)

src/llm_change_agent/utils/llm_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from langchain.agents import AgentExecutor
1111
from langchain.agents.react.agent import create_react_agent
1212
from langchain.tools.retriever import create_retriever_tool
13-
from langchain_core.tools import tool
1413
from langchain_chroma import Chroma
1514
from langchain_community.document_loaders import WebBaseLoader
1615
from langchain_core.documents import Document
16+
from langchain_core.tools import tool
1717
from langchain_openai import OpenAIEmbeddings
1818
from langchain_text_splitters import RecursiveCharacterTextSplitter
1919
from openai import OpenAI
@@ -319,21 +319,24 @@ def extract_commands(command):
319319
else:
320320
return cleaned_command
321321

322+
322323
def normalize_changes(changes):
324+
"""Convert IRIs to CURIEs in change statements."""
323325
for idx, change in enumerate(changes):
324326
if any(string.startswith("<http") or string.startswith("http") for string in change.split()):
325-
iri = [string for string in change.split() if string.startswith("<http")or string.startswith("http")]
327+
iri = [string for string in change.split() if string.startswith("<http") or string.startswith("http")]
326328
# Replace the strings in the list with the curie using converter.compress(item)
327329
for _, item in enumerate(iri):
328-
stripped_item = item.strip('<>')
330+
stripped_item = item.strip("<>")
329331
compressed_item = compress_iri(stripped_item) if compress_iri(stripped_item) else item
330332
# Update the original change list with the compressed item
331333
change = change.replace(item, compressed_item)
332334
changes[idx] = change
333335
return changes
334336

337+
335338
@tool
336339
def compress_iri(iri: str) -> str:
337340
"""Compress the IRI."""
338341
converter = curies.get_obo_converter()
339-
return converter.compress(iri)
342+
return converter.compress(iri)

0 commit comments

Comments
 (0)