Source code for backend.app.utils.query_utils
from sqlalchemy.orm import aliased
from sqlmodel import SQLModel, select
from backend.app.models.voc_subclass import VocSubclass
[docs]
def get_all_voc_subclass_ancestor_ids_by_voc_subclass_id(
voc_subclass_id, session
) -> list[int]:
"""
Retrieve all ancestor IDs for a given VocSubclass ID using a recursive CTE.
:param voc_subclass_id: The ID of the VocSubclass
for which to find ancestors.
:type voc_subclass_id: int
:param session: SQLAlchemy session used to execute the query.
:type session: sqlalchemy.orm.session.Session
:return: A list of ancestor IDs of the specified VocSubclass, including all
direct and indirect parents.
:rtype: list[int]
"""
return __get_all_ancestors__(
VocSubclass, "parent_voc_category_id", voc_subclass_id, session
)
def __get_all_ancestors__(
object_class: SQLModel, parent_id_field_name: str, object_id, session
) -> list[int]:
"""
Find all ancestor IDs for a given object using a recursive CTE.
:param session: SQLAlchemy session
:type session: sqlalchemy.orm.session.Session
:param object_class: The SQLAlchemy model class to query
:type object_class: SQLModel
:param parent_id_field_name: The name of the field
that references the parent ID
:type parent_id_field_name: str
:param object_id: ID of the starting object to find the ancestors for
:type object_id: int
:return: List of ancestor IDs
:rtype: list[int]
"""
# TODO: this function should be more robust and implement some
# error handling, e.g. not matching parent_id_field_name fields...
# I'm not sure if this level of abstraction is even beneficial, maybe just
# create new concrete functions if more tables get hierarchical.
# Create the base CTE to start with the provided object ID
base_cte = (
select(object_class)
.where(object_class.id == object_id)
.cte(name="base_cte", recursive=True)
)
# Aliases for the base_cte to reference later
cte_alias = aliased(base_cte, name="cte_alias")
# Define the recursive part of the CTE, creating a list of all parents
recursive_cte = base_cte.union_all(
select(object_class)
# TODO: rn this only works if the object_class field to join by
# is called id. This could be designed more abstract
.join(cte_alias, cte_alias.c[parent_id_field_name] == object_class.id)
)
# Execute the query and retrieve all ancestor IDs
ancestor_ids = session.execute(select(recursive_cte.c.id)).scalars().all()
return ancestor_ids