2
2
from dataclasses import asdict
3
3
from datetime import timedelta
4
4
from pathlib import Path
5
+ from typing import Tuple
5
6
6
7
from beartype .typing import List
7
8
from flytekit import Resources , current_context , dynamic , task
28
29
)
29
30
from pyrovelocity .logging import configure_logging
30
31
from pyrovelocity .tasks .data import download_dataset
32
+ from pyrovelocity .tasks .evaluate import calculate_cross_boundary_correctness
31
33
from pyrovelocity .tasks .postprocess import postprocess_dataset
32
34
from pyrovelocity .tasks .preprocess import preprocess_dataset
33
35
from pyrovelocity .tasks .summarize import summarize_dataset
48
50
SummarizeConfiguration ,
49
51
SummarizeOutputs ,
50
52
TrainingOutputs ,
53
+ TrajectoryEvaluationOutputs ,
51
54
WorkflowConfiguration ,
52
55
bonemarrow_configuration ,
53
56
default_resource_limits ,
75
78
"upload_summary" ,
76
79
"map_model_configurations_over_data_set" ,
77
80
"training_workflow" ,
81
+ "evaluate_trajectory_metrics" ,
78
82
]
79
83
80
84
logger = configure_logging (__name__ )
87
91
SUMMARIZE_CACHE_VERSION = f"{ CACHE_VERSION } .3"
88
92
UPLOAD_CACHE_VERSION = f"{ CACHE_VERSION } .7"
89
93
LINEAGE_FATE_CORRELATION_CACHE_VERSION = f"{ CACHE_VERSION } .8"
94
+ TRAJECTORY_EVALUATION_CACHE_VERSION = f"{ CACHE_VERSION } .0"
90
95
COMBINE_METRICS_CACHE_VERSION = f"{ CACHE_VERSION } .5"
91
96
DEFAULT_ACCELERATOR_TYPE : GPUAccelerator = T4
92
97
@@ -655,6 +660,73 @@ def combine_all_metrics(
655
660
)
656
661
657
662
663
+ @task (
664
+ cache = PYROVELOCITY_CACHE_FLAG ,
665
+ cache_version = TRAJECTORY_EVALUATION_CACHE_VERSION ,
666
+ retries = 3 ,
667
+ interruptible = False ,
668
+ timeout = timedelta (minutes = 120 ),
669
+ requests = Resources (cpu = "8" , mem = "30Gi" , ephemeral_storage = "50Gi" ),
670
+ limits = Resources (cpu = "16" , mem = "60Gi" , ephemeral_storage = "200Gi" ),
671
+ enable_deck = False ,
672
+ )
673
+ def evaluate_trajectory_metrics (
674
+ results : List [List [SummarizeOutputs ]],
675
+ ) -> TrajectoryEvaluationOutputs :
676
+ logger .info ("Evaluating trajectory metrics for datasets" )
677
+
678
+ output_dir = Path ("reports/trajectory_metrics" )
679
+ output_dir .mkdir (parents = True , exist_ok = True )
680
+
681
+ model_results = []
682
+
683
+ for dataset_results in results :
684
+ for model_output in dataset_results :
685
+ postprocessed_data_path = model_output .postprocessed_data .download ()
686
+
687
+ model_results .append (
688
+ {
689
+ "data_model" : model_output .data_model ,
690
+ "postprocessed_data" : postprocessed_data_path ,
691
+ }
692
+ )
693
+
694
+ summary_file , results_dir , plot_file = calculate_cross_boundary_correctness (
695
+ model_results = model_results ,
696
+ output_dir = output_dir ,
697
+ )
698
+
699
+ ctx = current_context ()
700
+ execution_id = ctx .execution_id .name
701
+
702
+ uploaded_files = []
703
+ for file_path in [summary_file , plot_file ]:
704
+ for ext in ["" , ".png" ]:
705
+ if Path (f"{ file_path } { ext } " ).exists ():
706
+ upload_result = upload_file_concurrently (
707
+ bucket_name = f"pyrovelocity/reports/{ execution_id } " ,
708
+ source_filename = f"{ file_path } { ext } " ,
709
+ destination_blob_name = f"{ file_path .name } { ext } " ,
710
+ )
711
+ uploaded_files .append (upload_result )
712
+
713
+ if all (isinstance (result , Success ) for result in uploaded_files ):
714
+ logger .info ("All trajectory metrics files uploaded successfully" )
715
+ else :
716
+ failed_uploads = [
717
+ str (i )
718
+ for i , result in enumerate (uploaded_files )
719
+ if isinstance (result , Failure )
720
+ ]
721
+ logger .warning (f"Failed uploads: { ', ' .join (failed_uploads )} " )
722
+
723
+ return TrajectoryEvaluationOutputs (
724
+ summary_file = FlyteFile (path = str (summary_file )),
725
+ results_directory = FlyteDirectory (path = str (results_dir )),
726
+ plot_file = FlyteFile (path = str (plot_file )),
727
+ )
728
+
729
+
658
730
@dynamic
659
731
def training_workflow (
660
732
simulated_configuration : WorkflowConfiguration = simulated_configuration ,
@@ -668,7 +740,11 @@ def training_workflow(
668
740
larry_neu_configuration : WorkflowConfiguration = larry_neu_configuration ,
669
741
larry_mono_configuration : WorkflowConfiguration = larry_mono_configuration ,
670
742
larry_multilineage_configuration : WorkflowConfiguration = larry_multilineage_configuration ,
671
- ) -> list [list [SummarizeOutputs ]]:
743
+ ) -> Tuple [
744
+ List [List [SummarizeOutputs ]],
745
+ TrajectoryEvaluationOutputs ,
746
+ CombinedMetricsOutputs ,
747
+ ]:
672
748
"""
673
749
Apply the primary workflow to a collection of configurations.
674
750
Conditionally executes configurations based on the value of PYROVELOCITY_DATA_SUBSET.
@@ -688,6 +764,7 @@ def training_workflow(
688
764
]
689
765
690
766
lineage_traced_results = []
767
+ developmental_results = []
691
768
lineage_traced_configurations = [
692
769
(larry_mono_configuration , "larry_mono" ),
693
770
(larry_neu_configuration , "larry_neu" ),
@@ -721,16 +798,30 @@ def training_workflow(
721
798
accelerator_type = config .accelerator_type ,
722
799
upload_results = config .upload_results ,
723
800
)
801
+
724
802
if "larry" in data_set_name :
725
803
lineage_traced_results .append (result )
804
+
805
+ if (
806
+ data_set_name in ["bonemarrow" , "pancreas" , "pons" ]
807
+ or "larry" in data_set_name
808
+ ):
809
+ developmental_results .append (result )
810
+
726
811
results .append (result )
727
812
728
- combine_all_metrics (results = results )
813
+ metrics_outputs = combine_all_metrics (results = results )
729
814
730
815
if len (lineage_traced_results ) > 0 :
731
816
combine_time_lineage_fate_correlation (results = lineage_traced_results )
732
817
733
- return results
818
+ trajectory_outputs = None
819
+ if len (developmental_results ) > 0 :
820
+ trajectory_outputs = evaluate_trajectory_metrics (
821
+ results = developmental_results
822
+ )
823
+
824
+ return results , trajectory_outputs , metrics_outputs
734
825
735
826
736
827
if __name__ == "__main__" :
0 commit comments