From eb7303a3edb964ebd9c68fa0edbbc6b76d303384 Mon Sep 17 00:00:00 2001 From: ch1ru Date: Fri, 22 Mar 2024 11:18:40 +0000 Subject: [PATCH 1/3] added tests --- app/pages/2_Preprocess.py | 399 +++++++++++++++++++++-------------- app/tests/run_all_tests.py | 26 ++- app/tests/test_preprocess.py | 193 ++++++++++------- app/utils/session_cache.py | 3 +- 4 files changed, 373 insertions(+), 248 deletions(-) diff --git a/app/pages/2_Preprocess.py b/app/pages/2_Preprocess.py index 89992a9..79f6fd4 100644 --- a/app/pages/2_Preprocess.py +++ b/app/pages/2_Preprocess.py @@ -133,19 +133,26 @@ def remove_genes(self): adata = adata[:,keep] """ with st.form(key="remove_genes_form"): - st.subheader("Remove genes") - remove_genes = st.multiselect(label="Genes", options=st.session_state.adata_state.current.adata.var_names) - subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) - if submit_btn: - with st.spinner(text="Removing genes"): - for gene in remove_genes: - remove_genes = st.session_state.adata_state.current.adata.var_names.str.startswith(gene) - remove = np.array(remove_genes) - keep = np.invert(remove) - st.session_state.adata_state.current.adata = st.session_state.adata_state.current.adata[:,keep] + try: + st.subheader("Remove genes") + remove_genes = st.multiselect(label="Genes", options=self.state_manager.adata_state().current.adata.var_names) + subcol1, _, _ = st.columns(3) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + if submit_btn: + adata = self.state_manager.get_current_adata() + with st.spinner(text="Removing genes"): + for gene in remove_genes: + remove_genes = adata.var_names.str.startswith(gene) + remove = np.array(remove_genes) + keep = np.invert(remove) + adata = adata[:,keep] - # TODO: write to script state + self.state_manager \ + .add_adata(adata) \ + .save_session() + + except Exception as e: + st.toast(e, icon="❌") def filter_highly_variable_genes(self): @@ -279,27 +286,32 @@ def normalize_counts(self): """ with st.form(key="form_normalize_total"): - st.subheader("Normalization") - subcol_input1, subcol_input2 = st.columns(2, gap="medium") - target_sum = subcol_input1.number_input(label="Target sum", value=1.0, key="ni:pp:normalize_counts:target_sum") - max_fraction = subcol_input2.number_input(label="Max fraction", key="ni:pp:normalize_counts:max_fraction", value=0.050, min_value=0.001, max_value=1.000) - exclude_high_expr = subcol_input1.checkbox(label="Exclude highly_expr", value=False) - log_transform_total = subcol_input2.checkbox(label="Log transform", value=False) - subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + try: + st.subheader("Normalization") + subcol_input1, subcol_input2 = st.columns(2, gap="medium") + target_sum = subcol_input1.number_input(label="Target sum", value=1.0, key="ni:pp:normalize_counts:target_sum") + max_fraction = subcol_input2.number_input(label="Max fraction", key="ni:pp:normalize_counts:max_fraction", value=0.050, min_value=0.001, max_value=1.000) + exclude_high_expr = subcol_input1.checkbox(label="Exclude highly_expr", value=False) + log_transform_total = subcol_input2.checkbox(label="Log transform", value=False) + subcol1, _, _ = st.columns(3) + submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) - if submit_btn: - sc.pp.normalize_total(st.session_state.adata_state.current.adata, target_sum=target_sum, exclude_highly_expressed=exclude_high_expr, max_fraction=max_fraction) - if log_transform_total: - sc.pp.log1p(st.session_state.adata_state.current.adata) + if submit_btn: + sc.pp.normalize_total(st.session_state.adata_state.current.adata, target_sum=target_sum, exclude_highly_expressed=exclude_high_expr, max_fraction=max_fraction) + if log_transform_total: + sc.pp.log1p(st.session_state.adata_state.current.adata) - # write to script state - self.state_manager \ - .add_adata(st.session_state.adata_state.current.adata) \ - .add_script(Normalize(language=Language.ALL_SUPPORTED, scale_factor=target_sum, log_norm=log_transform_total)) \ - .save_session() + # write to script state + self.state_manager \ + .add_adata(st.session_state.adata_state.current.adata) \ + .add_script(Normalize(language=Language.ALL_SUPPORTED, scale_factor=target_sum, log_norm=log_transform_total)) \ + .save_session() + + st.toast("Normalized data", icon='✅') + + except Exception as e: + st.toast(e, icon="❌") - st.toast("Normalized data", icon='✅') def filter_cells(self): """ @@ -321,23 +333,27 @@ def filter_cells(self): sc.pp.filter_cells(adata, max_genes=None, min_genes=200, max_counts=None, min_counts=None) """ with st.form(key="form_filter_cells"): - st.subheader("Filter Cells", help="Filter cell outliers based on counts and numbers of genes expressed. Only keep cells with at least min_genes genes expressed. This is equivalent to min.features in Seurat.") - min_genes = st.number_input(label="min genes for cell", min_value=1, value=None, key="ni:pp:filter_cells:min_genes") + try: + st.subheader("Filter Cells", help="Filter cell outliers based on counts and numbers of genes expressed. Only keep cells with at least min_genes genes expressed. This is equivalent to min.features in Seurat.") + min_genes = st.number_input(label="min genes for cell", min_value=1, value=None, key="ni:pp:filter_cells:min_genes") - subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + subcol1, _, _ = st.columns(3) + submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) - if submit_btn: - adata = self.state_manager.get_current_adata() - sc.pp.filter_cells(adata, min_genes=min_genes) + if submit_btn: + adata = self.state_manager.get_current_adata() + sc.pp.filter_cells(adata, min_genes=min_genes) - #make adata - self.state_manager \ - .add_adata(adata) \ - .add_script(Filter_cells(language=Language.ALL_SUPPORTED, min_genes=min_genes)) \ - .save_session() - - st.toast("Filtered cells", icon='✅') + #make adata + self.state_manager \ + .add_adata(adata) \ + .add_script(Filter_cells(language=Language.ALL_SUPPORTED, min_genes=min_genes)) \ + .save_session() + + st.toast("Filtered cells", icon='✅') + + except Exception as e: + st.toast(e, icon="❌") def filter_genes(self): @@ -360,21 +376,25 @@ def filter_genes(self): sc.pp.filter_genes(adata, max_cells=None, min_cells=None, max_counts=None, min_counts=3) """ with st.form(key="form_filter_genes"): - st.subheader("Filter Genes", help="Filter genes based on number of cells or counts. Keep genes that are in at least min_cells cells. Equivalent to min.cells in Seurat.") - min_cells = st.number_input(label="min cells for gene", min_value=1, value=None, key="ni:pp:filter_genes:min_cells") - - subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) - if submit_btn: - adata: AnnData = self.state_manager.get_current_adata() - sc.pp.filter_genes(adata, min_cells=min_cells) - - self.state_manager \ - .add_adata(adata) \ - .add_script(Filter_genes(language=Language.ALL_SUPPORTED, min_cells=min_cells)) \ - .save_session() + try: + st.subheader("Filter Genes", help="Filter genes based on number of cells or counts. Keep genes that are in at least min_cells cells. Equivalent to min.cells in Seurat.") + min_cells = st.number_input(label="min cells for gene", min_value=1, value=None, key="ni:pp:filter_genes:min_cells") + + subcol1, _, _ = st.columns(3) + submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + if submit_btn: + adata: AnnData = self.state_manager.get_current_adata() + sc.pp.filter_genes(adata, min_cells=min_cells) + + self.state_manager \ + .add_adata(adata) \ + .add_script(Filter_genes(language=Language.ALL_SUPPORTED, min_cells=min_cells)) \ + .save_session() + + st.toast("Filtered genes", icon='✅') - st.toast("Filtered genes", icon='✅') + except Exception as e: + st.toast(e, icon="❌") def recipes(self): @@ -420,45 +440,67 @@ def recipes(self): seurat_tab, weinreb17_tab, zheng17_tab = st.tabs(['Seurat', 'Weinreb17', 'Zheng17']) with seurat_tab: with st.form(key="form_seurat"): - st.write("Parameters") - log = st.checkbox(label="Log", value=True) - subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) - - if submit_btn: - sc.pp.recipe_seurat(st.session_state.adata_state.current.adata, log=log) - - # TODO: add to script state - st.toast(f"Applied recipe: Seurat", icon='✅') + try: + st.write("Parameters") + log = st.checkbox(label="Log", value=True, key="cb:pp:recipe:seurat:log") + subcol1, _, _ = st.columns(3) + submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) + + if submit_btn: + adata = self.state_manager.get_current_adata() + sc.pp.recipe_seurat(adata, log=log) + + self.state_manager \ + .add_adata(adata) \ + .save_session() + st.toast(f"Applied recipe: Seurat", icon='✅') + + except Exception as e: + st.toast(e, icon="❌") with weinreb17_tab: with st.form(key="form_weinreb17"): - st.write("Parameters") - col1, col2, col3 = st.columns(3) - mean_threshold = col1.number_input(label="Mean threshold", value=0.01, step=0.01) - cv_threshold = col2.number_input(label="CV threshold", value=2.0, step=1.0) - n_pcs = col3.number_input(label="n_pcs", min_value=1, value=50, step=1, format="%i") - log = st.checkbox(label="Log", value=False) - subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) - if submit_btn: - sc.pp.recipe_weinreb17(st.session_state.adata_state.current.adata, log=log, mean_threshold=mean_threshold, cv_threshold=cv_threshold, n_pcs=n_pcs) - - # TODO: add to script state - st.toast(f"Applied recipe: Weinreb17", icon='✅') + try: + st.write("Parameters") + col1, col2, col3 = st.columns(3) + mean_threshold = col1.number_input(label="Mean threshold", value=0.01, step=0.01) + cv_threshold = col2.number_input(label="CV threshold", value=2.0, step=1.0) + n_pcs = col3.number_input(label="n_pcs", min_value=1, value=50, step=1, format="%i") + log = st.checkbox(label="Log", value=False, key="cb:pp:recipe:weinreb17:log") + subcol1, _, _ = st.columns(3) + submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) + if submit_btn: + adata = self.state_manager.get_current_adata() + sc.pp.recipe_weinreb17(adata, log=log, mean_threshold=mean_threshold, cv_threshold=cv_threshold, n_pcs=n_pcs) + + self.state_manager \ + .add_adata(adata) \ + .save_session() + st.toast(f"Applied recipe: Weinreb17", icon='✅') + + except Exception as e: + st.toast(e, icon="❌") with zheng17_tab: with st.form(key="form_zheng17"): - st.write("Parameters") - n_top_genes = st.number_input(label="n_top_genes", key="ni_zheng17_n_genes", min_value=1, max_value=self.state_manager.adata_state().current.adata.n_vars, value=1000 if st.session_state.adata_state.current.adata.n_vars >= 1000 else st.session_state.adata_state.current.adata.n_vars) - log = st.checkbox(label="Log", value=False) - subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) - if submit_btn: - sc.pp.recipe_zheng17(st.session_state.adata_state.current.adata, log=log, n_top_genes=n_top_genes) - - # TODO: add to script state - st.toast(f"Applied recipe: Zheng17", icon='✅') + try: + st.write("Parameters") + n_vars = self.state_manager.adata_state().current.adata.n_vars + n_top_genes = st.number_input(label="n_top_genes", key="ni:pp:recipe:zheng17:n_genes", min_value=1, max_value=n_vars, value=min(1000, n_vars)) + log = st.checkbox(label="Log", value=False) + subcol1, _, _ = st.columns(3) + submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) + if submit_btn: + adata = self.state_manager.get_current_adata() + sc.pp.recipe_zheng17(adata, log=log, n_top_genes=n_top_genes) + + self.state_manager \ + .add_adata(adata) \ + .save_session() + st.toast(f"Applied recipe: Zheng17", icon='✅') + + except Exception as e: + st.toast(e, icon="❌") def annotate(self): @@ -795,40 +837,46 @@ def run_scrublet(self): with st.form(key="scrublet_form"): st.subheader("Scrublet", help="Use Scrublet to remove cells predicted to be doublets.") col1, col2, col3 = st.columns(3) - sim_doublet_ratio = col1.number_input(label="Sim doublet ratio", value=2.00, key="ni_sim_doublet_ratio") - expected_doublet_rate = col2.number_input(label="Exp doublet rate", value=0.05, key="ni_expected_doublet_rate") - stdev_doublet_rate = col3.number_input(label="stdev_doublet_rate", value=0.02, key="ni_stdev_doublet_rate") - batch_key = st.selectbox(label="Batch key", key="sb_scrublet_batch_key", options=np.append('None', st.session_state.adata_state.current.adata.obs_keys())) + sim_doublet_ratio = col1.number_input(label="Sim doublet ratio", value=2.00, key="ni:pp:scrublet:sim_doublet_ratio") + expected_doublet_rate = col2.number_input(label="Exp doublet rate", value=0.05, key="ni:pp:scrublet:expected_doublet_rate") + stdev_doublet_rate = col3.number_input(label="stdev_doublet_rate", value=0.02, key="ni:pp:scrublet:stdev_doublet_rate") + batch_key = st.selectbox(label="Batch key", key="sb:pp:scrublet:batch_key", options=np.append('None', self.state_manager.adata_state().current.adata.obs_keys())) subcol1, _, _ = st.columns(3) scrublet_submit = subcol1.form_submit_button(label="Run", use_container_width=True) if scrublet_submit: try: + adata = self.state_manager.get_current_adata() with st.spinner("Running scrublet"): + if batch_key == 'None': batch_key = None - sc.external.pp.scrublet(st.session_state.adata_state.current.adata, sim_doublet_ratio=sim_doublet_ratio, - expected_doublet_rate=expected_doublet_rate, stdev_doublet_rate=stdev_doublet_rate, batch_key=batch_key, random_state=42) + + sc.external.pp.scrublet(adata, sim_doublet_ratio=sim_doublet_ratio, expected_doublet_rate=expected_doublet_rate, + stdev_doublet_rate=stdev_doublet_rate, batch_key=batch_key, verbose=False, random_state=42) + # plot PCA to see doublets - sc.external.pl.scrublet_score_distribution(st.session_state.adata_state.current.adata) - sc.pp.pca(st.session_state.adata_state.current.adata) - sc.pp.neighbors(st.session_state.adata_state.current.adata) - sc.tl.umap(st.session_state.adata_state.current.adata) + sc.external.pl.scrublet_score_distribution(adata) + sc.pp.pca(adata) + sc.pp.neighbors(adata) + sc.tl.umap(adata) stats, umap, distplot, simulated_doublets = st.tabs(['Stats', 'UMAP', 'Distplot', 'Simulated doublets']) + with stats: - num_of_doublets = st.session_state.adata_state.current.adata.obs.predicted_doublet.sum() - predicted_doublets = '{:.2%}'.format(num_of_doublets / st.session_state.adata_state.current.adata.n_obs) + num_of_doublets = adata.obs.predicted_doublet.sum() + predicted_doublets = '{:.2%}'.format(num_of_doublets / adata.n_obs) st.write(f"Number of predicted doublets: {num_of_doublets}") st.write(f"Percentage of predicted doublets: {predicted_doublets}") + if batch_key: - st.session_state.adata_state.current.adata.obs[batch_key] = st.session_state.adata_state.current.adata.obs[batch_key].astype('category') - batches = st.session_state.adata_state.current.adata.obs[batch_key].cat.categories + adata.obs[batch_key] = adata.obs[batch_key].astype('category') + batches = adata.obs[batch_key].cat.categories stats_df = pd.DataFrame(columns=['batch', 'mean_doublet_score', 'doublet_rate']) stats_df['batch'] = np.array(batches) for batch in batches: - stats_df['mean_doublet_score'].loc[stats_df.batch == batch] = st.session_state.adata_state.current.adata.obs.doublet_score[st.session_state.adata_state.current.adata.obs[batch_key] == batch].mean() + stats_df['mean_doublet_score'].loc[stats_df.batch == batch] = adata.obs.doublet_score[adata.obs[batch_key] == batch].mean() stats_df['doublet_rate'].loc[stats_df.batch == batch] = \ - st.session_state.adata_state.current.adata.obs.predicted_doublet[st.session_state.adata_state.current.adata.obs[batch_key] == batch].sum() / st.session_state.adata_state.current.adata[st.session_state.adata_state.current.adata.obs[batch_key] == batch].n_obs + adata.obs.predicted_doublet[adata.obs[batch_key] == batch].sum() / adata[adata.obs[batch_key] == batch].n_obs # add % sign stats_df['doublet_rate'] = stats_df['doublet_rate'].map('{:.2%}'.format) st.dataframe(stats_df, hide_index=True, use_container_width=True, @@ -844,32 +892,40 @@ def run_scrublet(self): } ) + st.session_state["scrublet_mean_doublet_score"] = stats_df["mean_doublet_score"] + st.session_state["scrublet_doublet_rate"] = stats_df["doublet_rate"] + with umap: - df = pd.DataFrame({'UMAP 1': st.session_state.adata_state.current.adata.obsm['X_umap'][:,0], 'UMAP 2': st.session_state.adata_state.current.adata.obsm['X_umap'][:,1], 'Doublet score': st.session_state.adata_state.current.adata.obs.doublet_score}) + df = pd.DataFrame({'UMAP 1': adata.obsm['X_umap'][:,0], 'UMAP 2': adata.obsm['X_umap'][:,1], 'Doublet score': adata.obs.doublet_score}) st.scatter_chart(df, x='UMAP 1', y='UMAP 2', color='Doublet score', size=10) with distplot: if batch_key: for i, batch in enumerate(st.session_state.adata_state.current.adata.obs[batch_key].unique()): line_colors = ['#31abe8', '#8ee065', '#eda621', '#f071bf', '#9071f0', '#71e3f0', '#2f39ed', '#ed2f7b'] - fig = ff.create_distplot([st.session_state.adata_state.current.adata.obs.doublet_score[st.session_state.adata_state.current.adata.obs[batch_key] == batch]], group_labels=['doublet_score'], colors=[line_colors[i % len(line_colors)]], + fig = ff.create_distplot([adata.obs.doublet_score[adata.obs[batch_key] == batch]], group_labels=['doublet_score'], colors=[line_colors[i % len(line_colors)]], bin_size=0.02, show_rug=False, show_curve=False) fig.update_layout(yaxis_type="log") - fig.add_vline(x=st.session_state.adata_state.current.adata.uns['scrublet']["batches"][batch]['threshold'], line_color="red") + fig.add_vline(x=adata.uns['scrublet']["batches"][batch]['threshold'], line_color="red") fig.update_layout(xaxis_title="Doublet score", yaxis_title="Probability density", title=f"Observed transcriptomes for batch {batch}") st.plotly_chart(fig, use_container_width=True) else: - fig = ff.create_distplot([st.session_state.adata_state.current.adata.obs.doublet_score.values], group_labels=['doublet_score'], + fig = ff.create_distplot([adata.obs.doublet_score.values], group_labels=['doublet_score'], bin_size=0.02, show_rug=False, show_curve=False, colors=["#31abe8"]) fig.update_layout(yaxis_type="log") - fig.add_vline(x=st.session_state.adata_state.current.adata.uns['scrublet']['threshold'], line_color="red") + fig.add_vline(x=adata.uns['scrublet']['threshold'], line_color="red") fig.update_layout(xaxis_title="Doublet score", yaxis_title="Probability density", title="Observed transcriptomes") st.plotly_chart(fig, use_container_width=True) with simulated_doublets: st.write("To implement") - # TODO: save to script state - st.toast("Run scrublet", icon="✅") + + + self.state_manager \ + .add_adata(adata) \ + .save_session() + + st.toast("Run scrublet", icon="✅") except Exception as e: st.toast(e, icon="❌") @@ -1118,7 +1174,7 @@ def downsample_data(self): counts_per_cell, total_counts = st.tabs(["counts_per_cell", "total_counts"]) with counts_per_cell: with st.form(key="downsample_form_counts_per_cell"): - counts_per_cell = st.number_input(label="Counts per cell", key="ni:pp:downsample_counts_per_cell", help="Target total counts per cell. If a cell has more than 'counts_per_cell', it will be downsampled to this number. Resulting counts can be specified on a per cell basis by passing an array.Should be an integer or integer ndarray with same length as number of obs.") + counts_per_cell = st.number_input(label="Counts per cell", value=1, step=1, format="%i", key="ni:pp:downsample:counts_per_cell", help="Target total counts per cell. If a cell has more than 'counts_per_cell', it will be downsampled to this number. Resulting counts can be specified on a per cell basis by passing an array.Should be an integer or integer ndarray with same length as number of obs.") subcol1, _, _ = st.columns(3) btn_downsample_counts_per_cell = subcol1.form_submit_button(label="Apply", use_container_width=True) if btn_downsample_counts_per_cell: @@ -1127,11 +1183,10 @@ def downsample_data(self): self.state_manager \ .add_adata(adata) \ .save_session() - # TODO: add to script state st.toast("Successfully downsampled data per cell", icon="✅") with total_counts: with st.form(key="downsample_form_total_counts"): - total_counts = st.number_input(label="Total counts", key="ni:pp:downsample_total_counts", help="Target total counts. If the count matrix has more than total_counts it will be downsampled to have this number.") + total_counts = st.number_input(label="Total counts", key="ni:pp:downsample:total_counts", help="Target total counts. If the count matrix has more than total_counts it will be downsampled to have this number.") subcol1, _, _ = st.columns(3) btn_downsample_total_counts = subcol1.form_submit_button(label="Apply", use_container_width=True) if btn_downsample_total_counts: @@ -1140,7 +1195,6 @@ def downsample_data(self): self.state_manager \ .add_adata(adata) \ .save_session() - # TODO: add to script state st.toast("Successfully downsampled data by total counts", icon="✅") except Exception as e: @@ -1174,26 +1228,38 @@ def subsample_data(self): """ st.subheader("Subsample data") n_obs, fraction = st.tabs(["n_obs", "fraction"]) - with n_obs: - with st.form(key="subsample_form_n_obs"): - n_obs = st.number_input(label="n obs", key="ni_subsample_n_obs", help="Subsample to this number of observations.") - subcol1, _, _ = st.columns(3) - btn_subsample_n_obs = subcol1.form_submit_button(label="Apply", use_container_width=True) - if btn_subsample_n_obs: - sc.pp.subsample(st.session_state.adata_state.current.adata, n_obs=st.session_state.ni_n_obs, random_state=42) + try: + with n_obs: + with st.form(key="subsample_form_n_obs"): + n_obs = self.state_manager.adata_state().current.adata.n_obs + n_obs = st.number_input(label="n obs", key="ni:pp:subsample:n_obs", help="Subsample to this number of observations.", value=n_obs, step=1, format="%i", max_value=n_obs) + subcol1, _, _ = st.columns(3) + btn_subsample_n_obs = subcol1.form_submit_button(label="Apply", use_container_width=True) + if btn_subsample_n_obs: + adata: AnnData = self.state_manager.get_current_adata() + sc.pp.subsample(adata, n_obs=n_obs, random_state=42) - # TODO: add to script state - st.toast("Successfully subsampled data", icon="✅") - with fraction: - with st.form(key="subsample_form_fraction"): - fraction = st.number_input(label="subsample_fraction", key="ni_subsample_fraction", help="Subsample this fraction of the number of observations.") - subcol1, _, _ = st.columns(3) - btn_subsample_fraction = subcol1.form_submit_button(label="Apply", use_container_width=True) - if btn_subsample_fraction: - sc.pp.subsample(st.session_state.adata_state.current.adata, fraction=st.session_state.ni_subsample_fraction, random_state=42) + self.state_manager \ + .add_adata(adata) \ + .save_session() + st.toast(f"Successfully subsampled data to {n_obs} observations", icon="✅") + with fraction: + with st.form(key="subsample_form_fraction"): + fraction = st.number_input(label="subsample_fraction", key="ni:pp:subsample:fraction", help="Subsample this fraction of the number of observations.") + subcol1, _, _ = st.columns(3) + btn_subsample_fraction = subcol1.form_submit_button(label="Apply", use_container_width=True) + if btn_subsample_fraction: + adata: AnnData = self.state_manager.get_current_adata() + sc.pp.subsample(adata, fraction=fraction, random_state=42) - # TODO: add to script state - st.toast("Successfully subsampled data", icon="✅") + self.state_manager \ + .add_adata(adata) \ + .save_session() + st.toast(f"Successfully subsampled data to {fraction * 100}% original value", icon="✅") + + except Exception as e: + print(e) + st.toast(e, icon="❌") def batch_effect_removal(self): """ @@ -1220,21 +1286,28 @@ def batch_effect_removal(self): sc.pp.combat(adata, key="batch", covariates=None) """ with st.form(key="batch_effect_removal_form"): - st.subheader("Batch effect correction", help="Uses Combat to correct non-biological differences caused by batch effect.") - index = 0 - for i, obs in enumerate(st.session_state.adata_state.current.adata.obs_keys()): - if obs.lower().replace("_", "").__contains__("batch"): - index = i - key = st.selectbox(label="Key", options=st.session_state.adata_state.current.adata.obs_keys(), key="sb_batch_effect_key", index=index) - covariates = st.multiselect(placeholder="Optional", label="Covariates", options=st.session_state.adata_state.current.adata.obs_keys()) - subcol1, _, _ = st.columns(3) - btn_batch_effect_removal = subcol1.form_submit_button(label="Apply", use_container_width=True) - if btn_batch_effect_removal: - with st.spinner(text="Running Combat batch effect correction"): - sc.pp.combat(st.session_state.adata_state.current.adata, key=key, covariates=covariates, inplace=True) - - # TODO: add to script state - st.toast("Batch corrected data", icon='✅') + try: + st.subheader("Batch effect correction", help="Uses Combat to correct non-biological differences caused by batch effect.") + index = 0 + for i, obs in enumerate(self.state_manager.adata_state().current.adata.obs_keys()): + if obs.lower().replace("_", "").__contains__("batch"): + index = i + key = st.selectbox(label="Batch key", options=self.state_manager.adata_state().current.adata.obs_keys(), key="sb:pp:combat:batch_key", index=index) + covariates = st.multiselect(placeholder="Optional", label="Covariates", options=self.state_manager.adata_state().current.adata.obs_keys()) + subcol1, _, _ = st.columns(3) + btn_batch_effect_removal = subcol1.form_submit_button(label="Apply", use_container_width=True) + if btn_batch_effect_removal: + adata = self.state_manager.get_current_adata() + with st.spinner(text="Running Combat batch effect correction"): + sc.pp.combat(adata, key=key, covariates=covariates, inplace=True) + + self.state_manager \ + .add_adata(adata) \ + .save_session() + st.toast("Batch corrected data", icon='✅') + + except Exception as e: + st.toast(e, icon="❌") def pca(self): @@ -1259,36 +1332,38 @@ def pca(self): """ with st.form(key="pca_pp_form"): st.subheader("PCA") + adata: AnnData = self.state_manager.get_current_adata() - def run_pca(adata): + def run_pca(adata: AnnData): with st.spinner(text="Running PCA"): sc.pp.pca(adata, random_state=42) - pp_pca_df = pd.DataFrame({'pca1': adata.obsm['X_pca'][:,0], 'pca2': adata.obsm['X_pca'][:,1], 'color': adata.obs[f'{st.session_state.sb_pca_color_pp}']}) + pp_pca_df = pd.DataFrame({'pca1': adata.obsm['X_pca'][:,0], 'pca2': adata.obsm['X_pca'][:,1], 'color': adata.obs[f'{st.session_state["sb:pp:pca:color"]}']}) st.session_state["preprocess_plots"]["pca"] = dict(df=pp_pca_df) + + self.state_manager \ + .add_adata(adata) \ + .add_script(PCA(language=Language.ALL_SUPPORTED, color=pca_color)) \ + .save_session() index = 0 - for i, item in enumerate(st.session_state.adata_state.current.adata.obs_keys()): + for i, item in enumerate(adata.obs_keys()): if item.lower().replace("_", "").__contains__("batch"): #give precedence to batch if present since it is relevant to preprocessing index = i - pca_color = st.selectbox(label="Color", options=st.session_state.adata_state.current.adata.obs_keys(), key="sb_pca_color_pp", index=index) + pca_color = st.selectbox(label="Color", options=adata.obs_keys(), key="sb:pp:pca:color", index=index) subcol1, _, _ = st.columns(3) pca_pp_btn = subcol1.form_submit_button("Apply", use_container_width=True) pca_empty = st.empty() if st.session_state["preprocess_plots"]["pca"] == None: - run_pca(st.session_state.adata_state.current.adata) + run_pca(adata) pca_empty.empty() pca_empty.scatter_chart(data=st.session_state["preprocess_plots"]["pca"]['df'], x='pca1', y='pca2', color='color', size=18) if pca_pp_btn: - run_pca(st.session_state.adata_state.current.adata) - - self.state_manager \ - .add_adata(st.session_state.adata_state.current.adata) \ - .add_script(PCA(language=Language.ALL_SUPPORTED, color=pca_color)) \ - .save_session() + adata: AnnData = self.state_manager.get_current_adata() + run_pca(adata) pca_empty.empty() pca_empty.scatter_chart(data=st.session_state["preprocess_plots"]["pca"]['df'], x='pca1', y='pca2', color='color', size=18) diff --git a/app/tests/run_all_tests.py b/app/tests/run_all_tests.py index 20a9f29..f5c13c6 100644 --- a/app/tests/run_all_tests.py +++ b/app/tests/run_all_tests.py @@ -15,8 +15,8 @@ from state.AdataState import AdataState from database.schemas import schemas from models.WorkspaceModel import WorkspaceModel +from state.StateManager import StateManager import os -import time from random import randrange import shutil from datetime import datetime, date @@ -50,15 +50,29 @@ class bcolors: print(f"{bcolors.OKGREEN}TEST PASSED{bcolors.ENDC}") print() - print(f"{bcolors.BOLD}===============Testing Preprocess===============") - pp_test = Test_Preprocess(session_state=upload_state) - pp_state = pp_test.get_final_session_state() + print(f"{bcolors.BOLD}===============Testing Preprocess Pipeline 1===============") + pp_test = Test_Preprocess(session_state=upload_state, pipeline=1) + pp_state1 = pp_test.get_final_session_state() + print() + print(f"{bcolors.OKGREEN}TEST PASSED{bcolors.ENDC}") + + print() + print(f"{bcolors.BOLD}===============Testing Preprocess Pipeline 2===============") + pp_test = Test_Preprocess(session_state=upload_state, pipeline=2) + pp_state2 = pp_test.get_final_session_state() + print() + print(f"{bcolors.OKGREEN}TEST PASSED{bcolors.ENDC}") + + print() + print(f"{bcolors.BOLD}===============Testing Preprocess Pipeline 3===============") + pp_test = Test_Preprocess(session_state=upload_state, pipeline=3) + pp_state3 = pp_test.get_final_session_state() print() print(f"{bcolors.OKGREEN}TEST PASSED{bcolors.ENDC}") print() print(f"{bcolors.BOLD}===============Testing Create Citeseq Model===============") - create_model_test = Test_Create_Model(session_state=pp_state, model="Citeseq (dimensionality reduction)") + create_model_test = Test_Create_Model(session_state=pp_state2, model="Citeseq (dimensionality reduction)") create_model_state = create_model_test.get_final_session_state() print() print(f"{bcolors.OKGREEN}TEST PASSED{bcolors.ENDC}") @@ -79,7 +93,7 @@ class bcolors: print() print(f"{bcolors.BOLD}===============Testing Create Solo Model===============") - create_model_test = Test_Create_Model(session_state=pp_state, model="Solo (doublet removal)") + create_model_test = Test_Create_Model(session_state=pp_state1, model="Solo (doublet removal)") create_model_state = create_model_test.get_final_session_state() print() print(f"{bcolors.OKGREEN}TEST PASSED{bcolors.ENDC}") diff --git a/app/tests/test_preprocess.py b/app/tests/test_preprocess.py index fc37029..ca9522b 100644 --- a/app/tests/test_preprocess.py +++ b/app/tests/test_preprocess.py @@ -1,16 +1,11 @@ from streamlit.testing.v1 import AppTest import os from anndata import AnnData -from state.AdataState import AdataState -from models.AdataModel import AdataModel -from models.AdataModel import AdataModel -from models.WorkspaceModel import WorkspaceModel from database.database import SessionLocal from sqlalchemy.orm import Session from database.schemas import schemas -from matplotlib.testing.compare import compare_images -from pdf2image import convert_from_path import scanpy as sc +import math import numpy as np import pandas as pd from utils.plotting import highest_expr_genes_box_plot, plot_doubletdetection_threshold_heatmap @@ -27,7 +22,7 @@ class bcolors: UNDERLINE = '\033[4m' class Test_Preprocess: - def __init__(self, session_state = None): + def __init__(self, session_state = None, pipeline: int = 1): print(f"{bcolors.OKBLUE}Initialising page... {bcolors.ENDC}", end="") self.at = AppTest.from_file("pages/2_Preprocess.py") self.conn: Session = SessionLocal() @@ -37,7 +32,18 @@ def __init__(self, session_state = None): self.at.run(timeout=500) assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + + if pipeline == 1: + self.pipeline1() + elif pipeline == 2: + self.pipeline2() + elif pipeline == 3: + self.pipeline3() + else: + raise Exception("Unknown run") + + def pipeline1(self): self.test_add_and_delete_adata() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") @@ -50,6 +56,10 @@ def __init__(self, session_state = None): assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + self.test_doublet_prediction() + assert not self.at.exception + print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + self.test_scale_data() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") @@ -69,14 +79,6 @@ def __init__(self, session_state = None): self.test_filter_genes() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - - self.test_doublet_prediction() - assert not self.at.exception - print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - - self.test_pp_recipe() - assert not self.at.exception - print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") self.test_annot_mito() assert not self.at.exception @@ -94,15 +96,19 @@ def __init__(self, session_state = None): assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - #self.test_cell_cycle_scoring() + self.test_cell_cycle_scoring() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - self.test_downsampling_data() + self.test_downsampling_data_total_counts() + assert not self.at.exception + print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + + self.test_downsampling_data_counts_per_cell() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - self.test_subsampling_data() + self.test_subsampling_data_fraction() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") @@ -118,11 +124,22 @@ def __init__(self, session_state = None): assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - self.test_save_adata() + # self.test_save_adata() + # assert not self.at.exception + # print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + + def pipeline2(self): + self.test_subsampling_data_n_obs() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - - + + + def pipeline3(self): + self.test_pp_recipe() + assert not self.at.exception + print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + + def test_add_and_delete_adata(self): print(f"{bcolors.OKBLUE}test_add_and_delete_data {bcolors.ENDC}", end="") @@ -185,8 +202,6 @@ def test_highly_variable_genes(self): assert_series_equal(adata1.var.dispersions, adata2.var.dispersions) - - def test_filter_cells(self): print(f"{bcolors.OKBLUE}test_filter_cells {bcolors.ENDC}", end="") # Check inputs are correct @@ -232,26 +247,36 @@ def test_filter_genes(self): def test_pp_recipe(self): print(f"{bcolors.OKBLUE}test_pp_recipe {bcolors.ENDC}", end="") - self.at.number_input(key="ni_zheng17_n_genes").set_value(1100).run() - self.at.button(key="FormSubmitter:form_zheng17-Apply").click().run(timeout=100) - #assert self.at.session_state.adata_state.current.adata.n_vars == 1100 + adata: AnnData = sc.datasets.pbmc3k() + self.at.session_state.adata_state.current.adata = adata + sc.write(adata=adata, filename=self.at.session_state.adata_state.current.filename) + sc.pp.recipe_seurat(adata, log=True) + self.at.checkbox(key="cb:pp:recipe:seurat:log").set_value(True) + self.at.button(key="FormSubmitter:form_seurat-Apply").click().run(timeout=100) + assert_frame_equal(self.at.session_state.adata_state.current.adata.to_df(), adata.to_df()) + assert_frame_equal(self.at.session_state.adata_state.current.adata.obs, adata.obs) + assert_frame_equal(self.at.session_state.adata_state.current.adata.var, adata.var) def test_doublet_prediction(self): print(f"{bcolors.OKBLUE}test_doublet_prediction {bcolors.ENDC}", end="") - #adata = self.at.session_state.adata_state.current.adata.copy() - self.at.number_input(key="ni_sim_doublet_ratio").set_value(2.10).run() - self.at.number_input(key="ni_expected_doublet_rate").set_value(0.06).run() - self.at.number_input(key="ni_stdev_doublet_rate").set_value(0.02).run() - self.at.selectbox(key="sb_scrublet_batch_key").select("BATCH") + # Check inputs + assert self.at.session_state.adata_state.current.adata.n_obs == 9288 + assert self.at.session_state.adata_state.current.adata.n_vars == 1222 + from_file_adata = sc.read(self.at.session_state.adata_state.current.filename) + assert from_file_adata.n_vars == 1222 + assert from_file_adata.n_obs == 9288 + adata = self.at.session_state.adata_state.current.adata.copy() + self.at.number_input(key="ni:pp:scrublet:sim_doublet_ratio").set_value(2.10) + self.at.number_input(key="ni:pp:scrublet:expected_doublet_rate").set_value(0.06) + self.at.number_input(key="ni:pp:scrublet:stdev_doublet_rate").set_value(0.02) + self.at.selectbox(key="sb:pp:scrublet:batch_key").select("BATCH") self.at.button(key="FormSubmitter:scrublet_form-Run").click().run(timeout=100) - # sc.external.pp.scrublet(adata, sim_doublet_ratio=2.1, expected_doublet_rate=0.06, stdev_doublet_rate=0.02, batch_key="BATCH", random_state=42) + sc.external.pp.scrublet(adata, sim_doublet_ratio=2.10, expected_doublet_rate=0.06, stdev_doublet_rate=0.02, verbose=False, batch_key="BATCH", random_state=42) - # for i, score in enumerate(adata.obs.doublet_score): - # assert score == self.at.session_state.adata_state.current.adata.obs.doublet_score[i] - # for i, pred in enumerate(adata.obs.predicted_doublet): - # assert pred == self.at.session_state.adata_state.current.adata.obs.predicted_doublet[i] + assert_series_equal(adata.obs.doublet_score, self.at.session_state.adata_state.current.adata.obs.doublet_score) + assert_series_equal(adata.obs.predicted_doublet, self.at.session_state.adata_state.current.adata.obs.predicted_doublet) def test_annot_mito(self): @@ -342,65 +367,75 @@ def test_scale_data(self): def test_batch_effect_removal_and_pca(self): print(f"{bcolors.OKBLUE}test_batch_effect_removal_and_adata {bcolors.ENDC}", end="") - # adata_original = self.at.session_state.adata_state.current.adata.copy() - # sc.pp.combat(adata_original, key='BATCH') - # sc.pp.pca(adata_original, random_state=42) - # self.at.selectbox(key="sb_batch_effect_key").select("BATCH").run() - # self.at.button(key="FormSubmitter:batch_effect_removal_form-Apply").click().run(timeout=500) - #first test pca in combat - #for i, item in enumerate(adata_original.obsm['X_pca']): - #assert np.array_equal(item, self.at.session_state.adata_state.current.adata.obsm['X_pca'][i]) - - #next test correct pca plot is generated - # df = pd.DataFrame({'pca1': adata_original.obsm['X_pca'][:,0], 'pca2': adata_original.obsm['X_pca'][:,1], 'color': adata_original.obs[f'{self.at.session_state.sb_pca_color_pp}']}) - # self.at.selectbox(key="sb_pca_color_pp").select("BATCH").run() - # self.at.button(key="FormSubmitter:pca_pp_form-Apply").click().run(timeout=100) - # assert np.array_equal(self.at.session_state['pp_df_pca']['pca1'], df['pca1']) - # assert np.array_equal(self.at.session_state['pp_df_pca']['pca2'], df['pca2']) - # assert np.array_equal(self.at.session_state['pp_df_pca']['color'], df['color']) - - - def test_downsampling_data(self): - print(f"{bcolors.OKBLUE}test_downsampling_data {bcolors.ENDC}", end="") - - print(self.at.session_state.adata_state.current.adata.to_df()) - self.at.number_input(key="ni:pp:downsample_total_counts").set_value(1000) - self.at.button(key="FormSubmitter:downsample_form_total_counts-Apply").click().run(timeout=100) - print(self.at.session_state.adata_state.current.adata.to_df()) - print(self.at.session_state.adata_state.current.adata.to_df().sum().sum()) - assert self.at.session_state.adata_state.current.adata.to_df().sum().sum() == 1000 + + adata: AnnData = self.at.session_state.adata_state.current.adata.copy() + sc.pp.combat(adata, key='BATCH', inplace=True) + sc.pp.pca(adata, random_state=42) + + # Run Combat + self.at.selectbox(key="sb:pp:combat:batch_key").select("BATCH") + self.at.button(key="FormSubmitter:batch_effect_removal_form-Apply").click().run(timeout=500) + # Now run PCA + self.at.selectbox(key="sb:pp:pca:color").select("BATCH") + self.at.button(key="FormSubmitter:pca_pp_form-Apply").click().run(timeout=100) + + assert np.array_equal(adata.obsm['X_pca'], self.at.session_state.adata_state.current.adata.obsm['X_pca']) - # TODO: get counts per cell working - # self.at.number_input(key="ni:pp:downsample_counts_per_cell").set_value(1000) - # self.at.button(key="FormSubmitter:downsample_form_counts_per_cell-Apply").click().run(timeout=100) - # assert self.at.session_state.adata_state.current.adata.n_obs * 1.5 == self.at.session_state.adata_state.current.adata.to_df().sum().sum() + # Test plot generated + pp_pca_df = pd.DataFrame({'pca1': adata.obsm['X_pca'][:,0], 'pca2': adata.obsm['X_pca'][:,1], 'color': adata.obs['BATCH']}) + plots_dict: dict = self.at.session_state["preprocess_plots"]["pca"] + assert_frame_equal(plots_dict.get('df'), pp_pca_df) + + def test_downsampling_data_total_counts(self): + print(f"{bcolors.OKBLUE}test_downsampling_data_total_counts {bcolors.ENDC}", end="") + # Total counts + self.at.number_input(key="ni:pp:downsample:total_counts").set_value(1000.0) + self.at.button(key="FormSubmitter:downsample_form_total_counts-Apply").click().run(timeout=100) + assert self.at.session_state.adata_state.current.adata.to_df().sum().sum() == 1000.0 - def test_subsampling_data(self): - print(f"{bcolors.OKBLUE}test_subsampling_data {bcolors.ENDC}", end="") + def test_downsampling_data_counts_per_cell(self): + print(f"{bcolors.OKBLUE}test_downsampling_data_counts_per_cell {bcolors.ENDC}", end="") + # Counts per cell + # Reset to raw to avoid negative values which skew counts + # TODO: Possibly reset to raw via another method? + self.at.number_input(key="ni:pp:downsample:counts_per_cell").set_value(2) + self.at.button(key="FormSubmitter:downsample_form_counts_per_cell-Apply").click().run(timeout=100) + #assert self.at.session_state.adata_state.current.adata.to_df().sum().sum() == self.at.session_state.adata_state.current.adata.n_obs * 2 + + + def test_subsampling_data_fraction(self): + print(f"{bcolors.OKBLUE}test_subsampling_data_fraction {bcolors.ENDC}", end="") original_n_obs = self.at.session_state.adata_state.current.adata.n_obs - fraction = 0.9 - self.at.number_input(key="ni_subsample_fraction").set_value(fraction).run(timeout=100) + fraction = 0.95 + self.at.number_input(key="ni:pp:subsample:fraction").set_value(fraction) self.at.button(key="FormSubmitter:subsample_form_fraction-Apply").click().run(timeout=100) subsampled_n_obs = self.at.session_state.adata_state.current.adata.n_obs - assert float(subsampled_n_obs) == original_n_obs * fraction + assert subsampled_n_obs == math.floor(original_n_obs * fraction) - original_n_obs = self.at.session_state.adata_state.current.adata - n_obs = round(original_n_obs * 0.9) - self.at.number_input(key="ni_subsample_n_obs").set_value(n_obs).run(timeout=100) + + + def test_subsampling_data_n_obs(self): + print(f"{bcolors.OKBLUE}test_subsampling_data_n_obs {bcolors.ENDC}", end="") + fraction = 0.95 + original_n_obs = self.at.session_state.adata_state.current.adata.n_obs + n_obs = math.floor(original_n_obs * fraction) + print(n_obs) + self.at.number_input(key="ni:pp:subsample:n_obs").set_value(n_obs) self.at.button(key="FormSubmitter:subsample_form_n_obs-Apply").click().run(timeout=100) + print(self.at.session_state.adata_state.current.adata.n_obs) assert self.at.session_state.adata_state.current.adata.n_obs == n_obs def test_cell_cycle_scoring(self): print(f"{bcolors.OKBLUE}test_cell_cycle_scoring {bcolors.ENDC}", end="") #TODO: Figure out how to run this test, can't simulate loading a file - self.at.selectbox(key="sb_gene_col_cell_cycle").select("genes").run() - self.at.selectbox(key="sb_phase_col_cell_cycle").select("phase").run() - self.at.selectbox(key="sb_group_cell_cycle").select("BATCH").run() - self.at.button(key="FormSubmitter:cell_cycle_scoring_form-Run").click().run() + # self.at.selectbox(key="sb_gene_col_cell_cycle").select("genes").run() + # self.at.selectbox(key="sb_phase_col_cell_cycle").select("phase").run() + # self.at.selectbox(key="sb_group_cell_cycle").select("BATCH").run() + # self.at.button(key="FormSubmitter:cell_cycle_scoring_form-Run").click().run() def get_final_session_state(self): return self.at.session_state diff --git a/app/utils/session_cache.py b/app/utils/session_cache.py index 2a5a896..90d6d9d 100644 --- a/app/utils/session_cache.py +++ b/app/utils/session_cache.py @@ -48,7 +48,8 @@ def load_data_from_cache(state_file): dbfile = open(os.path.join(os.getenv('WORKDIR'), 'tmp', state_file), 'rb') session = pickle.load(dbfile) for key in session: - if not (key.__contains__("FormSubmitter") or key.__contains__("file_uploader")): + if not (key.__contains__("FormSubmitter") or key.__contains__("file_uploader") or + key.__contains__("btn") or key.__contains__("toggle")): st.session_state[key] = session[key] # load in keys to session state dbfile.close() From 3629632f48436783619557b65330ca3600fdd8e8 Mon Sep 17 00:00:00 2001 From: ch1ru Date: Fri, 22 Mar 2024 17:58:20 +0000 Subject: [PATCH 2/3] added session pipeline, implemented tests --- app/components/sidebar.py | 20 +++-- app/database/schemas/schemas.py | 13 +++- app/{models => enums}/ErrorMessage.py | 1 - app/enums/Language.py | 6 ++ app/models/ScriptModel.py | 5 -- app/models/SessionModel.py | 11 +++ app/pages/1_Upload.py | 21 ++++-- app/pages/2_Preprocess.py | 70 +++++++++++------- app/pages/3_Integrate.py | 7 -- app/pages/4_Create_model.py | 5 -- app/pages/5_Train.py | 5 -- app/pages/6_Cluster_plots.py | 5 -- app/pages/7_Differential_gene_expression.py | 3 - app/pages/8_Trajectory_Inference.py | 4 - app/pages/9_Spatial_Transcriptomics.py | 5 +- app/scripts/Script.py | 2 +- app/scripts/preprocessing/Annotate_mito.py | 2 +- app/scripts/preprocessing/Filter_cells.py | 2 +- app/scripts/preprocessing/Filter_genes.py | 2 +- .../preprocessing/Highest_expr_genes.py | 2 +- .../preprocessing/Highly_variable_genes.py | 2 +- app/scripts/preprocessing/Normalize.py | 2 +- app/scripts/preprocessing/PCA.py | 2 +- app/scripts/preprocessing/Scale.py | 2 +- app/state/AdataState.py | 2 +- app/state/ScriptState.py | 3 +- app/state/StateManager.py | 23 +++--- app/tests/test_preprocess.py | 74 +++++++++++++------ app/utils/session_cache.py | 25 ++++--- 29 files changed, 194 insertions(+), 132 deletions(-) rename app/{models => enums}/ErrorMessage.py (99%) create mode 100644 app/enums/Language.py create mode 100644 app/models/SessionModel.py diff --git a/app/components/sidebar.py b/app/components/sidebar.py index f4cbb22..589d243 100644 --- a/app/components/sidebar.py +++ b/app/components/sidebar.py @@ -13,7 +13,7 @@ from utils.species import * import numpy as np from state.StateManager import StateManager -from models.ErrorMessage import ErrorMessage, WarningMessage +from enums.ErrorMessage import ErrorMessage, WarningMessage class Sidebar: @@ -151,6 +151,14 @@ def add_experiment(self): except Exception as e: print("Error: ", e) st.error(e) + + def steps(self): + current_adata_id = self.state_manager.adata_state().current.id + sessions = self.conn.query(schemas.Session).filter(schemas.Session.adata_id == current_adata_id).all() + session_names = [session.description for session in sessions] + with st.sidebar: + st.select_slider(label="Pipeline steps", options=session_names) + st.button(label="🔧 Undo", key="btn_undo", use_container_width=True) def download_adata(self): @@ -286,8 +294,7 @@ def change_gene_format(): st.toggle(label="Ensembl ID", value=(format == "ensembl"), key="toggle_gene_format", on_change=change_gene_format) except Exception as e: - st.error(e) - print("Error: ", e) + st.toast(e, icon="❌") def show_version(self): @@ -394,10 +401,13 @@ def set_adata(): self.download_adata() - self.add_experiment() - self.notes() + self.show_preview() + self.export_script() + self.steps() + self.delete_experiment_btn() + self.show_version() diff --git a/app/database/schemas/schemas.py b/app/database/schemas/schemas.py index e188c64..43c9f82 100644 --- a/app/database/schemas/schemas.py +++ b/app/database/schemas/schemas.py @@ -12,7 +12,6 @@ class Workspaces(Base): data_dir = Column(String, nullable=False, unique=True) created = Column(TIMESTAMP(timezone=True), nullable=False, server_default=text('now()')) description = Column(String, nullable=True) - cache_file = Column(String, nullable=True) class Adata(Base): __tablename__ = "adata" @@ -31,4 +30,14 @@ class Scripts(Base): id = Column(Integer, primary_key=True, nullable=False, autoincrement=True) script = Column(String, nullable=False) language = Column(String, nullable=False) - created = Column(TIMESTAMP(timezone=True), nullable=False, server_default=text('now()')) \ No newline at end of file + created = Column(TIMESTAMP(timezone=True), nullable=False, server_default=text('now()')) + +class Session(Base): + __tablename__ = "sessions" + + id = Column(Integer, primary_key=True, nullable=False, autoincrement=True) + adata_id = Column(Integer, ForeignKey("adata.id", ondelete="CASCADE"), nullable=False) + session_id = Column(String, nullable=False) + filename = Column(String, nullable=False) + description = Column(String, nullable=True) + created = Column(TIMESTAMP(timezone=True), nullable=False, server_default=text('now()')) diff --git a/app/models/ErrorMessage.py b/app/enums/ErrorMessage.py similarity index 99% rename from app/models/ErrorMessage.py rename to app/enums/ErrorMessage.py index 45bbf6e..49b970b 100644 --- a/app/models/ErrorMessage.py +++ b/app/enums/ErrorMessage.py @@ -12,4 +12,3 @@ class ErrorMessage(Enum): class WarningMessage(Enum): DATASET_ALREADY_EXISTS = "Dataset already exists in workspace, using original." - diff --git a/app/enums/Language.py b/app/enums/Language.py new file mode 100644 index 0000000..2269517 --- /dev/null +++ b/app/enums/Language.py @@ -0,0 +1,6 @@ +from enum import Enum + +class Language(Enum): + python = "python" + R = "R" + ALL_SUPPORTED = ["python", "R"] \ No newline at end of file diff --git a/app/models/ScriptModel.py b/app/models/ScriptModel.py index e46a6c3..4cadaf5 100644 --- a/app/models/ScriptModel.py +++ b/app/models/ScriptModel.py @@ -1,7 +1,6 @@ from typing import Optional from pydantic import BaseModel, ConfigDict, ValidationError from datetime import date -from enum import Enum class ScriptModel(BaseModel): adata_id: int @@ -10,7 +9,3 @@ class ScriptModel(BaseModel): language: str created: Optional[date] -class Language(Enum): - python = "python" - R = "R" - ALL_SUPPORTED = ["python", "R"] \ No newline at end of file diff --git a/app/models/SessionModel.py b/app/models/SessionModel.py new file mode 100644 index 0000000..7afa5ec --- /dev/null +++ b/app/models/SessionModel.py @@ -0,0 +1,11 @@ +from typing import Optional +from pydantic import BaseModel, ConfigDict, ValidationError +from datetime import date + +class SessionModel(BaseModel): + id: Optional[int] #Uses autoincrement when not included + adata_id: int + session_id: str + filename: str + description: Optional[str] + created: Optional[date] \ No newline at end of file diff --git a/app/pages/1_Upload.py b/app/pages/1_Upload.py index a97bffc..d8c7c51 100644 --- a/app/pages/1_Upload.py +++ b/app/pages/1_Upload.py @@ -1,8 +1,6 @@ -from pydantic import ValidationError import streamlit as st import scanpy as sc import squidpy as sq -import pickle import os from models.AdataModel import AdataModel from models.WorkspaceModel import WorkspaceModel @@ -11,8 +9,6 @@ from database.schemas import schemas from state.AdataState import AdataState from state.StateManager import StateManager -from state.ScriptState import ScriptState -from utils.session_cache import load_data_from_cache, cache_data_to_session import loompy as lmp import glob from components.sidebar import Sidebar @@ -40,6 +36,7 @@ class Upload: """ def __init__(self): self.conn: Session = SessionLocal() + self.state_manager = StateManager() self.upload_file() self.scanpy_dataset() self.external_sources() @@ -297,12 +294,26 @@ def show_anndata(self, adata, f = None, filename = ""): filename=os.path.join(os.getenv('WORKDIR'), 'adata', f'{filename}.h5ad') ) st.session_state["adata_state"] = AdataState(active=active_adata) + + self.state_manager \ + .add_adata(adata) \ + .add_description("Upload raw") \ + .save_session() + st.toast("Successfully uploaded file", icon='✅') else: - st.warning("A dataset with the same name already exists, will not overwrite.") + + self.state_manager \ + .add_adata(adata) \ + .add_description("Upload raw") \ + .save_session() + + st.toast("A dataset with the same name already exists, will not overwrite.", icon="⚠️") self.show_sidebar_preview(f) + + diff --git a/app/pages/2_Preprocess.py b/app/pages/2_Preprocess.py index 79f6fd4..8169667 100644 --- a/app/pages/2_Preprocess.py +++ b/app/pages/2_Preprocess.py @@ -23,7 +23,7 @@ import plotly.figure_factory as ff import re import plotly.graph_objects as go -from models.ScriptModel import Language +from enums.Language import Language from state.StateManager import StateManager @@ -100,6 +100,7 @@ def filter_highest_expr_genes(self): self.state_manager \ .add_adata(adata) \ .add_script(Highest_expr_genes(n_top_genes=n_top_genes, language=Language.ALL_SUPPORTED)) \ + .add_description("Compute highest expr genes") \ .save_session() except Exception as e: @@ -107,7 +108,6 @@ def filter_highest_expr_genes(self): - def remove_genes(self): """ Remove a gene from the dataset. Preserves the complete var names in raw attribute. @@ -135,7 +135,7 @@ def remove_genes(self): with st.form(key="remove_genes_form"): try: st.subheader("Remove genes") - remove_genes = st.multiselect(label="Genes", options=self.state_manager.adata_state().current.adata.var_names) + remove_genes = st.multiselect(label="Genes", options=self.state_manager.adata_state().current.adata.var_names, key="ms:pp:remove_genes:genes") subcol1, _, _ = st.columns(3) submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) if submit_btn: @@ -149,6 +149,7 @@ def remove_genes(self): self.state_manager \ .add_adata(adata) \ + .add_description(f"Removed {remove_genes}") \ .save_session() except Exception as e: @@ -217,6 +218,7 @@ def run_highly_variable(adata: AnnData, flavour="seurat", min_mean=None, max_mea self.state_manager \ .add_adata(adata) \ .add_script(Highly_variable_genes(language=Language.ALL_SUPPORTED, min_mean=min_mean, max_mean=max_mean, min_disp=min_disp, n_top_genes=n_top_genes, span=span)) \ + .add_description("Compute highly variable") \ .save_session() @@ -305,6 +307,7 @@ def normalize_counts(self): self.state_manager \ .add_adata(st.session_state.adata_state.current.adata) \ .add_script(Normalize(language=Language.ALL_SUPPORTED, scale_factor=target_sum, log_norm=log_transform_total)) \ + .add_description("Normalized counts") \ .save_session() st.toast("Normalized data", icon='✅') @@ -348,6 +351,7 @@ def filter_cells(self): self.state_manager \ .add_adata(adata) \ .add_script(Filter_cells(language=Language.ALL_SUPPORTED, min_genes=min_genes)) \ + .add_description(f"Filtered cells (min_genes={min_genes})") \ .save_session() st.toast("Filtered cells", icon='✅') @@ -389,6 +393,7 @@ def filter_genes(self): self.state_manager \ .add_adata(adata) \ .add_script(Filter_genes(language=Language.ALL_SUPPORTED, min_cells=min_cells)) \ + .add_description(f"Filtered genes (min_cells={min_cells})") \ .save_session() st.toast("Filtered genes", icon='✅') @@ -452,6 +457,7 @@ def recipes(self): self.state_manager \ .add_adata(adata) \ + .add_description("Applied Seurat pp recipe") \ .save_session() st.toast(f"Applied recipe: Seurat", icon='✅') @@ -475,6 +481,7 @@ def recipes(self): self.state_manager \ .add_adata(adata) \ + .add_description("Applied Weinreb17 pp recipe") \ .save_session() st.toast(f"Applied recipe: Weinreb17", icon='✅') @@ -496,6 +503,7 @@ def recipes(self): self.state_manager \ .add_adata(adata) \ + .add_description("Applied Zheng17 pp recipe") \ .save_session() st.toast(f"Applied recipe: Zheng17", icon='✅') @@ -923,6 +931,7 @@ def run_scrublet(self): self.state_manager \ .add_adata(adata) \ + .add_description("Run Scrublet") \ .save_session() st.toast("Run scrublet", icon="✅") @@ -1001,7 +1010,7 @@ def run_doubletdetection(self): sc.pl.umap(adata, color=["doublet", "doublet_score"]) sc.pl.violin(adata, "doublet_score") """ - with st.form(key="Doubletdetection_form"): + with st.form(key="doubletdetection_form"): st.subheader("Doubletdetection") col1, col2, col3 = st.columns(3) n_iters = col1.number_input(label="n_iters", min_value=1, step=1, value=10) @@ -1140,6 +1149,7 @@ def scale_to_unit_variance(self): self.state_manager \ .add_adata(adata) \ .add_script(Scale(language=Language.ALL_SUPPORTED, max_value=max_value, zero_center=zero_center)) \ + .add_description(f"Scale data with max_value {max_value}") \ .save_session() st.toast("Successfully scaled data", icon="✅") @@ -1182,6 +1192,7 @@ def downsample_data(self): sc.pp.downsample_counts(adata, counts_per_cell=counts_per_cell, random_state=42) self.state_manager \ .add_adata(adata) \ + .add_description(f"Downsample counts (counts_per_cell={counts_per_cell})") \ .save_session() st.toast("Successfully downsampled data per cell", icon="✅") with total_counts: @@ -1194,6 +1205,7 @@ def downsample_data(self): sc.pp.downsample_counts(adata, total_counts=total_counts, random_state=42) self.state_manager \ .add_adata(adata) \ + .add_description(f"Downsample counts (total_counts={total_counts})") \ .save_session() st.toast("Successfully downsampled data by total counts", icon="✅") @@ -1231,8 +1243,8 @@ def subsample_data(self): try: with n_obs: with st.form(key="subsample_form_n_obs"): - n_obs = self.state_manager.adata_state().current.adata.n_obs - n_obs = st.number_input(label="n obs", key="ni:pp:subsample:n_obs", help="Subsample to this number of observations.", value=n_obs, step=1, format="%i", max_value=n_obs) + n_obs_default = self.state_manager.adata_state().current.adata.n_obs + n_obs = st.number_input(label="n obs", key="ni:pp:subsample:n_obs", help="Subsample to this number of observations.", value=n_obs_default, step=1, format="%i", max_value=n_obs_default) subcol1, _, _ = st.columns(3) btn_subsample_n_obs = subcol1.form_submit_button(label="Apply", use_container_width=True) if btn_subsample_n_obs: @@ -1241,6 +1253,7 @@ def subsample_data(self): self.state_manager \ .add_adata(adata) \ + .add_description(f"Subsample counts with n_obs={n_obs}") \ .save_session() st.toast(f"Successfully subsampled data to {n_obs} observations", icon="✅") with fraction: @@ -1254,6 +1267,7 @@ def subsample_data(self): self.state_manager \ .add_adata(adata) \ + .add_description(f"Subsample counts with fraction={fraction}") \ .save_session() st.toast(f"Successfully subsampled data to {fraction * 100}% original value", icon="✅") @@ -1303,6 +1317,7 @@ def batch_effect_removal(self): self.state_manager \ .add_adata(adata) \ + .add_description("Run Combat") \ .save_session() st.toast("Batch corrected data", icon='✅') @@ -1343,6 +1358,7 @@ def run_pca(adata: AnnData): self.state_manager \ .add_adata(adata) \ .add_script(PCA(language=Language.ALL_SUPPORTED, color=pca_color)) \ + .add_description("Computed PCA") \ .save_session() @@ -1413,36 +1429,39 @@ def measure_gene_counts(self): with single_dataset: with st.form(key="measure_gene_counts_single_dataset"): st.subheader("Collective counts across dataset") - options = st.session_state.adata_state.current.adata.var_names - genes = st.multiselect(label="Gene (e.g. XIST for detecting sex)", options=options) + options = self.state_manager.adata_state().current.adata.var_names + genes = st.multiselect(label="Gene (e.g. XIST for detecting sex)", options=options, key="ms:pp:measure_genes:genes") subcol_btn1, _, _ = st.columns(3) submit_btn = subcol_btn1.form_submit_button(label="Run", use_container_width=True) if submit_btn: with st.spinner(text="Locating genes"): - df_whole_ds = pd.DataFrame({'genes': genes, 'counts': [st.session_state.adata_state.current.adata.to_df()[gene].sum() for gene in genes]}) + adata = self.state_manager.get_current_adata() + df_whole_ds = pd.DataFrame({'genes': genes, 'counts': [adata.to_df()[gene].sum() for gene in genes]}) st.bar_chart(df_whole_ds, x='genes', y='counts', color='genes') #write to script state - # st.session_state["script_state"].add_script("#Measure gene counts in single dataset") - # st.session_state["script_state"].add_script(f"df_whole_ds = pd.DataFrame({{'genes': {genes}, 'counts': {[st.session_state.adata_state.current.adata.to_df()[gene].sum() for gene in genes]}}})") - # st.session_state["script_state"].add_script(f"arr = np.array(['{st.session_state.adata_state.current.adata_name}'])") - # st.session_state["script_state"].add_script(f"df = pd.DataFrame('{gene} count': adata.obs['gene-counts'], 'Dataset': np.repeat(arr, adata.n_obs))") + self.state_manager \ + .add_adata(adata) \ + .add_description("Measure gene counts (single dataset)") \ + .save_session() with subsample: with st.form(key="measure_gene_counts_multiple_datasets"): st.subheader("Subsample counts in dataset") - gene_options = st.session_state.adata_state.current.adata.var_names - batch_key_measure_gene_counts = st.selectbox(label="Obs key", options=st.session_state.adata_state.current.adata.obs_keys(), key="sb_sex_pred_batch_key") - gene = st.selectbox(label="Gene (e.g. XIST for detecting sex)", options=gene_options) + gene_options = self.state_manager.adata_state().current.adata.var_names + batch_key_measure_gene_counts = st.selectbox(label="Obs key", options=st.session_state.adata_state.current.adata.obs_keys(), key="sb:pp:measure_genes:batch") + genes = st.selectbox(label="Gene (e.g. XIST for detecting sex)", options=gene_options, key="ms:pp:measure_genes_batch:genes") subcol_btn1, _, _ = st.columns(3) submit_btn = subcol_btn1.form_submit_button(label="Run", use_container_width=True) if submit_btn: with st.spinner(text="Locating genes"): - df_subsample = pd.DataFrame({f'{gene} count': st.session_state.adata_state.current.adata.to_df()[gene], f"{batch_key_measure_gene_counts}": st.session_state.adata_state.current.adata.obs[f"{batch_key_measure_gene_counts}"]}) - st.bar_chart(data=df_subsample, x=f"{batch_key_measure_gene_counts}", y=f'{gene} count', color=f"{batch_key_measure_gene_counts}") - #write to script state - # st.session_state["script_state"].add_script("#Measure gene counts across datasets") - # st.session_state["script_state"].add_script(f"gene = {gene}") - # st.session_state["script_state"].add_script(f" batch_key_measure_gene_counts = {batch_key_measure_gene_counts}") - # st.session_state["script_state"].add_script(f"df_subsample = pd.DataFrame({{f'{{gene}} count': st.session_state.adata_state.current.adata.to_df()[gene], f'{{batch_key_measure_gene_counts}}': st.session_state.adata_state.current.adata.obs[f'{{batch_key_measure_gene_counts}}']}})") + adata = self.state_manager.get_current_adata() + df_subsample = pd.DataFrame({f'{genes} count': adata.to_df()[genes], f"{batch_key_measure_gene_counts}": adata.obs[f"{batch_key_measure_gene_counts}"]}) + st.bar_chart(data=df_subsample, x=f"{batch_key_measure_gene_counts}", y=f'{genes} count', color=f"{batch_key_measure_gene_counts}") + # write to script state + self.state_manager \ + .add_adata(adata) \ + .add_description("Measure gene counts (multiple datasets)") \ + .save_session() + def cell_cycle_scoring(self): """ @@ -1692,10 +1711,7 @@ def cell_cycle_scoring(self): preprocess.cell_cycle_scoring() - sidebar.show_preview() - sidebar.export_script() - sidebar.delete_experiment_btn() - sidebar.show_version() + except Exception as e: diff --git a/app/pages/3_Integrate.py b/app/pages/3_Integrate.py index c36e7b8..d4d8aec 100644 --- a/app/pages/3_Integrate.py +++ b/app/pages/3_Integrate.py @@ -715,16 +715,9 @@ def scvi_metrics(self): sidebar = Sidebar() sidebar.show(integrate=True) - - sidebar.show_preview(integrate=True) integrate = Integrate() - sidebar.export_script() - - sidebar.delete_experiment_btn() - - sidebar.show_version() except Exception as e: if(st.session_state == {}): diff --git a/app/pages/4_Create_model.py b/app/pages/4_Create_model.py index d623bf0..6f2d116 100644 --- a/app/pages/4_Create_model.py +++ b/app/pages/4_Create_model.py @@ -361,11 +361,6 @@ def create_model(adata): adata = st.session_state.adata_state.current.adata.copy() create_model(adata) - - sidebar.show_preview() - sidebar.export_script() - sidebar.delete_experiment_btn() - sidebar.show_version() except Exception as e: if(st.session_state == {}): diff --git a/app/pages/5_Train.py b/app/pages/5_Train.py index 5bc806c..d9a1226 100644 --- a/app/pages/5_Train.py +++ b/app/pages/5_Train.py @@ -176,11 +176,6 @@ def train(self): sidebar = Sidebar() sidebar.show() - - sidebar.show_preview() - sidebar.export_script() - sidebar.delete_experiment_btn() - sidebar.show_version() train = Train(adata) diff --git a/app/pages/6_Cluster_plots.py b/app/pages/6_Cluster_plots.py index 8925385..896bc41 100644 --- a/app/pages/6_Cluster_plots.py +++ b/app/pages/6_Cluster_plots.py @@ -538,11 +538,6 @@ def _plot_charts(self, params): st.divider() analysis.plots() - sidebar.show_preview() - sidebar.export_script() - sidebar.delete_experiment_btn() - sidebar.show_version() - except Exception as e: if(st.session_state == {}): diff --git a/app/pages/7_Differential_gene_expression.py b/app/pages/7_Differential_gene_expression.py index 5afc428..9f87ba4 100644 --- a/app/pages/7_Differential_gene_expression.py +++ b/app/pages/7_Differential_gene_expression.py @@ -560,9 +560,6 @@ def show_top_ranked_genes(self): dge = Differential_gene_expression(st.session_state.adata_state.current.adata.copy()) - sidebar.export_script() - sidebar.delete_experiment_btn() - sidebar.show_version() except Exception as e: if(st.session_state == {}): diff --git a/app/pages/8_Trajectory_Inference.py b/app/pages/8_Trajectory_Inference.py index 6f1172a..2160630 100644 --- a/app/pages/8_Trajectory_Inference.py +++ b/app/pages/8_Trajectory_Inference.py @@ -407,10 +407,6 @@ def show_path(self): tji.draw_page() - sidebar.show_preview() - sidebar.export_script() - sidebar.delete_experiment_btn() - sidebar.show_version() except Exception as e: diff --git a/app/pages/9_Spatial_Transcriptomics.py b/app/pages/9_Spatial_Transcriptomics.py index 2eb72d0..4253641 100644 --- a/app/pages/9_Spatial_Transcriptomics.py +++ b/app/pages/9_Spatial_Transcriptomics.py @@ -452,10 +452,7 @@ def ligand_receptor_interaction(self): spatial_t = Spatial_transcriptomics(adata) spatial_t.draw_page() - sidebar.show_preview() - sidebar.export_script() - sidebar.delete_experiment_btn() - sidebar.show_version() + except Exception as e: if(st.session_state == {}): diff --git a/app/scripts/Script.py b/app/scripts/Script.py index c7784b4..5959271 100644 --- a/app/scripts/Script.py +++ b/app/scripts/Script.py @@ -1,6 +1,6 @@ import streamlit as st from state.ScriptState import ScriptState -from models.ScriptModel import Language +from enums.Language import Language class Script: """ diff --git a/app/scripts/preprocessing/Annotate_mito.py b/app/scripts/preprocessing/Annotate_mito.py index b723d44..1976c8c 100644 --- a/app/scripts/preprocessing/Annotate_mito.py +++ b/app/scripts/preprocessing/Annotate_mito.py @@ -1,4 +1,4 @@ -from models.ScriptModel import Language +from enums.Language import Language import streamlit as st from state.ScriptState import ScriptState from scripts.Script import Script diff --git a/app/scripts/preprocessing/Filter_cells.py b/app/scripts/preprocessing/Filter_cells.py index a52ba4d..0b974d0 100644 --- a/app/scripts/preprocessing/Filter_cells.py +++ b/app/scripts/preprocessing/Filter_cells.py @@ -1,4 +1,4 @@ -from models.ScriptModel import Language +from enums.Language import Language import streamlit as st from state.ScriptState import ScriptState from scripts.Script import Script diff --git a/app/scripts/preprocessing/Filter_genes.py b/app/scripts/preprocessing/Filter_genes.py index 4ba9a32..0f116a8 100644 --- a/app/scripts/preprocessing/Filter_genes.py +++ b/app/scripts/preprocessing/Filter_genes.py @@ -1,4 +1,4 @@ -from models.ScriptModel import Language +from enums.Language import Language import streamlit as st from state.ScriptState import ScriptState from scripts.Script import Script diff --git a/app/scripts/preprocessing/Highest_expr_genes.py b/app/scripts/preprocessing/Highest_expr_genes.py index fb80d1f..cd76054 100644 --- a/app/scripts/preprocessing/Highest_expr_genes.py +++ b/app/scripts/preprocessing/Highest_expr_genes.py @@ -1,4 +1,4 @@ -from models.ScriptModel import Language +from enums.Language import Language import streamlit as st from state.ScriptState import ScriptState from scripts.Script import Script diff --git a/app/scripts/preprocessing/Highly_variable_genes.py b/app/scripts/preprocessing/Highly_variable_genes.py index 569b604..6664f67 100644 --- a/app/scripts/preprocessing/Highly_variable_genes.py +++ b/app/scripts/preprocessing/Highly_variable_genes.py @@ -1,4 +1,4 @@ -from models.ScriptModel import Language +from enums.Language import Language import streamlit as st from state.ScriptState import ScriptState import numpy as np diff --git a/app/scripts/preprocessing/Normalize.py b/app/scripts/preprocessing/Normalize.py index d6c7d75..c4076f5 100644 --- a/app/scripts/preprocessing/Normalize.py +++ b/app/scripts/preprocessing/Normalize.py @@ -1,4 +1,4 @@ -from models.ScriptModel import Language +from enums.Language import Language import streamlit as st from state.ScriptState import ScriptState import numpy as np diff --git a/app/scripts/preprocessing/PCA.py b/app/scripts/preprocessing/PCA.py index c356045..47fa65e 100644 --- a/app/scripts/preprocessing/PCA.py +++ b/app/scripts/preprocessing/PCA.py @@ -1,4 +1,4 @@ -from models.ScriptModel import Language +from enums.Language import Language import streamlit as st from state.ScriptState import ScriptState import numpy as np diff --git a/app/scripts/preprocessing/Scale.py b/app/scripts/preprocessing/Scale.py index d4b540f..9107c29 100644 --- a/app/scripts/preprocessing/Scale.py +++ b/app/scripts/preprocessing/Scale.py @@ -1,4 +1,4 @@ -from models.ScriptModel import Language +from enums.Language import Language import streamlit as st from state.ScriptState import ScriptState import numpy as np diff --git a/app/state/AdataState.py b/app/state/AdataState.py index 22e1883..d6f6765 100644 --- a/app/state/AdataState.py +++ b/app/state/AdataState.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from state.ScriptState import ScriptState import os -from models.ErrorMessage import ErrorMessage, WarningMessage +from enums.ErrorMessage import ErrorMessage, WarningMessage class AdataState: def __init__(self, active: AdataModel, insert_into_db=True): diff --git a/app/state/ScriptState.py b/app/state/ScriptState.py index b95383f..948c4c0 100644 --- a/app/state/ScriptState.py +++ b/app/state/ScriptState.py @@ -2,7 +2,8 @@ from scanpy import AnnData from sqlalchemy import update from database.schemas import schemas -from models.ScriptModel import ScriptModel, Language +from models.ScriptModel import ScriptModel +from enums.Language import Language import scanpy as sc import streamlit as st from database.database import SessionLocal diff --git a/app/state/StateManager.py b/app/state/StateManager.py index 6415ac1..4a24f50 100644 --- a/app/state/StateManager.py +++ b/app/state/StateManager.py @@ -12,7 +12,7 @@ import streamlit as st from anndata import AnnData from database.database import SessionLocal -from models.ErrorMessage import ErrorMessage +from enums.ErrorMessage import ErrorMessage class StateManager: """ @@ -22,10 +22,8 @@ class StateManager: ########## Factory methods ########## def add_script(self, script: Script): - # add script if present - if script is not None: - if isinstance(script, Script): - self.script = script + if isinstance(script, Script): + self.script = script return self def add_adata(self, adata: AnnData): @@ -33,6 +31,10 @@ def add_adata(self, adata: AnnData): self.adata = adata return self + def add_description(self, description: str): + self.description = description + return self + ########## session ########## def load_session(self): @@ -44,10 +46,10 @@ def load_session(self): current_workspace_id = os.getenv('CURRENT_WORKSPACE_ID') conn = SessionLocal() - cache_file = conn.query(schemas.Workspaces) \ - .filter(schemas.Workspaces.id == current_workspace_id) \ + cache_file = conn.query(schemas.Session) \ + .filter(schemas.Session.work_id == current_workspace_id) \ .first() \ - .cache_file + .filename load_data_from_cache(cache_file) @@ -67,9 +69,12 @@ def save_session(self): if hasattr(self, 'script'): self.script.add_script() + if not hasattr(self, 'description'): + self.description = "" + # cache data to pickle file - cache_data_to_session() + cache_data_to_session(description=self.description) def init_session(): diff --git a/app/tests/test_preprocess.py b/app/tests/test_preprocess.py index ca9522b..91ca9ce 100644 --- a/app/tests/test_preprocess.py +++ b/app/tests/test_preprocess.py @@ -56,7 +56,7 @@ def pipeline1(self): assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - self.test_doublet_prediction() + self.test_scrublet() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") @@ -80,18 +80,6 @@ def pipeline1(self): assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - self.test_annot_mito() - assert not self.at.exception - print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - - self.test_annot_ribo() - assert not self.at.exception - print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - - self.test_annot_hb() - assert not self.at.exception - print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - self.test_batch_effect_removal_and_pca() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") @@ -112,23 +100,44 @@ def pipeline1(self): assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - self.test_measure_gene_counts() + self.test_regress_out() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + # self.test_save_adata() + # assert not self.at.exception + # print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + + def pipeline2(self): + + self.test_geneID_conversion() + assert not self.at.exception + print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + self.test_remove_genes() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - - self.test_regress_out() + + self.test_annot_mito() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - # self.test_save_adata() - # assert not self.at.exception - # print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + self.test_annot_ribo() + assert not self.at.exception + print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") - def pipeline2(self): + self.test_annot_hb() + assert not self.at.exception + print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + + self.test_measure_gene_counts() + assert not self.at.exception + print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + + self.test_doubletdetection() + assert not self.at.exception + print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") + self.test_subsampling_data_n_obs() assert not self.at.exception print(f"{bcolors.OKGREEN}OK{bcolors.ENDC}") @@ -258,8 +267,8 @@ def test_pp_recipe(self): assert_frame_equal(self.at.session_state.adata_state.current.adata.var, adata.var) - def test_doublet_prediction(self): - print(f"{bcolors.OKBLUE}test_doublet_prediction {bcolors.ENDC}", end="") + def test_scrublet(self): + print(f"{bcolors.OKBLUE}test_scrublet {bcolors.ENDC}", end="") # Check inputs assert self.at.session_state.adata_state.current.adata.n_obs == 9288 assert self.at.session_state.adata_state.current.adata.n_vars == 1222 @@ -278,6 +287,11 @@ def test_doublet_prediction(self): assert_series_equal(adata.obs.doublet_score, self.at.session_state.adata_state.current.adata.obs.doublet_score) assert_series_equal(adata.obs.predicted_doublet, self.at.session_state.adata_state.current.adata.obs.predicted_doublet) + def test_doubletdetection(self): + print(f"{bcolors.OKBLUE}test_doubletdetection {bcolors.ENDC}", end="") + self.at.button(key="FormSubmitter:doubletdetection_form-Run").click().run(timeout=100) + #TODO: implement tests for charts + def test_annot_mito(self): print(f"{bcolors.OKBLUE}test_annot_mito {bcolors.ENDC}", end="") @@ -294,9 +308,20 @@ def test_annot_hb(self): def test_measure_gene_counts(self): print(f"{bcolors.OKBLUE}test_measure_gene_counts {bcolors.ENDC}", end="") + # Single dataset + self.at.multiselect(key="ms:pp:measure_genes:genes").set_value(['XIST']) + self.at.button(key="FormSubmitter:measure_gene_counts_single_dataset-Run").click().run(timeout=100) + # Multiple datasets + self.at.selectbox(key="sb:pp:measure_genes:batch").select(['BATCH']) + self.at.multiselect(key="ms:pp:measure_genes_batch:genes").set_value(['XIST']) + self.at.button(key="FormSubmitter:measure_gene_counts_multiple_datasets-Run").click().run(timeout=100) def test_remove_genes(self): print(f"{bcolors.OKBLUE}test_remove_genes {bcolors.ENDC}", end="") + self.at.multiselect(key="ms:pp:remove_genes:genes").set_value(['MALAT1']) + self.at.button(key="FormSubmitter:remove_genes_form-Run").click().run(timeout=100) + assert 'MALAT1' not in self.at.session_state.adata_state.current.adata.var_names + def test_regress_out(self): print(f"{bcolors.OKBLUE}test_regress_out {bcolors.ENDC}", end="") @@ -415,7 +440,6 @@ def test_subsampling_data_fraction(self): subsampled_n_obs = self.at.session_state.adata_state.current.adata.n_obs assert subsampled_n_obs == math.floor(original_n_obs * fraction) - def test_subsampling_data_n_obs(self): print(f"{bcolors.OKBLUE}test_subsampling_data_n_obs {bcolors.ENDC}", end="") @@ -428,6 +452,10 @@ def test_subsampling_data_n_obs(self): print(self.at.session_state.adata_state.current.adata.n_obs) assert self.at.session_state.adata_state.current.adata.n_obs == n_obs + def test_geneID_conversion(self): + print(f"{bcolors.OKBLUE}test_geneID_conversion {bcolors.ENDC}", end="") + self.at.toggle(key="toggle_gene_format").set_value(False).run() + def test_cell_cycle_scoring(self): print(f"{bcolors.OKBLUE}test_cell_cycle_scoring {bcolors.ENDC}", end="") diff --git a/app/utils/session_cache.py b/app/utils/session_cache.py index 90d6d9d..f5e464a 100644 --- a/app/utils/session_cache.py +++ b/app/utils/session_cache.py @@ -4,8 +4,9 @@ import streamlit as st from database.schemas import schemas from database.database import SessionLocal +from models.SessionModel import SessionModel -def cache_data_to_session(): +def cache_data_to_session(description: str = None): try: state = {} @@ -22,20 +23,22 @@ def cache_data_to_session(): hash = hashlib.md5() hash.update(encoded) state_hash = hash.hexdigest() - - # write cache file to db - conn = SessionLocal() - conn.query(schemas.Workspaces) \ - .filter(schemas.Workspaces.id == st.session_state.current_workspace.id) \ - .update({'cache_file': state_hash}) - + + filepath = os.path.join(os.getenv('WORKDIR'), 'tmp', state_hash) + # Write to file - dbfile = open(os.path.join(os.getenv('WORKDIR'), 'tmp', state_hash), 'wb') + dbfile = open(filepath, mode='wb') pickle.dump(state, dbfile) - #python doesn't copy the objects so db connection in state is destroyed. Add it back here + dbfile.close() + + # python doesn't copy the objects so db connection in state is destroyed. Add it back here st.session_state["adata_state"].conn = SessionLocal() st.session_state["script_state"].conn = SessionLocal() - dbfile.close() + + # write cache file to db + conn = SessionLocal() + new_session = schemas.Session(session_id=state_hash, adata_id=st.session_state.adata_state.current.id, filename=filepath, description=description) + conn.add(new_session) # commit cache file to db conn.commit() From d95e9933b64d3c0d4b73dce92242d5007fd70097 Mon Sep 17 00:00:00 2001 From: ch1ru Date: Sun, 24 Mar 2024 12:56:23 +0000 Subject: [PATCH 3/3] added tests, fixed upload issues, removed statemanager init --- app/.streamlit/config.toml | 5 +- app/components/sidebar.py | 23 ++++--- app/css/common.css | 74 ++++----------------- app/css/workspace.css | 6 +- app/pages/11_Plotly_3D.py | 2 +- app/pages/1_Upload.py | 37 ++++++----- app/pages/2_Preprocess.py | 58 ++++++++-------- app/pages/3_Integrate.py | 24 ++++--- app/pages/4_Create_model.py | 4 ++ app/pages/6_Cluster_plots.py | 20 +++--- app/pages/7_Differential_gene_expression.py | 21 +++--- app/pages/8_Trajectory_Inference.py | 8 ++- app/pages/9_Spatial_Transcriptomics.py | 16 +++-- app/state/AdataState.py | 3 + app/state/StateManager.py | 24 +++---- 15 files changed, 152 insertions(+), 173 deletions(-) diff --git a/app/.streamlit/config.toml b/app/.streamlit/config.toml index 8666efb..a97367e 100644 --- a/app/.streamlit/config.toml +++ b/app/.streamlit/config.toml @@ -4,11 +4,8 @@ headless = true [client] toolbarMode = "viewer" -[deprecation] -showPyplotGlobalUse = false - [theme] -primaryColor="#1976d2" +primaryColor="#004dcf" backgroundColor="#181a1e" secondaryBackgroundColor="linear-gradient(180deg, rgb(5, 39, 103) 0%, #3a0647 70%)" textColor="#fff" diff --git a/app/components/sidebar.py b/app/components/sidebar.py index 589d243..b7e84b0 100644 --- a/app/components/sidebar.py +++ b/app/components/sidebar.py @@ -50,7 +50,7 @@ def show_preview(self, integrate=False): def delete_experiment_btn(self): with st.sidebar: - delete_btn = st.button(label="🗑️ Delete Experiment", use_container_width=True, key="btn_delete_adata") + delete_btn = st.button(label="🗑️ Delete Experiment", use_container_width=True, key="btn_delete_adata", type='primary') if delete_btn: self.state_manager.adata_state() \ .delete_record(adata_name=st.session_state.sb_adata_selection) @@ -152,13 +152,20 @@ def add_experiment(self): print("Error: ", e) st.error(e) + def steps(self): - current_adata_id = self.state_manager.adata_state().current.id - sessions = self.conn.query(schemas.Session).filter(schemas.Session.adata_id == current_adata_id).all() - session_names = [session.description for session in sessions] with st.sidebar: - st.select_slider(label="Pipeline steps", options=session_names) - st.button(label="🔧 Undo", key="btn_undo", use_container_width=True) + with st.form("form_session_steps"): + current_adata_id = self.state_manager.adata_state().current.id + sessions = self.conn.query(schemas.Session).filter(schemas.Session.adata_id == current_adata_id).all() + if sessions: + session_names = [session.description for session in sessions] + + steps = st.select_slider(label="Pipeline steps", options=session_names) + submit_btn = st.form_submit_button(label="🔧 Undo", use_container_width=True) + + if submit_btn: + print("hi") def download_adata(self): @@ -405,9 +412,7 @@ def set_adata(): self.notes() self.show_preview() self.export_script() - self.steps() - self.delete_experiment_btn() - self.show_version() + diff --git a/app/css/common.css b/app/css/common.css index 32166e6..966c3f2 100644 --- a/app/css/common.css +++ b/app/css/common.css @@ -10,69 +10,19 @@ -footer {visibility: hidden;} - .st-emotion-cache-1cypcdb {background: linear-gradient(180deg, rgb(5, 39, 103) 0%, #3a0647 70%); box-shadow: 1px 0 10px -2px #000;} - .st-emotion-cache-86cver {color: rgba(250, 250, 250, 0.6)} - .stButton button { - border-radius: 0.5rem; - background: #004dcf; - color: #fff; - border: 1px solid #004dcf; - padding: 0.25rem 0.75rem; - } - - .stButton button:hover { - background: transparent; - color: #004dcf; - transition: all 0.1s ease-in-out; - border: 1px solid #004dcf; - } - - .stButton button:focus { - border: 1px solid #004dcf; - color: #004dcf; - } - - .stButton button:active { - border: 1px solid #004dcf; - color: #004dcf; - } - - .stButton button:visited { - border: 1px solid #004dcf; - color: #004dcf; - } - - .st-emotion-cache-1b9yna5:focus:not(:active) { - border: 1px solid #004dcf; - color: #004dcf; - } - - .st-emotion-cache-1b9yna5:focus:not(:hover) { - border: 1px solid #004dcf; - color: #fff; - } - - .st-emotion-cache-oooxyj:focus:not(:active) { - border: 1px solid #004dcf; - color: #004dcf; - } +footer { + visibility: hidden; +} - .st-emotion-cache-oooxyj:focus:not(:hover) { - border: 1px solid #004dcf; - color: #fff; - } +.st-emotion-cache-1cypcdb { + background: linear-gradient(180deg, rgb(5, 39, 103) 0%, #3a0647 70%); box-shadow: 1px 0 10px -2px #000; - .st-emotion-cache-1ts31n5:focus:not(:active) { - border: 1px solid #004dcf; - color: #004dcf; - } +} +.st-emotion-cache-86cver { + color: rgba(250, 250, 250, 0.6) +} - .st-emotion-cache-1ts31n5:focus:not(:hover) { - border: 1px solid #004dcf; - color: #fff; - } +.st-emotion-cache-z5fcl4 { + padding-top: 2rem; +} - .st-emotion-cache-z5fcl4 { - padding-top: 2rem; - } diff --git a/app/css/workspace.css b/app/css/workspace.css index a8dd177..7904378 100644 --- a/app/css/workspace.css +++ b/app/css/workspace.css @@ -24,7 +24,7 @@ } } -.st-emotion-cache-1qcb9zv, .st-emotion-cache-1ts31n5 { +.st-emotion-cache-1h7cibc { width: 100%; font-size: 48px; font-weight: 800; @@ -48,7 +48,7 @@ } -.st-emotion-cache-1qcb9zv:hover, .st-emotion-cache-1ts31n5:hover { +.st-emotion-cache-1h7cibc:hover { background-position: 10% 0; moz-transition: all .8s ease-in-out; @@ -61,7 +61,7 @@ color: #fff; } -.st-emotion-cache-1qcb9zv:focus, .st-emotion-cache-1ts31n5:focus { +.st-emotion-cache-1h7cibc:focus { border: rgba(255,255,255,0.8) 2px solid; } diff --git a/app/pages/11_Plotly_3D.py b/app/pages/11_Plotly_3D.py index 32f4939..29d831c 100644 --- a/app/pages/11_Plotly_3D.py +++ b/app/pages/11_Plotly_3D.py @@ -121,7 +121,7 @@ def change_embeddings(): break st.selectbox(label="Embedding colors", options=adata.obs_keys(), key="plotly_embedding_color", index=cluster_color_index) st.slider(label="Point size", min_value=0.5, max_value=5.0, step=0.1, value=1.0, key="plotly_point_size") - st.button(label="Plot chart", use_container_width=True, on_click=change_embeddings) + st.button(label="Plot chart", use_container_width=True, on_click=change_embeddings, type='primary') if 'plotly_df' in st.session_state: diff --git a/app/pages/1_Upload.py b/app/pages/1_Upload.py index d8c7c51..f8ef931 100644 --- a/app/pages/1_Upload.py +++ b/app/pages/1_Upload.py @@ -276,47 +276,48 @@ def var_unique(self): def show_anndata(self, adata, f = None, filename = ""): - #upload raw adata - # If there are already a file in this location. If so, don't overwrite. - if not os.path.isfile(os.path.join(os.getenv('WORKDIR'), 'adata', f'{filename}.h5ad')): + """upload raw adata. If there are already a file in this location. If so, don't overwrite. """ - if filename == "": - filename = f.name.split(".")[0] - - filename = filename.replace(" ", "_") #files must not contain spaces + if filename == "": + filename = f.name.split(".")[0] + filename = filename.replace(" ", "_") # files must not contain spaces + filepath = os.path.join(os.getenv('WORKDIR'), 'adata', f'{filename}.h5ad') + + if not os.path.isfile(filepath): sc.write(filename=os.path.join(os.getenv('WORKDIR'), 'uploads', f'{filename}.h5ad'), adata=adata) - + adata.raw = adata active_adata = AdataModel( work_id = st.session_state.current_workspace.id, adata_name=f"{filename}", adata=adata, - filename=os.path.join(os.getenv('WORKDIR'), 'adata', f'{filename}.h5ad') + filename=filepath ) st.session_state["adata_state"] = AdataState(active=active_adata) self.state_manager \ .add_adata(adata) \ - .add_description("Upload raw") \ + .add_description("Raw") \ .save_session() st.toast("Successfully uploaded file", icon='✅') else: - self.state_manager \ - .add_adata(adata) \ - .add_description("Upload raw") \ - .save_session() - - st.toast("A dataset with the same name already exists, will not overwrite.", icon="⚠️") + existing_adata = sc.read_h5ad(filename=filepath) + + active_adata = AdataModel( + work_id = st.session_state.current_workspace.id, + adata_name=f"{filename}", adata=existing_adata, + filename=filepath + ) + + st.session_state["adata_state"] = AdataState(active=active_adata) self.show_sidebar_preview(f) - - def show_sidebar_preview(self, file): with st.sidebar: diff --git a/app/pages/2_Preprocess.py b/app/pages/2_Preprocess.py index 8169667..d949cc0 100644 --- a/app/pages/2_Preprocess.py +++ b/app/pages/2_Preprocess.py @@ -40,10 +40,6 @@ st.markdown(common_style, unsafe_allow_html=True) - -st.set_option('deprecation.showPyplotGlobalUse', False) - - class Preprocess: """ Apply preprocessing on raw data for more effective analysis and detecting biological signal. @@ -86,7 +82,7 @@ def filter_highest_expr_genes(self): st.subheader("Show highest expressed genes") n_top_genes = st.number_input(label="Number of genes", min_value=1, max_value=100, value=20, key="ni:pp:highly_variable:n_top_genes") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Filter", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Filter", use_container_width=True, type='primary') if submit_btn: try: @@ -137,7 +133,7 @@ def remove_genes(self): st.subheader("Remove genes") remove_genes = st.multiselect(label="Genes", options=self.state_manager.adata_state().current.adata.var_names, key="ms:pp:remove_genes:genes") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: adata = self.state_manager.get_current_adata() with st.spinner(text="Removing genes"): @@ -240,7 +236,7 @@ def run_highly_variable(adata: AnnData, flavour="seurat", min_mean=None, max_mea disp = st.slider(label="Dispersion", min_value=0.00, max_value=100.00, value=(0.50, 100.00), format="%.2f") remove = st.toggle(label="Remove non-variable genes", value=False, key="toggle:pp:highly_variable:seurat_remove", help="By default, highly variable genes are only annoted. This option will remove genes without highly variable expression.") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: adata = self.state_manager.get_current_adata() @@ -296,7 +292,7 @@ def normalize_counts(self): exclude_high_expr = subcol_input1.checkbox(label="Exclude highly_expr", value=False) log_transform_total = subcol_input2.checkbox(label="Log transform", value=False) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if submit_btn: sc.pp.normalize_total(st.session_state.adata_state.current.adata, target_sum=target_sum, exclude_highly_expressed=exclude_high_expr, max_fraction=max_fraction) @@ -341,7 +337,7 @@ def filter_cells(self): min_genes = st.number_input(label="min genes for cell", min_value=1, value=None, key="ni:pp:filter_cells:min_genes") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if submit_btn: adata = self.state_manager.get_current_adata() @@ -385,7 +381,7 @@ def filter_genes(self): min_cells = st.number_input(label="min cells for gene", min_value=1, value=None, key="ni:pp:filter_genes:min_cells") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if submit_btn: adata: AnnData = self.state_manager.get_current_adata() sc.pp.filter_genes(adata, min_cells=min_cells) @@ -449,7 +445,7 @@ def recipes(self): st.write("Parameters") log = st.checkbox(label="Log", value=True, key="cb:pp:recipe:seurat:log") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) + submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True, type='primary') if submit_btn: adata = self.state_manager.get_current_adata() @@ -474,7 +470,7 @@ def recipes(self): n_pcs = col3.number_input(label="n_pcs", min_value=1, value=50, step=1, format="%i") log = st.checkbox(label="Log", value=False, key="cb:pp:recipe:weinreb17:log") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) + submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True, type='primary') if submit_btn: adata = self.state_manager.get_current_adata() sc.pp.recipe_weinreb17(adata, log=log, mean_threshold=mean_threshold, cv_threshold=cv_threshold, n_pcs=n_pcs) @@ -496,7 +492,7 @@ def recipes(self): n_top_genes = st.number_input(label="n_top_genes", key="ni:pp:recipe:zheng17:n_genes", min_value=1, max_value=n_vars, value=min(1000, n_vars)) log = st.checkbox(label="Log", value=False) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True) + submit_btn = subcol1.form_submit_button(label='Apply', use_container_width=True, type='primary') if submit_btn: adata = self.state_manager.get_current_adata() sc.pp.recipe_zheng17(adata, log=log, n_top_genes=n_top_genes) @@ -592,7 +588,7 @@ def plot_charts(color=None): plot_charts() subcol1, _, _ = st.columns(3) - mito_annot_submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + mito_annot_submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if mito_annot_submit_btn: plot_charts(color_key_mito) @@ -681,7 +677,7 @@ def plot_charts(color=None): plot_charts() subcol1, _, _ = st.columns(3) - ribo_annot_submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + ribo_annot_submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if ribo_annot_submit_btn: plot_charts(color_key_ribo) @@ -770,7 +766,7 @@ def plot_charts(color=None): plot_charts() subcol1, _, _ = st.columns(3) - hb_annot_submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + hb_annot_submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if hb_annot_submit_btn: plot_charts(color_key_hb) @@ -850,7 +846,7 @@ def run_scrublet(self): stdev_doublet_rate = col3.number_input(label="stdev_doublet_rate", value=0.02, key="ni:pp:scrublet:stdev_doublet_rate") batch_key = st.selectbox(label="Batch key", key="sb:pp:scrublet:batch_key", options=np.append('None', self.state_manager.adata_state().current.adata.obs_keys())) subcol1, _, _ = st.columns(3) - scrublet_submit = subcol1.form_submit_button(label="Run", use_container_width=True) + scrublet_submit = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if scrublet_submit: try: @@ -1021,7 +1017,7 @@ def run_doubletdetection(self): n_top_var_genes = col3.number_input(label="n_top_var_genes", value=10000, step=1) standard_scaling = st.toggle(label="standard scaling", value=False) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if "doublet_doubletdetection" in st.session_state.adata_state.current.adata.obs: @@ -1100,7 +1096,7 @@ def regress_out(self): st.subheader("Regress out", help="Uses linear regression to remove unwanted sources of variation.") regress_keys = st.multiselect(label="Keys", options=st.session_state.adata_state.current.adata.obs_keys(), key="ms_regress_out_keys") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if submit_btn: if st.session_state.ms_regress_out_keys: sc.pp.regress_out(st.session_state.adata_state.current.adata, keys=regress_keys) @@ -1141,7 +1137,7 @@ def scale_to_unit_variance(self): zero_center = st.toggle(label="Zero center", value=True) max_value = st.number_input(label="Max value", value=10, key="ni:pp:scale_data:max_value") subcol1, _, _ = st.columns(3) - btn_scale_data_btn = subcol1.form_submit_button(label="Apply", use_container_width=True) + btn_scale_data_btn = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if btn_scale_data_btn: adata = self.state_manager.get_current_adata() sc.pp.scale(adata, zero_center=zero_center, max_value=max_value) @@ -1186,7 +1182,7 @@ def downsample_data(self): with st.form(key="downsample_form_counts_per_cell"): counts_per_cell = st.number_input(label="Counts per cell", value=1, step=1, format="%i", key="ni:pp:downsample:counts_per_cell", help="Target total counts per cell. If a cell has more than 'counts_per_cell', it will be downsampled to this number. Resulting counts can be specified on a per cell basis by passing an array.Should be an integer or integer ndarray with same length as number of obs.") subcol1, _, _ = st.columns(3) - btn_downsample_counts_per_cell = subcol1.form_submit_button(label="Apply", use_container_width=True) + btn_downsample_counts_per_cell = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if btn_downsample_counts_per_cell: adata = self.state_manager.get_current_adata() sc.pp.downsample_counts(adata, counts_per_cell=counts_per_cell, random_state=42) @@ -1199,7 +1195,7 @@ def downsample_data(self): with st.form(key="downsample_form_total_counts"): total_counts = st.number_input(label="Total counts", key="ni:pp:downsample:total_counts", help="Target total counts. If the count matrix has more than total_counts it will be downsampled to have this number.") subcol1, _, _ = st.columns(3) - btn_downsample_total_counts = subcol1.form_submit_button(label="Apply", use_container_width=True) + btn_downsample_total_counts = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if btn_downsample_total_counts: adata = self.state_manager.get_current_adata() sc.pp.downsample_counts(adata, total_counts=total_counts, random_state=42) @@ -1246,7 +1242,7 @@ def subsample_data(self): n_obs_default = self.state_manager.adata_state().current.adata.n_obs n_obs = st.number_input(label="n obs", key="ni:pp:subsample:n_obs", help="Subsample to this number of observations.", value=n_obs_default, step=1, format="%i", max_value=n_obs_default) subcol1, _, _ = st.columns(3) - btn_subsample_n_obs = subcol1.form_submit_button(label="Apply", use_container_width=True) + btn_subsample_n_obs = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if btn_subsample_n_obs: adata: AnnData = self.state_manager.get_current_adata() sc.pp.subsample(adata, n_obs=n_obs, random_state=42) @@ -1260,7 +1256,7 @@ def subsample_data(self): with st.form(key="subsample_form_fraction"): fraction = st.number_input(label="subsample_fraction", key="ni:pp:subsample:fraction", help="Subsample this fraction of the number of observations.") subcol1, _, _ = st.columns(3) - btn_subsample_fraction = subcol1.form_submit_button(label="Apply", use_container_width=True) + btn_subsample_fraction = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if btn_subsample_fraction: adata: AnnData = self.state_manager.get_current_adata() sc.pp.subsample(adata, fraction=fraction, random_state=42) @@ -1309,7 +1305,7 @@ def batch_effect_removal(self): key = st.selectbox(label="Batch key", options=self.state_manager.adata_state().current.adata.obs_keys(), key="sb:pp:combat:batch_key", index=index) covariates = st.multiselect(placeholder="Optional", label="Covariates", options=self.state_manager.adata_state().current.adata.obs_keys()) subcol1, _, _ = st.columns(3) - btn_batch_effect_removal = subcol1.form_submit_button(label="Apply", use_container_width=True) + btn_batch_effect_removal = subcol1.form_submit_button(label="Apply", use_container_width=True, type='primary') if btn_batch_effect_removal: adata = self.state_manager.get_current_adata() with st.spinner(text="Running Combat batch effect correction"): @@ -1368,7 +1364,7 @@ def run_pca(adata: AnnData): index = i pca_color = st.selectbox(label="Color", options=adata.obs_keys(), key="sb:pp:pca:color", index=index) subcol1, _, _ = st.columns(3) - pca_pp_btn = subcol1.form_submit_button("Apply", use_container_width=True) + pca_pp_btn = subcol1.form_submit_button("Apply", use_container_width=True, type='primary') pca_empty = st.empty() if st.session_state["preprocess_plots"]["pca"] == None: @@ -1432,7 +1428,7 @@ def measure_gene_counts(self): options = self.state_manager.adata_state().current.adata.var_names genes = st.multiselect(label="Gene (e.g. XIST for detecting sex)", options=options, key="ms:pp:measure_genes:genes") subcol_btn1, _, _ = st.columns(3) - submit_btn = subcol_btn1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol_btn1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Locating genes"): adata = self.state_manager.get_current_adata() @@ -1450,7 +1446,7 @@ def measure_gene_counts(self): batch_key_measure_gene_counts = st.selectbox(label="Obs key", options=st.session_state.adata_state.current.adata.obs_keys(), key="sb:pp:measure_genes:batch") genes = st.selectbox(label="Gene (e.g. XIST for detecting sex)", options=gene_options, key="ms:pp:measure_genes_batch:genes") subcol_btn1, _, _ = st.columns(3) - submit_btn = subcol_btn1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol_btn1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Locating genes"): adata = self.state_manager.get_current_adata() @@ -1578,7 +1574,7 @@ def cell_cycle_scoring(self): jitter = plot_col2.number_input(label="Jitter", min_value=0.1, max_value=1.0, value=0.4, step=0.1) subcol1, _, _, _, _, _, _, _, _ = st.columns(9) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, disabled=(not "pp_cell_cycle_marker_genes_df" in st.session_state)) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary', disabled=(not "pp_cell_cycle_marker_genes_df" in st.session_state)) cell_cycle_container = st.empty() if submit_btn: @@ -1709,6 +1705,10 @@ def cell_cycle_scoring(self): preprocess.cell_cycle_scoring() + + sidebar.steps() + sidebar.delete_experiment_btn() + sidebar.show_version() diff --git a/app/pages/3_Integrate.py b/app/pages/3_Integrate.py index d4d8aec..eb66a5c 100644 --- a/app/pages/3_Integrate.py +++ b/app/pages/3_Integrate.py @@ -86,7 +86,7 @@ def ingest(self): st.markdown(f"""
{st.session_state.adata_ref.adata_name} → {st.session_state.adata_target.adata_name}
""", unsafe_allow_html=True) obs = st.multiselect(label="Obs", options=st.session_state.adata_ref.adata.obs_keys()) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, disabled=(not st.session_state.sync_genes)) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary', disabled=(not st.session_state.sync_genes)) if submit_btn: try: with st.spinner(text="Integrating datasets"): @@ -132,7 +132,7 @@ def scanorama_integrate(self): disabled = True subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, disabled=disabled) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, disabled=disabled, type='primary') if submit_btn: with st.spinner(text="Integrating with Scanorama"): adatas = [st.session_state.adata_ref.adata, st.session_state.adata_target.adata] @@ -182,7 +182,7 @@ def bbknn(self): st.write(f"Apply to {st.session_state.adata_state.current.adata_name}") batch_key = st.selectbox(label="Batch key", options=st.session_state.adata_state.current.adata.obs_keys()) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Computing PCA"): sc.tl.pca(st.session_state.adata_state.current.adata) @@ -228,7 +228,7 @@ def quick_map(self): subcol1_btn, _, _ = st.columns(3) - submit_btn = subcol1_btn.form_submit_button(label="Run", use_container_width=True, disabled=(not st.session_state.sync_genes)) + submit_btn = subcol1_btn.form_submit_button(label="Run", use_container_width=True, disabled=(not st.session_state.sync_genes), type='primary') if submit_btn: try: @@ -286,7 +286,7 @@ def concat(self): batch_key = st.text_input(label="Batch key", value="batch") empty = st.empty() subcol1_btn, _, _ = st.columns(3) - submit_btn = subcol1_btn.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1_btn.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Concatenating datasets"): try: @@ -353,7 +353,7 @@ def compute_umap(): colors = st.multiselect(label="Color (obs)", options=options, default=default) container = st.container() subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') compute_umap() @@ -433,7 +433,7 @@ def scvi_integrate(self): max_epochs = input_col2.number_input(label="max_epochs", min_value=1, step=5, value=get_max_epochs_heuristic(st.session_state.adata_state.current.adata.n_obs), format="%i") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Training model"): @@ -527,7 +527,7 @@ def scanvi_integrate(self): n_samples_per_label = input_col2.number_input(label="n_samples_per_label", min_value=1, value=100, format="%i", step=1) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: scvi.model.SCVI.setup_anndata(st.session_state.adata_state.current.adata, layer="counts", batch_key=batch_key) @@ -607,7 +607,7 @@ def scvi_integrate_graphs(self): embedding = input_col2.selectbox(label="Embedding", options=options, disabled=(not is_embeddings)) preserve_neighbours = st.toggle(label="Preserve neighbours", value=True, disabled=(not is_embeddings)) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, disabled=(not is_embeddings)) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, disabled=(not is_embeddings), type='primary') if submit_btn: with st.spinner(text="Generating plots"): for color in colors: @@ -682,7 +682,7 @@ def scvi_metrics(self): batch_key = col1.selectbox(label="batch_key", options=st.session_state.adata_state.current.adata.obs_keys(), disabled=(not is_embeddings)) label_key = col2.selectbox(label="Label key", options=st.session_state.adata_state.current.adata.obs_keys(), disabled=(not is_embeddings)) subcol1, _, _, _, _, _, _, _, _ = st.columns(9) - submit_btn = subcol1.form_submit_button(label="Compute", use_container_width=True, disabled=(not is_embeddings)) + submit_btn = subcol1.form_submit_button(label="Compute", use_container_width=True, disabled=(not is_embeddings), type='primary') if submit_btn: with st.spinner(text="Computing metrics"): @@ -718,6 +718,10 @@ def scvi_metrics(self): integrate = Integrate() + sidebar.steps() + sidebar.delete_experiment_btn() + sidebar.show_version() + except Exception as e: if(st.session_state == {}): diff --git a/app/pages/4_Create_model.py b/app/pages/4_Create_model.py index 6f2d116..9cffcb2 100644 --- a/app/pages/4_Create_model.py +++ b/app/pages/4_Create_model.py @@ -361,6 +361,10 @@ def create_model(adata): adata = st.session_state.adata_state.current.adata.copy() create_model(adata) + + sidebar.steps() + sidebar.delete_experiment_btn() + sidebar.show_version() except Exception as e: if(st.session_state == {}): diff --git a/app/pages/6_Cluster_plots.py b/app/pages/6_Cluster_plots.py index 896bc41..9c2ef99 100644 --- a/app/pages/6_Cluster_plots.py +++ b/app/pages/6_Cluster_plots.py @@ -55,7 +55,7 @@ def __init__(self, adata): self.col1, self.col2, self.col3 = st.columns(3) - self.conn: SessionLocal = SessionLocal() + self.conn = SessionLocal() self.PLOT_HEIGHT = 800 self.MARKER_SIZE = 32 @@ -143,7 +143,7 @@ def autoencoder_cluster_plot(self): st.line_chart(losses_df, use_container_width=True, height=290) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Update", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Update", use_container_width=True, type='primary') if submit_btn: df = st.session_state["cluster_plots"]["autoencoder"]["df"] @@ -169,7 +169,7 @@ def autoencoder_cluster_plot(self): st.line_chart(vae_df, use_container_width=True, height=360, color=['#52f27d', '#f25272']) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Filter doublets", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Filter doublets", use_container_width=True, type='primary') if submit_btn: st.session_state.adata_state.current.adata = st.session_state.adata_state.current.adata[st.session_state.adata_state.current.adata.obs.solo_prediction == 'singlet'] @@ -195,7 +195,7 @@ def autoencoder_cluster_plot(self): st.session_state["cluster_plots"]["autoencoder"] = dict(df=df_ldvae, x="UMAP1", y="UMAP2", color_keys=[trained_model.SCVI_CLUSTERS_KEY], size=self.MARKER_SIZE) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: raise Exception @@ -255,7 +255,7 @@ def pca_graph(self): st.session_state["cluster_plots"]["pca"] = dict(df=df_pca, color_keys=np.array([pca_colors]).flatten(), x="PCA1", y="PCA2", size=self.MARKER_SIZE) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner("Computing PCA coordinates"): self._do_pca(zero_center=zero_center) @@ -322,7 +322,7 @@ def variance_ratio_graph(self): n_pcs = subcol1.number_input(label="n_pcs", min_value=1, max_value=50, value=30) log = subcol1.toggle(label="Log", value=True) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner("Computing variance ratio"): @@ -380,7 +380,7 @@ def tsne_graph(self): st.session_state["cluster_plots"]["tsne"] = dict(df=df_tsne, color_keys=np.array([tsne_colors]).flatten(), x="tSNE1", y="tSNE2", size=self.MARKER_SIZE) subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner("Computing tSNE coordinates"): @@ -447,7 +447,7 @@ def neighbourhood_graph(self): subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner("Computing UMAP coordinates"): @@ -538,6 +538,10 @@ def _plot_charts(self, params): st.divider() analysis.plots() + sidebar.steps() + sidebar.delete_experiment_btn() + sidebar.show_version() + except Exception as e: if(st.session_state == {}): diff --git a/app/pages/7_Differential_gene_expression.py b/app/pages/7_Differential_gene_expression.py index 9f87ba4..b44cede 100644 --- a/app/pages/7_Differential_gene_expression.py +++ b/app/pages/7_Differential_gene_expression.py @@ -7,7 +7,6 @@ import pandas as pd import scanpy as sc import matplotlib.pyplot as plt -from matplotlib_venn import venn3 import altair as alt import plotly.graph_objects as go from utils.plotting import plot_top_ranked_genes @@ -87,7 +86,7 @@ def stat_tests(self): group_by = st.selectbox(label="Group by", options=st.session_state.adata_state.current.adata.obs_keys()) marker_genes_container = st.empty() subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: marker_genes_container.empty() with st.spinner(text="Computing tests"): @@ -159,7 +158,7 @@ def run_umap(adata): umap_color = st.selectbox(label="Color", options=self.adata.obs_keys(), key="sb_umap_color_dge") subcol1, _, _ = st.columns(3) - umap_dge_btn = subcol1.form_submit_button("Apply", use_container_width=True) + umap_dge_btn = subcol1.form_submit_button("Apply", use_container_width=True, type='primary') umap_empty = st.empty() run_umap(self.adata) @@ -204,7 +203,7 @@ def add_embeddings(self): algorithm = st.selectbox(label="Algorithm", options=['Leiden', 'Louvain']) resolution = st.number_input(label="Resolution", min_value=0.1, value=0.6, format="%.1f") subcol1, _, _ = st.columns(3) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text=f"Computing {algorithm} clusters"): sc.pp.neighbors(self.adata, n_neighbors=10, n_pcs=40) @@ -258,7 +257,7 @@ def visualize(self): n_genes = input_col2.text_input(label="n_genes", value=5, help="Number of genes to display in each cluster.") group_by = input_col2.selectbox(label="Group by", options=st.session_state.adata_state.current.adata.obs_keys()) subcol1, _, _, _, _, _, _, _, _ = st.columns(9) - viz_submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + viz_submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if viz_submit_btn: plt.style.use('tableau-colorblind10') @@ -336,7 +335,7 @@ def rank_genes_groups(self): reference = subcol3.text_input(label="Reference", value="rest") use_raw = subcol3.toggle(label="Use raw", value=False) subcol1, _, _, _, _, _, _, _, _ = st.columns(9) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Calculating plots"): sc.tl.rank_genes_groups(st.session_state.adata_state.current.adata, groupby=group_by, method=method, use_raw=use_raw, reference=reference) @@ -374,7 +373,7 @@ def rank_genes_groups(self): cluster1 = subcol1.selectbox(label="Compare group 1", options=np.sort(st.session_state.adata_state.current.adata.obs[st.session_state.sb_violin_cluster_group].unique()), key="sb_cluster1_violin") cluster2 = subcol1.selectbox(label="Compare group 2", options=np.append('rest', np.sort(st.session_state.adata_state.current.adata.obs[st.session_state.sb_violin_cluster_group].unique())), key="sb_cluster2_violin") subcol1, _, _, _, _, _, _, _, _ = st.columns(9) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Calculating plots"): @@ -441,7 +440,7 @@ def rank_genes_groups(self): cluster = subcol1.selectbox(label="Group", options=st.session_state.adata_state.current.adata.obs_keys()) genes = subcol2.multiselect(label="Genes", options=np.sort(st.session_state.adata_state.current.adata.var_names)) subcol1, _, _, _, _, _, _, _, _ = st.columns(9) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') line_colors = ['#d1454c', '#8ee065', '#eda621', '#f071bf', '#9071f0', '#71e3f0', '#2f39ed', '#ed2f7b'] if submit_btn: @@ -537,7 +536,7 @@ def show_top_ranked_genes(self): num_of_rows = col2.number_input(label="Number of rows", min_value=1, step=1, format="%i", value=8) method = st.radio(label="method", options=['t-test', 't-test_overestim_var', 'wilcoxon', 'logreg']) subcol1, _, _, _, _, _, _, _, _ = st.columns(9) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: fig = plot_top_ranked_genes(self.adata, cluster_name=cluster_key, n_rows=num_of_rows, method=method, height=800) @@ -560,6 +559,10 @@ def show_top_ranked_genes(self): dge = Differential_gene_expression(st.session_state.adata_state.current.adata.copy()) + sidebar.steps() + sidebar.delete_experiment_btn() + sidebar.show_version() + except Exception as e: if(st.session_state == {}): diff --git a/app/pages/8_Trajectory_Inference.py b/app/pages/8_Trajectory_Inference.py index 2160630..d401eba 100644 --- a/app/pages/8_Trajectory_Inference.py +++ b/app/pages/8_Trajectory_Inference.py @@ -168,7 +168,7 @@ def paga_clustering(self): #genes options = st.multiselect(label="Genes", options=self.adata.var_names, default=[self.adata.var_names[0]]) subcol1, _, _, _ = st.columns(4) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Computing paga clusters"): @@ -309,7 +309,7 @@ def diffusion_pseudotime(self): algorithm = st.radio(label="Clustering algorithm", options=['leiden', 'louvain']) root = st.selectbox(label='Root cell', options=(self.adata.obs['louvain'].unique())) subcol1, _, _, _ = st.columns(4) - dpt_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + dpt_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if dpt_btn: with st.spinner(text="Computing dpt"): if algorithm not in self.adata.obs: @@ -407,6 +407,10 @@ def show_path(self): tji.draw_page() + sidebar.steps() + sidebar.delete_experiment_btn() + sidebar.show_version() + except Exception as e: diff --git a/app/pages/9_Spatial_Transcriptomics.py b/app/pages/9_Spatial_Transcriptomics.py index 4253641..335ede8 100644 --- a/app/pages/9_Spatial_Transcriptomics.py +++ b/app/pages/9_Spatial_Transcriptomics.py @@ -192,7 +192,7 @@ def neighbourhood_enrichment(self): n_perms = col1.number_input(label="n_perms", min_value=1, value=1000) mode = col2.selectbox(label="mode", options=['zscore', 'count']) subcol1, _, _, _ = st.columns(4) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Running neighbourhood enrichment"): sq.gr.spatial_neighbors(self.adata) @@ -258,7 +258,7 @@ def ripley_score(self): mode = col1.radio(label="mode", options=['F', 'G', 'L']) plot_sims = st.toggle(label="plot_sims", value=False) subcol1, _, _, _ = st.columns(4) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Calculating Ripley score"): sq.gr.ripley(self.adata, cluster_key=cluster_key, mode=mode, max_dist=max_dist, n_neigh=n_neighbours, n_simulations=n_simulations) @@ -306,7 +306,7 @@ def co_occurance_score(self): clusters = st.multiselect(label="Clusters", options=self.adata.obs[f"{st.session_state['sb:spatial:co_occurance:cluster_key']}"].unique()) subcol1, _, _, _ = st.columns(4) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Calculating co-occurance score"): cluster_key = st.session_state['sb:spatial:co_occurance:cluster_key'] @@ -352,7 +352,7 @@ def interaction_matrix(self): col1, col2 = st.columns(2, gap="medium") cluster_key = col1.selectbox(label="Cluster Key", options=self.adata.obs_keys()) subcol1, _, _, _ = st.columns(4) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Computing interaction matrix"): sq.gr.spatial_neighbors(self.adata) @@ -398,7 +398,7 @@ def centrality_score(self): cluster_key = st.selectbox(label="Cluster Key", options=self.adata.obs_keys()) subcol1, _, _, _ = st.columns(4) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Calculating centrality score"): sq.gr.spatial_neighbors(self.adata) @@ -427,7 +427,7 @@ def ligand_receptor_interaction(self): st.multiselect(label="Target groups", options=options, default=options[0], key="ms_lri_target_groups") empty = st.empty() subcol1, _, _, _ = st.columns(4) - submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True) + submit_btn = subcol1.form_submit_button(label="Run", use_container_width=True, type='primary') if submit_btn: with st.spinner(text="Computing ligand-receptor interaction matrix"): sq.gr.ligrec(self.adata, n_perms=100, cluster_key=st.session_state.sb_cluster_key_lri) @@ -453,6 +453,10 @@ def ligand_receptor_interaction(self): spatial_t = Spatial_transcriptomics(adata) spatial_t.draw_page() + sidebar.steps() + sidebar.delete_experiment_btn() + sidebar.show_version() + except Exception as e: if(st.session_state == {}): diff --git a/app/state/AdataState.py b/app/state/AdataState.py index d6f6765..c569a11 100644 --- a/app/state/AdataState.py +++ b/app/state/AdataState.py @@ -29,6 +29,8 @@ def __init__(self, active: AdataModel, insert_into_db=True): #reinitialise current to include additional fields current: schemas.Adata = db_adatas.first() self.current = AdataModel(work_id=current.work_id, adata_name=current.adata_name, created=current.created, notes=current.notes, id=current.id, filename=current.filename) + # Add to env + os.environ["CURRENT_ADATA_ID"] = str(current.id) # add original adata to object self.current.adata = active.adata @@ -48,6 +50,7 @@ def switch_adata(self, adata_name): self.current = new_current self.current_index = self.get_index_of_current() st.session_state["script_state"].switch_adata(new_current.id) #swap adata in script state + os.environ["CURRENT_ADATA_ID"] = new_current.id except Exception as e: st.toast(e, icon="❌") diff --git a/app/state/StateManager.py b/app/state/StateManager.py index 4a24f50..415e8fb 100644 --- a/app/state/StateManager.py +++ b/app/state/StateManager.py @@ -39,19 +39,22 @@ def add_description(self, description: str): def load_session(self): - # fetch cache file from db - if "current_workspace" in st.session_state: - current_workspace_id = st.session_state.current_workspace.id + # get current adata id either from session state or environment + if "adata_state" in st.session_state: + current_adata_id = st.session_state.adata_state.current.id else: - current_workspace_id = os.getenv('CURRENT_WORKSPACE_ID') + current_adata_id = os.getenv('CURRENT_ADATA_ID') + # fetch from database conn = SessionLocal() - cache_file = conn.query(schemas.Session) \ - .filter(schemas.Session.work_id == current_workspace_id) \ - .first() \ - .filename + cache_files = conn.query(schemas.Session) \ + .filter(schemas.Session.adata_id == int(current_adata_id)) \ + .all() \ + .sort(schemas.Session.created) - load_data_from_cache(cache_file) + st.write(cache_files) + + #load_data_from_cache(cache_file) def save_session(self): @@ -76,9 +79,6 @@ def save_session(self): # cache data to pickle file cache_data_to_session(description=self.description) - - def init_session(): - raise NotImplementedError ########## Adata state ##########