diff --git a/src/tdamapper/app.py b/src/tdamapper/app.py index aac6adc..27b9869 100644 --- a/src/tdamapper/app.py +++ b/src/tdamapper/app.py @@ -4,7 +4,7 @@ import logging import os -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from typing import Any, Callable, Literal, Optional import networkx as nx @@ -91,11 +91,12 @@ RANDOM_SEED = 42 -@dataclass +@dataclass(frozen=True) class MapperConfig: """ Configuration for the Mapper algorithm. + :param excluded_columns: List of columns to exclude from the input data. :param lens_type: Type of lens to use for dimensionality reduction. :param cover_scale_data: Whether to scale the data before covering. :param cover_type: Type of cover to use for the Mapper algorithm. @@ -119,6 +120,7 @@ class MapperConfig: :param plot_seed: Random seed for reproducibility. """ + excluded_columns: list[str] = field(default_factory=list) lens_type: str = LENS_PCA cover_scale_data: bool = COVER_SCALE_DATA cover_type: str = COVER_CUBICAL @@ -222,6 +224,7 @@ def run_mapper( params: dict[str, Any] = kwargs mapper_config = MapperConfig(**params) + excluded_columns = mapper_config.excluded_columns lens_type = mapper_config.lens_type cover_scale_data = mapper_config.cover_scale_data cover_type = mapper_config.cover_type @@ -284,7 +287,11 @@ def run_mapper( return None mapper = MapperAlgorithm(cover=cover, clustering=clustering) - X = df.to_numpy() + if excluded_columns: + X = df.drop(columns=excluded_columns, errors="ignore").to_numpy() + else: + X = df.to_numpy() + y = lens(X) df_y = pd.DataFrame(y, columns=[f"{lens_type} {i}" for i in range(y.shape[1])]) if cover_scale_data: @@ -435,6 +442,9 @@ def __init__(self, storage: dict[str, Any]) -> None: color="themelight", ).classes("w-full text-themedark") + with ui.column().classes("w-full gap-2"): + self._init_columns() + with ui.column().classes("w-full gap-2"): self._init_lens() @@ -512,6 +522,17 @@ def _init_file_upload(self) -> None: value=LOAD_EXAMPLE, ) + def _init_columns(self) -> None: + ui.label("⚙️ Settings").classes("text-h6") + self._init_columns_settings() + + def _init_columns_settings(self) -> None: + self.excluded_columns = ui.select( + options=[], + label="Exclude columns", + multiple=True, + ).classes("w-full") + def _init_lens(self) -> None: ui.label("🔎 Lens").classes("text-h6") self._init_lens_settings() @@ -726,6 +747,9 @@ def get_mapper_config(self) -> MapperConfig: else: plot_dimensions = PLOT_DIMENSIONS return MapperConfig( + excluded_columns=( + self.excluded_columns.value if self.excluded_columns.value else [] + ), lens_type=str(self.lens_type.value) if self.lens_type.value else LENS_PCA, cover_type=( str(self.cover_type.value) if self.cover_type.value else COVER_CUBICAL @@ -854,6 +878,8 @@ def load_data(self) -> None: df = self.storage.get("df", pd.DataFrame()) if df is not None and not df.empty: logger.info("Load data completed.") + self.excluded_columns.set_options(list(df.columns)) + self.excluded_columns.set_value([]) ui.notify("Load data completed.", type="positive") else: error = "Load data failed: no data found, please upload a file first." @@ -910,21 +936,26 @@ async def async_run_mapper(self) -> None: self.notification_running_stop(notification, error, type="warning") return mapper_config = self.get_mapper_config() - result = await run.cpu_bound(run_mapper, df_X, **asdict(mapper_config)) - if result is None: - error = "Run Mapper failed: something went wrong." + try: + result = await run.cpu_bound(run_mapper, df_X, **asdict(mapper_config)) + if result is None: + error = "Run Mapper failed: something went wrong." + logger.error(error) + self.notification_running_stop(notification, error, type="error") + return + mapper_graph, df_y = result + if mapper_graph is not None: + self.storage["mapper_graph"] = mapper_graph + if df_y is not None: + self.storage["df_y"] = df_y + self.notification_running_stop( + notification, "Run Mapper completed.", type="positive" + ) + await self.async_draw_mapper() + except Exception as e: + error = f"Run Mapper failed: {e}" logger.error(error) self.notification_running_stop(notification, error, type="error") - return - mapper_graph, df_y = result - if mapper_graph is not None: - self.storage["mapper_graph"] = mapper_graph - if df_y is not None: - self.storage["df_y"] = df_y - self.notification_running_stop( - notification, "Run Mapper completed.", type="positive" - ) - await self.async_draw_mapper() async def async_draw_mapper(self) -> None: """ diff --git a/tests/test_unit_app.py b/tests/test_unit_app.py index 2f1d0de..2957b31 100644 --- a/tests/test_unit_app.py +++ b/tests/test_unit_app.py @@ -28,6 +28,7 @@ async def test_run_app_fail(user: User) -> None: async def test_run_app_success(user: User) -> None: app.startup() await user.open("/") + await user.should_see("Load Data") await user.should_see("Lens") await user.should_see("Cover") @@ -35,9 +36,21 @@ async def test_run_app_success(user: User) -> None: await user.should_see("Run Mapper") await user.should_see("Redraw") + user.find("Dataset").click() + user.find("Digits").click() + user.find("Load Data").click() + await user.should_see("Load data completed") + await user.should_not_see("Load data failed") + + user.find("Dataset").click() + user.find("Iris").click() user.find("Load Data").click() await user.should_see("Load data completed") await user.should_not_see("Load data failed") + + user.find("Exclude columns").click() + user.find("sepal length (cm)").click() + user.find("Run Mapper").click() await user.should_see("Running Mapper...") await user.should_not_see("Run Mapper failed") @@ -45,6 +58,7 @@ async def test_run_app_success(user: User) -> None: await user.should_see("Drawing Mapper...") await user.should_not_see("Draw Mapper failed") await user.should_see("Draw Mapper completed", retries=RETRIES) + user.find("Redraw").click() await user.should_see("Drawing Mapper...") await user.should_not_see("Draw Mapper failed")