Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions src/tdamapper/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import os
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Literal, Optional

import networkx as nx
Expand Down Expand Up @@ -91,11 +91,12 @@
RANDOM_SEED = 42


@dataclass
@dataclass(frozen=True)
class MapperConfig:
"""
Configuration for the Mapper algorithm.

:param excluded_columns: List of columns to exclude from the input data.
:param lens_type: Type of lens to use for dimensionality reduction.
:param cover_scale_data: Whether to scale the data before covering.
:param cover_type: Type of cover to use for the Mapper algorithm.
Expand All @@ -119,6 +120,7 @@ class MapperConfig:
:param plot_seed: Random seed for reproducibility.
"""

excluded_columns: list[str] = field(default_factory=list)
lens_type: str = LENS_PCA
cover_scale_data: bool = COVER_SCALE_DATA
cover_type: str = COVER_CUBICAL
Expand Down Expand Up @@ -222,6 +224,7 @@ def run_mapper(
params: dict[str, Any] = kwargs
mapper_config = MapperConfig(**params)

excluded_columns = mapper_config.excluded_columns
lens_type = mapper_config.lens_type
cover_scale_data = mapper_config.cover_scale_data
cover_type = mapper_config.cover_type
Expand Down Expand Up @@ -284,7 +287,11 @@ def run_mapper(
return None

mapper = MapperAlgorithm(cover=cover, clustering=clustering)
X = df.to_numpy()
if excluded_columns:
X = df.drop(columns=excluded_columns, errors="ignore").to_numpy()
else:
X = df.to_numpy()

y = lens(X)
df_y = pd.DataFrame(y, columns=[f"{lens_type} {i}" for i in range(y.shape[1])])
if cover_scale_data:
Expand Down Expand Up @@ -435,6 +442,9 @@ def __init__(self, storage: dict[str, Any]) -> None:
color="themelight",
).classes("w-full text-themedark")

with ui.column().classes("w-full gap-2"):
self._init_columns()

with ui.column().classes("w-full gap-2"):
self._init_lens()

Expand Down Expand Up @@ -512,6 +522,17 @@ def _init_file_upload(self) -> None:
value=LOAD_EXAMPLE,
)

def _init_columns(self) -> None:
ui.label("⚙️ Settings").classes("text-h6")
self._init_columns_settings()

def _init_columns_settings(self) -> None:
self.excluded_columns = ui.select(
options=[],
label="Exclude columns",
multiple=True,
).classes("w-full")

def _init_lens(self) -> None:
ui.label("🔎 Lens").classes("text-h6")
self._init_lens_settings()
Expand Down Expand Up @@ -726,6 +747,9 @@ def get_mapper_config(self) -> MapperConfig:
else:
plot_dimensions = PLOT_DIMENSIONS
return MapperConfig(
excluded_columns=(
self.excluded_columns.value if self.excluded_columns.value else []
),
lens_type=str(self.lens_type.value) if self.lens_type.value else LENS_PCA,
cover_type=(
str(self.cover_type.value) if self.cover_type.value else COVER_CUBICAL
Expand Down Expand Up @@ -854,6 +878,8 @@ def load_data(self) -> None:
df = self.storage.get("df", pd.DataFrame())
if df is not None and not df.empty:
logger.info("Load data completed.")
self.excluded_columns.set_options(list(df.columns))
self.excluded_columns.set_value([])
ui.notify("Load data completed.", type="positive")
else:
error = "Load data failed: no data found, please upload a file first."
Expand Down Expand Up @@ -910,21 +936,26 @@ async def async_run_mapper(self) -> None:
self.notification_running_stop(notification, error, type="warning")
return
mapper_config = self.get_mapper_config()
result = await run.cpu_bound(run_mapper, df_X, **asdict(mapper_config))
if result is None:
error = "Run Mapper failed: something went wrong."
try:
result = await run.cpu_bound(run_mapper, df_X, **asdict(mapper_config))
if result is None:
error = "Run Mapper failed: something went wrong."
logger.error(error)
self.notification_running_stop(notification, error, type="error")
return
mapper_graph, df_y = result
if mapper_graph is not None:
self.storage["mapper_graph"] = mapper_graph
if df_y is not None:
self.storage["df_y"] = df_y
self.notification_running_stop(
notification, "Run Mapper completed.", type="positive"
)
await self.async_draw_mapper()
except Exception as e:
error = f"Run Mapper failed: {e}"
logger.error(error)
self.notification_running_stop(notification, error, type="error")
return
mapper_graph, df_y = result
if mapper_graph is not None:
self.storage["mapper_graph"] = mapper_graph
if df_y is not None:
self.storage["df_y"] = df_y
self.notification_running_stop(
notification, "Run Mapper completed.", type="positive"
)
await self.async_draw_mapper()

async def async_draw_mapper(self) -> None:
"""
Expand Down
14 changes: 14 additions & 0 deletions tests/test_unit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,37 @@ async def test_run_app_fail(user: User) -> None:
async def test_run_app_success(user: User) -> None:
app.startup()
await user.open("/")

await user.should_see("Load Data")
await user.should_see("Lens")
await user.should_see("Cover")
await user.should_see("Clustering")
await user.should_see("Run Mapper")
await user.should_see("Redraw")

user.find("Dataset").click()
user.find("Digits").click()
user.find("Load Data").click()
await user.should_see("Load data completed")
await user.should_not_see("Load data failed")

user.find("Dataset").click()
user.find("Iris").click()
user.find("Load Data").click()
await user.should_see("Load data completed")
await user.should_not_see("Load data failed")

user.find("Exclude columns").click()
user.find("sepal length (cm)").click()

user.find("Run Mapper").click()
await user.should_see("Running Mapper...")
await user.should_not_see("Run Mapper failed")
await user.should_see("Run Mapper completed", retries=RETRIES)
await user.should_see("Drawing Mapper...")
await user.should_not_see("Draw Mapper failed")
await user.should_see("Draw Mapper completed", retries=RETRIES)

user.find("Redraw").click()
await user.should_see("Drawing Mapper...")
await user.should_not_see("Draw Mapper failed")
Expand Down