Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ poetry.toml
.ruff_cache/

# LSP config files
pyrightconfig.json

# Syncthing
.stfolder/
Expand Down Expand Up @@ -235,7 +234,7 @@ compile_commands.json
*_qmlcache.qrc

### VisualStudioCode ###
.vscode/*
/.vscode
# !.vscode/settings.json
# !.vscode/tasks.json
# !.vscode/launch.json
Expand Down Expand Up @@ -268,3 +267,4 @@ TagStudio.ini

result
result-*
uv.lock
61 changes: 52 additions & 9 deletions src/tagstudio/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,12 +1685,47 @@ def get_tag_hierarchy(self, tag_ids: Iterable[int]) -> dict[int, Tag]:

return all_tags

def add_parent_tag(self, parent_id: int, child_id: int) -> bool:
def get_tag_descendants(self, tag_id: int, session: Session | None = None) -> set[int]:
"""Return ids of every tag that lists `tag_id` as an ancestor."""
owns_session = False
if session is None:
session = Session(self.engine)
owns_session = True

try:
descendants: set[int] = set()
frontier: set[int] = {tag_id}

while frontier:
stmt = select(TagParent.child_id).where(TagParent.parent_id.in_(frontier))
children = set(session.scalars(stmt).all())
children -= descendants
children.discard(tag_id)
descendants.update(children)
frontier = children

return descendants
finally:
if owns_session:
session.close()

def _would_create_parent_cycle(self, parent_id: int, child_id: int, session: Session) -> bool:
if parent_id == child_id:
return False
return True

return parent_id in self.get_tag_descendants(child_id, session=session)

def add_parent_tag(self, parent_id: int, child_id: int) -> bool:
# open session and save as parent tag
with Session(self.engine) as session:
if self._would_create_parent_cycle(parent_id, child_id, session):
logger.warning(
"[Library][add_parent_tag] Prevented cyclical parent assignment",
parent_id=parent_id,
child_id=child_id,
)
return False

parent_tag = TagParent(
parent_id=parent_id,
child_id=child_id,
Expand Down Expand Up @@ -1814,10 +1849,10 @@ def update_aliases(
session.add(alias)

def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session: Session):
if tag.id in parent_ids:
parent_ids.remove(tag.id)
new_parent_ids: set[int] = set(parent_ids)
new_parent_ids.discard(tag.id)

if tag.disambiguation_id not in parent_ids:
if tag.disambiguation_id not in new_parent_ids:
tag.disambiguation_id = None

# load all tag's parent tags to know which to remove
Expand All @@ -1826,14 +1861,22 @@ def update_parent_tags(self, tag: Tag, parent_ids: list[int] | set[int], session
).all()

for parent_tag in prev_parent_tags:
if parent_tag.parent_id not in parent_ids:
if parent_tag.parent_id not in new_parent_ids:
session.delete(parent_tag)
else:
# no change, remove from list
parent_ids.remove(parent_tag.parent_id)
new_parent_ids.remove(parent_tag.parent_id)

# create remaining items
for parent_id in list(new_parent_ids):
if self._would_create_parent_cycle(parent_id, tag.id, session):
logger.warning(
"[Library][update_parent_tags] Prevented cyclical parent assignment",
parent_id=parent_id,
child_id=tag.id,
)
continue

# create remaining items
for parent_id in parent_ids:
# add new parent tag
parent_tag = TagParent(
parent_id=parent_id,
Expand Down
16 changes: 12 additions & 4 deletions src/tagstudio/qt/mixed/build_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,12 @@ def __init__(self, library: Library, tag: Tag | None = None) -> None:
self.parent_tags_add_button.setText("+")
self.parent_tags_layout.addWidget(self.parent_tags_add_button)

exclude_ids: list[int] = list()
if tag is not None:
exclude_ids.append(tag.id)
exclude_ids: set[int] = set()
if tag is not None and tag.id is not None:
exclude_ids.add(tag.id)
exclude_ids.update(self.lib.get_tag_descendants(tag.id))

self.add_tag_modal = TagSearchModal(self.lib, exclude_ids)
self.add_tag_modal = TagSearchModal(self.lib, list(exclude_ids))
self.add_tag_modal.tsp.tag_chosen.connect(lambda x: self.add_parent_tag_callback(x))
self.parent_tags_add_button.clicked.connect(self.add_tag_modal.show)

Expand Down Expand Up @@ -562,6 +563,13 @@ def set_tag(self, tag: Tag):
logger.info("[BuildTagPanel] Setting Tag", tag=tag)
self.tag = tag

if tag.id is not None:
exclude_ids: set[int] = {tag.id}
exclude_ids.update(self.lib.get_tag_descendants(tag.id))
self.add_tag_modal.tsp.exclude = list(exclude_ids)
else:
self.add_tag_modal.tsp.exclude = []

self.name_field.setText(tag.name)
self.shorthand_field.setText(tag.shorthand or "")

Expand Down
15 changes: 7 additions & 8 deletions src/tagstudio/qt/mixed/field_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]:
"Character" -> "Johnny Bravo",
"TV" -> Johnny Bravo"
"""
loop_cutoff = 1024 # Used for stopping the while loop
visited_tags: set[int] = set()

hierarchy_tags = self.lib.get_tag_hierarchy(t.id for t in tags)
categories: dict[Tag | None, set[Tag]] = {None: set()}
Expand All @@ -179,16 +179,15 @@ def get_tag_categories(self, tags: set[Tag]) -> dict[Tag | None, set[Tag]]:
has_category_parent = False
parent_tags = tag.parent_tags

loop_counter = 0
while len(parent_tags) > 0:
# NOTE: This is for preventing infinite loops in the event a tag is parented
# to itself cyclically.
loop_counter += 1
if loop_counter >= loop_cutoff:
break
visited_tags.clear()
visited_tags.add(tag.id)

while len(parent_tags) > 0:
grandparent_tags: set[Tag] = set()
for parent_tag in parent_tags:
if parent_tag.id in visited_tags:
continue
visited_tags.add(parent_tag.id)
if parent_tag in categories:
categories[parent_tag].add(tag)
has_category_parent = True
Expand Down