Skip to content

Commit f965992

Browse files
fix(workflows): add task to evaluate trajectory metrics
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
1 parent 1c80ccb commit f965992

File tree

1 file changed

+94
-3
lines changed

1 file changed

+94
-3
lines changed

src/pyrovelocity/workflows/main_workflow.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import asdict
33
from datetime import timedelta
44
from pathlib import Path
5+
from typing import Tuple
56

67
from beartype.typing import List
78
from flytekit import Resources, current_context, dynamic, task
@@ -28,6 +29,7 @@
2829
)
2930
from pyrovelocity.logging import configure_logging
3031
from pyrovelocity.tasks.data import download_dataset
32+
from pyrovelocity.tasks.evaluate import calculate_cross_boundary_correctness
3133
from pyrovelocity.tasks.postprocess import postprocess_dataset
3234
from pyrovelocity.tasks.preprocess import preprocess_dataset
3335
from pyrovelocity.tasks.summarize import summarize_dataset
@@ -48,6 +50,7 @@
4850
SummarizeConfiguration,
4951
SummarizeOutputs,
5052
TrainingOutputs,
53+
TrajectoryEvaluationOutputs,
5154
WorkflowConfiguration,
5255
bonemarrow_configuration,
5356
default_resource_limits,
@@ -75,6 +78,7 @@
7578
"upload_summary",
7679
"map_model_configurations_over_data_set",
7780
"training_workflow",
81+
"evaluate_trajectory_metrics",
7882
]
7983

8084
logger = configure_logging(__name__)
@@ -87,6 +91,7 @@
8791
SUMMARIZE_CACHE_VERSION = f"{CACHE_VERSION}.3"
8892
UPLOAD_CACHE_VERSION = f"{CACHE_VERSION}.7"
8993
LINEAGE_FATE_CORRELATION_CACHE_VERSION = f"{CACHE_VERSION}.8"
94+
TRAJECTORY_EVALUATION_CACHE_VERSION = f"{CACHE_VERSION}.0"
9095
COMBINE_METRICS_CACHE_VERSION = f"{CACHE_VERSION}.5"
9196
DEFAULT_ACCELERATOR_TYPE: GPUAccelerator = T4
9297

@@ -655,6 +660,73 @@ def combine_all_metrics(
655660
)
656661

657662

663+
@task(
664+
cache=PYROVELOCITY_CACHE_FLAG,
665+
cache_version=TRAJECTORY_EVALUATION_CACHE_VERSION,
666+
retries=3,
667+
interruptible=False,
668+
timeout=timedelta(minutes=120),
669+
requests=Resources(cpu="8", mem="30Gi", ephemeral_storage="50Gi"),
670+
limits=Resources(cpu="16", mem="60Gi", ephemeral_storage="200Gi"),
671+
enable_deck=False,
672+
)
673+
def evaluate_trajectory_metrics(
674+
results: List[List[SummarizeOutputs]],
675+
) -> TrajectoryEvaluationOutputs:
676+
logger.info("Evaluating trajectory metrics for datasets")
677+
678+
output_dir = Path("reports/trajectory_metrics")
679+
output_dir.mkdir(parents=True, exist_ok=True)
680+
681+
model_results = []
682+
683+
for dataset_results in results:
684+
for model_output in dataset_results:
685+
postprocessed_data_path = model_output.postprocessed_data.download()
686+
687+
model_results.append(
688+
{
689+
"data_model": model_output.data_model,
690+
"postprocessed_data": postprocessed_data_path,
691+
}
692+
)
693+
694+
summary_file, results_dir, plot_file = calculate_cross_boundary_correctness(
695+
model_results=model_results,
696+
output_dir=output_dir,
697+
)
698+
699+
ctx = current_context()
700+
execution_id = ctx.execution_id.name
701+
702+
uploaded_files = []
703+
for file_path in [summary_file, plot_file]:
704+
for ext in ["", ".png"]:
705+
if Path(f"{file_path}{ext}").exists():
706+
upload_result = upload_file_concurrently(
707+
bucket_name=f"pyrovelocity/reports/{execution_id}",
708+
source_filename=f"{file_path}{ext}",
709+
destination_blob_name=f"{file_path.name}{ext}",
710+
)
711+
uploaded_files.append(upload_result)
712+
713+
if all(isinstance(result, Success) for result in uploaded_files):
714+
logger.info("All trajectory metrics files uploaded successfully")
715+
else:
716+
failed_uploads = [
717+
str(i)
718+
for i, result in enumerate(uploaded_files)
719+
if isinstance(result, Failure)
720+
]
721+
logger.warning(f"Failed uploads: {', '.join(failed_uploads)}")
722+
723+
return TrajectoryEvaluationOutputs(
724+
summary_file=FlyteFile(path=str(summary_file)),
725+
results_directory=FlyteDirectory(path=str(results_dir)),
726+
plot_file=FlyteFile(path=str(plot_file)),
727+
)
728+
729+
658730
@dynamic
659731
def training_workflow(
660732
simulated_configuration: WorkflowConfiguration = simulated_configuration,
@@ -668,7 +740,11 @@ def training_workflow(
668740
larry_neu_configuration: WorkflowConfiguration = larry_neu_configuration,
669741
larry_mono_configuration: WorkflowConfiguration = larry_mono_configuration,
670742
larry_multilineage_configuration: WorkflowConfiguration = larry_multilineage_configuration,
671-
) -> list[list[SummarizeOutputs]]:
743+
) -> Tuple[
744+
List[List[SummarizeOutputs]],
745+
TrajectoryEvaluationOutputs,
746+
CombinedMetricsOutputs,
747+
]:
672748
"""
673749
Apply the primary workflow to a collection of configurations.
674750
Conditionally executes configurations based on the value of PYROVELOCITY_DATA_SUBSET.
@@ -688,6 +764,7 @@ def training_workflow(
688764
]
689765

690766
lineage_traced_results = []
767+
developmental_results = []
691768
lineage_traced_configurations = [
692769
(larry_mono_configuration, "larry_mono"),
693770
(larry_neu_configuration, "larry_neu"),
@@ -721,16 +798,30 @@ def training_workflow(
721798
accelerator_type=config.accelerator_type,
722799
upload_results=config.upload_results,
723800
)
801+
724802
if "larry" in data_set_name:
725803
lineage_traced_results.append(result)
804+
805+
if (
806+
data_set_name in ["bonemarrow", "pancreas", "pons"]
807+
or "larry" in data_set_name
808+
):
809+
developmental_results.append(result)
810+
726811
results.append(result)
727812

728-
combine_all_metrics(results=results)
813+
metrics_outputs = combine_all_metrics(results=results)
729814

730815
if len(lineage_traced_results) > 0:
731816
combine_time_lineage_fate_correlation(results=lineage_traced_results)
732817

733-
return results
818+
trajectory_outputs = None
819+
if len(developmental_results) > 0:
820+
trajectory_outputs = evaluate_trajectory_metrics(
821+
results=developmental_results
822+
)
823+
824+
return results, trajectory_outputs, metrics_outputs
734825

735826

736827
if __name__ == "__main__":

0 commit comments

Comments
 (0)