diff --git a/.gitignore b/.gitignore index 12aece251..6fbff9e5d 100644 --- a/.gitignore +++ b/.gitignore @@ -173,7 +173,6 @@ poetry.toml .ruff_cache/ # LSP config files -pyrightconfig.json # Syncthing .stfolder/ @@ -235,7 +234,7 @@ compile_commands.json *_qmlcache.qrc ### VisualStudioCode ### -.vscode/* +/.vscode # !.vscode/settings.json # !.vscode/tasks.json # !.vscode/launch.json @@ -268,3 +267,4 @@ TagStudio.ini result result-* +uv.lock diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index 1ce4fc85f..12b4aeef6 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -1685,12 +1685,47 @@ def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]: return all_tags - def add_parent_tag(self, parent_id: int, child_id: int) -> bool: + def get_tag_descendants(self, tag_id: int, session: Session | None = None) -> set[int]: + """Return ids of every tag that lists `tag_id` as an ancestor.""" + owns_session = False + if session is None: + session = Session(self.engine) + owns_session = True + + try: + descendants: set[int] = set() + frontier: set[int] = {tag_id} + + while frontier: + stmt = select(TagParent.child_id).where(TagParent.parent_id.in_(frontier)) + children = set(session.scalars(stmt).all()) + children -= descendants + children.discard(tag_id) + descendants.update(children) + frontier = children + + return descendants + finally: + if owns_session: + session.close() + + def _would_create_parent_cycle(self, parent_id: int, child_id: int, session: Session) -> bool: if parent_id == child_id: - return False + return True + + return parent_id in self.get_tag_descendants(child_id, session=session) + def add_parent_tag(self, parent_id: int, child_id: int) -> bool: # open session and save as parent tag with Session(self.engine) as session: + if self._would_create_parent_cycle(parent_id, child_id, session): + logger.warning( + "[Library][add_parent_tag] Prevented cyclical parent assignment", + parent_id=parent_id, + child_id=child_id, + ) + return False + parent_tag = TagParent( parent_id=parent_id, child_id=child_id, @@ -1814,10 +1849,10 @@ def update_aliases( session.add(alias) def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session: Session): - if tag.id in parent_ids: - parent_ids.remove(tag.id) + new_parent_ids: set[int] = set(parent_ids) + new_parent_ids.discard(tag.id) - if tag.disambiguation_id not in parent_ids: + if tag.disambiguation_id not in new_parent_ids: tag.disambiguation_id = None # load all tag's parent tags to know which to remove @@ -1826,14 +1861,22 @@ def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session ).all() for parent_tag in prev_parent_tags: - if parent_tag.parent_id not in parent_ids: + if parent_tag.parent_id not in new_parent_ids: session.delete(parent_tag) else: # no change, remove from list - parent_ids.remove(parent_tag.parent_id) + new_parent_ids.remove(parent_tag.parent_id) + + # create remaining items + for parent_id in list(new_parent_ids): + if self._would_create_parent_cycle(parent_id, tag.id, session): + logger.warning( + "[Library][update_parent_tags] Prevented cyclical parent assignment", + parent_id=parent_id, + child_id=tag.id, + ) + continue - # create remaining items - for parent_id in parent_ids: # add new parent tag parent_tag = TagParent( parent_id=parent_id, diff --git a/src/tagstudio/qt/mixed/build_tag.py b/src/tagstudio/qt/mixed/build_tag.py index cfffb6595..17c364c26 100644 --- a/src/tagstudio/qt/mixed/build_tag.py +++ b/src/tagstudio/qt/mixed/build_tag.py @@ -163,11 +163,12 @@ def __init__(self, library: Library, tag: Tag | None = None) -> None: self.parent_tags_add_button.setText("+") self.parent_tags_layout.addWidget(self.parent_tags_add_button) - exclude_ids: list[int] = list() - if tag is not None: - exclude_ids.append(tag.id) + exclude_ids: set[int] = set() + if tag is not None and tag.id is not None: + exclude_ids.add(tag.id) + exclude_ids.update(self.lib.get_tag_descendants(tag.id)) - self.add_tag_modal = TagSearchModal(self.lib, exclude_ids) + self.add_tag_modal = TagSearchModal(self.lib, list(exclude_ids)) self.add_tag_modal.tsp.tag_chosen.connect(lambda x: self.add_parent_tag_callback(x)) self.parent_tags_add_button.clicked.connect(self.add_tag_modal.show) @@ -562,6 +563,13 @@ def set_tag(self, tag: Tag): logger.info("[BuildTagPanel] Setting Tag", tag=tag) self.tag = tag + if tag.id is not None: + exclude_ids: set[int] = {tag.id} + exclude_ids.update(self.lib.get_tag_descendants(tag.id)) + self.add_tag_modal.tsp.exclude = list(exclude_ids) + else: + self.add_tag_modal.tsp.exclude = [] + self.name_field.setText(tag.name) self.shorthand_field.setText(tag.shorthand or "") diff --git a/src/tagstudio/qt/mixed/field_containers.py b/src/tagstudio/qt/mixed/field_containers.py index ae8df9107..24219c9e6 100644 --- a/src/tagstudio/qt/mixed/field_containers.py +++ b/src/tagstudio/qt/mixed/field_containers.py @@ -166,7 +166,7 @@ def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: "Character" -> "Johnny Bravo", "TV" -> Johnny Bravo" """ - loop_cutoff = 1024 # Used for stopping the while loop + visited_tags: set[int] = set() hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags) categories: dict[Tag | None, set[Tag]] = {None: set()} @@ -179,16 +179,15 @@ def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]: has_category_parent = False parent_tags = tag.parent_tags - loop_counter = 0 - while len(parent_tags) > 0: - # NOTE: This is for preventing infinite loops in the event a tag is parented - # to itself cyclically. - loop_counter += 1 - if loop_counter >= loop_cutoff: - break + visited_tags.clear() + visited_tags.add(tag.id) + while len(parent_tags) > 0: grandparent_tags: set[Tag] = set() for parent_tag in parent_tags: + if parent_tag.id in visited_tags: + continue + visited_tags.add(parent_tag.id) if parent_tag in categories: categories[parent_tag].add(tag) has_category_parent = True