Skip to content

Commit 6e166e6

Browse files
committed
feat: added column filter for exclusion of columns that are not topologically relevant
1 parent 25f7689 commit 6e166e6

File tree

1 file changed

+47
-16
lines changed

1 file changed

+47
-16
lines changed

src/tdamapper/app.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
import os
7-
from dataclasses import asdict, dataclass
7+
from dataclasses import asdict, dataclass, field
88
from typing import Any, Callable, Literal, Optional
99

1010
import networkx as nx
@@ -91,11 +91,12 @@
9191
RANDOM_SEED = 42
9292

9393

94-
@dataclass
94+
@dataclass(frozen=True)
9595
class MapperConfig:
9696
"""
9797
Configuration for the Mapper algorithm.
9898
99+
:param excluded_columns: List of columns to exclude from the input data.
99100
:param lens_type: Type of lens to use for dimensionality reduction.
100101
:param cover_scale_data: Whether to scale the data before covering.
101102
:param cover_type: Type of cover to use for the Mapper algorithm.
@@ -119,6 +120,7 @@ class MapperConfig:
119120
:param plot_seed: Random seed for reproducibility.
120121
"""
121122

123+
excluded_columns: list[str] = field(default_factory=list)
122124
lens_type: str = LENS_PCA
123125
cover_scale_data: bool = COVER_SCALE_DATA
124126
cover_type: str = COVER_CUBICAL
@@ -222,6 +224,7 @@ def run_mapper(
222224
params: dict[str, Any] = kwargs
223225
mapper_config = MapperConfig(**params)
224226

227+
excluded_columns = mapper_config.excluded_columns
225228
lens_type = mapper_config.lens_type
226229
cover_scale_data = mapper_config.cover_scale_data
227230
cover_type = mapper_config.cover_type
@@ -284,7 +287,11 @@ def run_mapper(
284287
return None
285288

286289
mapper = MapperAlgorithm(cover=cover, clustering=clustering)
287-
X = df.to_numpy()
290+
if excluded_columns:
291+
X = df.drop(columns=excluded_columns, errors="ignore").to_numpy()
292+
else:
293+
X = df.to_numpy()
294+
288295
y = lens(X)
289296
df_y = pd.DataFrame(y, columns=[f"{lens_type} {i}" for i in range(y.shape[1])])
290297
if cover_scale_data:
@@ -435,6 +442,9 @@ def __init__(self, storage: dict[str, Any]) -> None:
435442
color="themelight",
436443
).classes("w-full text-themedark")
437444

445+
with ui.column().classes("w-full gap-2"):
446+
self._init_columns()
447+
438448
with ui.column().classes("w-full gap-2"):
439449
self._init_lens()
440450

@@ -512,6 +522,17 @@ def _init_file_upload(self) -> None:
512522
value=LOAD_EXAMPLE,
513523
)
514524

525+
def _init_columns(self) -> None:
526+
ui.label("⚙️ Settings").classes("text-h6")
527+
self._init_columns_settings()
528+
529+
def _init_columns_settings(self) -> None:
530+
self.excluded_columns = ui.select(
531+
options=[],
532+
label="Exclude columns",
533+
multiple=True,
534+
).classes("w-full")
535+
515536
def _init_lens(self) -> None:
516537
ui.label("🔎 Lens").classes("text-h6")
517538
self._init_lens_settings()
@@ -726,6 +747,9 @@ def get_mapper_config(self) -> MapperConfig:
726747
else:
727748
plot_dimensions = PLOT_DIMENSIONS
728749
return MapperConfig(
750+
excluded_columns=(
751+
self.excluded_columns.value if self.excluded_columns.value else []
752+
),
729753
lens_type=str(self.lens_type.value) if self.lens_type.value else LENS_PCA,
730754
cover_type=(
731755
str(self.cover_type.value) if self.cover_type.value else COVER_CUBICAL
@@ -854,6 +878,8 @@ def load_data(self) -> None:
854878
df = self.storage.get("df", pd.DataFrame())
855879
if df is not None and not df.empty:
856880
logger.info("Load data completed.")
881+
self.excluded_columns.set_options(list(df.columns))
882+
self.excluded_columns.set_value([])
857883
ui.notify("Load data completed.", type="positive")
858884
else:
859885
error = "Load data failed: no data found, please upload a file first."
@@ -910,21 +936,26 @@ async def async_run_mapper(self) -> None:
910936
self.notification_running_stop(notification, error, type="warning")
911937
return
912938
mapper_config = self.get_mapper_config()
913-
result = await run.cpu_bound(run_mapper, df_X, **asdict(mapper_config))
914-
if result is None:
915-
error = "Run Mapper failed: something went wrong."
939+
try:
940+
result = await run.cpu_bound(run_mapper, df_X, **asdict(mapper_config))
941+
if result is None:
942+
error = "Run Mapper failed: something went wrong."
943+
logger.error(error)
944+
self.notification_running_stop(notification, error, type="error")
945+
return
946+
mapper_graph, df_y = result
947+
if mapper_graph is not None:
948+
self.storage["mapper_graph"] = mapper_graph
949+
if df_y is not None:
950+
self.storage["df_y"] = df_y
951+
self.notification_running_stop(
952+
notification, "Run Mapper completed.", type="positive"
953+
)
954+
await self.async_draw_mapper()
955+
except Exception as e:
956+
error = f"Run Mapper failed: {e}"
916957
logger.error(error)
917958
self.notification_running_stop(notification, error, type="error")
918-
return
919-
mapper_graph, df_y = result
920-
if mapper_graph is not None:
921-
self.storage["mapper_graph"] = mapper_graph
922-
if df_y is not None:
923-
self.storage["df_y"] = df_y
924-
self.notification_running_stop(
925-
notification, "Run Mapper completed.", type="positive"
926-
)
927-
await self.async_draw_mapper()
928959

929960
async def async_draw_mapper(self) -> None:
930961
"""

0 commit comments

Comments
 (0)