Skip to content

castlecraft_engineer.abstractions.repository

castlecraft_engineer.abstractions.repository

AggregateNotFoundError

Bases: RepositoryError

Aggregate not found.

Source code in src/castlecraft_engineer/abstractions/repository.py
class AggregateNotFoundError(RepositoryError):
    """Aggregate not found."""

    def __init__(self, aggregate_id: Any):
        super().__init__(f"Aggregate ID '{aggregate_id}' not found.")
        self.aggregate_id = aggregate_id

AggregateRepository

Bases: _RepositoryBase[TAggregateId, TAggregate, TSQLModel]

Generic repository mapping Aggregates to SQLModels.

Handles persistence for Aggregates, potentially spanning multiple related SQLModels if relationships and cascades are configured correctly on the SQLModel classes themselves.

Relies on the SQLAlchemy Unit of Work pattern managed outside the repository (e.g., in a Command Handler or Application Service).

Source code in src/castlecraft_engineer/abstractions/repository.py
class AggregateRepository(
    _RepositoryBase[
        TAggregateId,
        TAggregate,
        TSQLModel,
    ],
):
    """
    Generic repository mapping Aggregates to SQLModels.

    Handles persistence for Aggregates, potentially spanning multiple
    related SQLModels if relationships and cascades are configured
    correctly on the SQLModel classes themselves.

    Relies on the SQLAlchemy Unit of Work pattern managed outside
    the repository (e.g., in a Command Handler or Application Service).
    """

    def __init__(
        self,
        aggregate_cls: Type[TAggregate],
        model_cls: Type[TSQLModel],
    ):
        super().__init__(aggregate_cls, model_cls)

    def get_by_id(
        self,
        session: Session,
        id: TAggregateId,
    ) -> Optional[TAggregate]:
        """
        Retrieves an Aggregate by its ID.
        Loads the root model and relies on ORM relationship loading
        (eager or lazy) for related data.
        """
        self._logger.debug(f"Getting aggregate by ID: {id}")
        model = session.get(self.model_cls, id)
        if not model:
            self._logger.warning(
                f"Aggregate ID {id} not found in database for root model {self.model_cls.__name__}."  # noqa: E501
            )
            return None

        self._logger.debug(
            f"Found root model for ID {id}. Mapping to aggregate.",
        )
        return self._map_model_to_aggregate(model)

    def save(self, session: Session, aggregate: TAggregate) -> TAggregate:
        """
        Persists Aggregate state (handles create or update).

        Relies on the provided session being managed externally (Unit of Work).
        Adds new aggregates to the session or updates existing ones.
        Handles optimistic concurrency checking based on the 'version' field
        of the root model.

        For updates involving related models (multiple tables), ensure
        SQLModel relationships and cascade options are correctly configured.
        The mapping logic (`_map_aggregate_to_model`) should update the
        state of the model instances, and the SQLAlchemy session flush
        will handle persisting those changes.
        """
        if not isinstance(aggregate, self.aggregate_cls):
            raise TypeError(
                f"Input must be an instance of {self.aggregate_cls.__name__}, got {type(aggregate).__name__}"  # noqa: E501
            )

        agg_id = aggregate.id
        current_agg_version = aggregate.version
        is_new = current_agg_version == -1

        self._logger.debug(
            f"Attempting to save aggregate ID: {agg_id}, Current Aggregate Version: {current_agg_version}, Is New: {is_new}"  # noqa: E501
        )

        model: Optional[TSQLModel] = None  # Declare model once with its optional type
        try:
            if is_new:
                self._logger.debug(
                    f"Aggregate ID {agg_id} is new. Creating root model."
                )
                model = self.model_cls(id=agg_id, version=-1)
                self._map_aggregate_to_model(aggregate, model)
                # Add to session first.
                session.add(model)
                # If add was successful, update versions.
                aggregate._increment_version()
                model.version = aggregate.version
                self._logger.info(
                    f"Added new aggregate ID: {agg_id} to session with version {aggregate.version}. Commit required externally."  # noqa: E501
                )
            else:
                self._logger.debug(
                    f"Aggregate ID {agg_id} exists. Loading root model for update."  # noqa: E501
                )
                model = session.get(self.model_cls, agg_id)

                if not model:
                    self._logger.error(
                        f"Aggregate ID {agg_id} not found in database during update attempt."  # noqa: E501
                    )
                    raise AggregateNotFoundError(agg_id)

                db_version = getattr(model, "version", None)
                if db_version is None:
                    raise RepositoryError(
                        f"Database model {self.model_cls.__name__} for ID {agg_id} is missing 'version'. Cannot perform optimistic lock."  # noqa: E501
                    )

                if db_version != current_agg_version:
                    self._logger.warning(
                        f"Optimistic lock failed for ID: {agg_id}. Expected DB version {current_agg_version}, found {db_version}."  # noqa: E501
                    )
                    raise OptimisticConcurrencyError(
                        agg_id, current_agg_version, db_version
                    )
                self._logger.debug(
                    f"Optimistic lock check passed for ID {agg_id} (Version: {current_agg_version})."  # noqa: E501
                )

                self._map_aggregate_to_model(aggregate, model)
                aggregate._increment_version()
                model.version = aggregate.version
                self._logger.info(
                    f"Updated aggregate ID: {agg_id} in session to version {aggregate.version}. Commit required externally."  # noqa: E501
                )
            return aggregate

        except (OptimisticConcurrencyError, AggregateNotFoundError) as e:
            self._logger.error(f"Save failed for aggregate ID {agg_id}: {e!s}")
            raise e
        except Exception as e:
            self._logger.exception(
                f"Unexpected error during save for aggregate ID {agg_id}: {e!s}"  # noqa: E501
            )
            raise RepositoryError(
                f"Save failed for aggregate ID {agg_id}: {e!s}"
            ) from e

    def delete_by_id(self, session: Session, id: TAggregateId) -> bool:
        """
        Deletes an Aggregate by its ID.
        Relies on the ORM's cascade delete configuration for related models.
        """
        self._logger.debug(f"Attempting to delete aggregate ID: {id}")
        model = session.get(self.model_cls, id)
        if not model:
            self._logger.warning(f"Aggregate ID: {id} not found for deletion.")
            return False

        try:
            session.delete(model)
            self._logger.info(
                f"Marked aggregate ID: {id} for deletion in session. Commit required externally."  # noqa: E501
            )
            return True
        except Exception as e:
            self._logger.exception(
                f"Delete failed for aggregate ID {id}: {e!s}",
            )
            raise RepositoryError(f"Delete failed for {id}: {e!s}") from e

delete_by_id(session, id)

Deletes an Aggregate by its ID. Relies on the ORM's cascade delete configuration for related models.

Source code in src/castlecraft_engineer/abstractions/repository.py
def delete_by_id(self, session: Session, id: TAggregateId) -> bool:
    """
    Deletes an Aggregate by its ID.
    Relies on the ORM's cascade delete configuration for related models.
    """
    self._logger.debug(f"Attempting to delete aggregate ID: {id}")
    model = session.get(self.model_cls, id)
    if not model:
        self._logger.warning(f"Aggregate ID: {id} not found for deletion.")
        return False

    try:
        session.delete(model)
        self._logger.info(
            f"Marked aggregate ID: {id} for deletion in session. Commit required externally."  # noqa: E501
        )
        return True
    except Exception as e:
        self._logger.exception(
            f"Delete failed for aggregate ID {id}: {e!s}",
        )
        raise RepositoryError(f"Delete failed for {id}: {e!s}") from e

get_by_id(session, id)

Retrieves an Aggregate by its ID. Loads the root model and relies on ORM relationship loading (eager or lazy) for related data.

Source code in src/castlecraft_engineer/abstractions/repository.py
def get_by_id(
    self,
    session: Session,
    id: TAggregateId,
) -> Optional[TAggregate]:
    """
    Retrieves an Aggregate by its ID.
    Loads the root model and relies on ORM relationship loading
    (eager or lazy) for related data.
    """
    self._logger.debug(f"Getting aggregate by ID: {id}")
    model = session.get(self.model_cls, id)
    if not model:
        self._logger.warning(
            f"Aggregate ID {id} not found in database for root model {self.model_cls.__name__}."  # noqa: E501
        )
        return None

    self._logger.debug(
        f"Found root model for ID {id}. Mapping to aggregate.",
    )
    return self._map_model_to_aggregate(model)

save(session, aggregate)

Persists Aggregate state (handles create or update).

Relies on the provided session being managed externally (Unit of Work). Adds new aggregates to the session or updates existing ones. Handles optimistic concurrency checking based on the 'version' field of the root model.

For updates involving related models (multiple tables), ensure SQLModel relationships and cascade options are correctly configured. The mapping logic (_map_aggregate_to_model) should update the state of the model instances, and the SQLAlchemy session flush will handle persisting those changes.

Source code in src/castlecraft_engineer/abstractions/repository.py
def save(self, session: Session, aggregate: TAggregate) -> TAggregate:
    """
    Persists Aggregate state (handles create or update).

    Relies on the provided session being managed externally (Unit of Work).
    Adds new aggregates to the session or updates existing ones.
    Handles optimistic concurrency checking based on the 'version' field
    of the root model.

    For updates involving related models (multiple tables), ensure
    SQLModel relationships and cascade options are correctly configured.
    The mapping logic (`_map_aggregate_to_model`) should update the
    state of the model instances, and the SQLAlchemy session flush
    will handle persisting those changes.
    """
    if not isinstance(aggregate, self.aggregate_cls):
        raise TypeError(
            f"Input must be an instance of {self.aggregate_cls.__name__}, got {type(aggregate).__name__}"  # noqa: E501
        )

    agg_id = aggregate.id
    current_agg_version = aggregate.version
    is_new = current_agg_version == -1

    self._logger.debug(
        f"Attempting to save aggregate ID: {agg_id}, Current Aggregate Version: {current_agg_version}, Is New: {is_new}"  # noqa: E501
    )

    model: Optional[TSQLModel] = None  # Declare model once with its optional type
    try:
        if is_new:
            self._logger.debug(
                f"Aggregate ID {agg_id} is new. Creating root model."
            )
            model = self.model_cls(id=agg_id, version=-1)
            self._map_aggregate_to_model(aggregate, model)
            # Add to session first.
            session.add(model)
            # If add was successful, update versions.
            aggregate._increment_version()
            model.version = aggregate.version
            self._logger.info(
                f"Added new aggregate ID: {agg_id} to session with version {aggregate.version}. Commit required externally."  # noqa: E501
            )
        else:
            self._logger.debug(
                f"Aggregate ID {agg_id} exists. Loading root model for update."  # noqa: E501
            )
            model = session.get(self.model_cls, agg_id)

            if not model:
                self._logger.error(
                    f"Aggregate ID {agg_id} not found in database during update attempt."  # noqa: E501
                )
                raise AggregateNotFoundError(agg_id)

            db_version = getattr(model, "version", None)
            if db_version is None:
                raise RepositoryError(
                    f"Database model {self.model_cls.__name__} for ID {agg_id} is missing 'version'. Cannot perform optimistic lock."  # noqa: E501
                )

            if db_version != current_agg_version:
                self._logger.warning(
                    f"Optimistic lock failed for ID: {agg_id}. Expected DB version {current_agg_version}, found {db_version}."  # noqa: E501
                )
                raise OptimisticConcurrencyError(
                    agg_id, current_agg_version, db_version
                )
            self._logger.debug(
                f"Optimistic lock check passed for ID {agg_id} (Version: {current_agg_version})."  # noqa: E501
            )

            self._map_aggregate_to_model(aggregate, model)
            aggregate._increment_version()
            model.version = aggregate.version
            self._logger.info(
                f"Updated aggregate ID: {agg_id} in session to version {aggregate.version}. Commit required externally."  # noqa: E501
            )
        return aggregate

    except (OptimisticConcurrencyError, AggregateNotFoundError) as e:
        self._logger.error(f"Save failed for aggregate ID {agg_id}: {e!s}")
        raise e
    except Exception as e:
        self._logger.exception(
            f"Unexpected error during save for aggregate ID {agg_id}: {e!s}"  # noqa: E501
        )
        raise RepositoryError(
            f"Save failed for aggregate ID {agg_id}: {e!s}"
        ) from e

AsyncAggregateRepository

Bases: _RepositoryBase[TAggregateId, TAggregate, TSQLModel]

Generic asynchronous repository mapping Aggregates to SQLModels using AsyncSession.

Source code in src/castlecraft_engineer/abstractions/repository.py
class AsyncAggregateRepository(
    _RepositoryBase[
        TAggregateId,
        TAggregate,
        TSQLModel,
    ],
):
    """
    Generic asynchronous repository mapping
    Aggregates to SQLModels using AsyncSession.
    """

    def __init__(
        self,
        aggregate_cls: Type[TAggregate],
        model_cls: Type[TSQLModel],
    ):
        super().__init__(aggregate_cls, model_cls)

    # _map_model_to_aggregate and _map_aggregate_to_model are inherited

    async def get_by_id(
        self,
        session: AsyncSession,
        id: TAggregateId,
    ) -> Optional[TAggregate]:
        """
        Asynchronously retrieves an Aggregate by its ID using AsyncSession.
        """
        self._logger.debug(f"Getting aggregate by ID: {id}")
        model: Optional[TSQLModel] = await session.get(self.model_cls, id)
        if not model:
            self._logger.warning(
                f"Aggregate ID {id} not found in database for root model {self.model_cls.__name__}."  # noqa: 501
            )
            return None

        self._logger.debug(
            f"Found root model for ID {id}. Mapping to aggregate.",
        )
        return self._map_model_to_aggregate(model)

    async def save(
        self,
        session: AsyncSession,
        aggregate: TAggregate,
    ) -> TAggregate:
        """
        Asynchronously persists Aggregate state using AsyncSession.
        """
        if not isinstance(aggregate, self.aggregate_cls):
            raise TypeError(
                f"Input must be an instance of {self.aggregate_cls.__name__}, got {type(aggregate).__name__}"  # noqa: 501
            )

        agg_id = aggregate.id
        current_agg_version = aggregate.version
        is_new = current_agg_version == -1

        self._logger.debug(
            f"Attempting to save aggregate ID: {agg_id}, Current Aggregate Version: {current_agg_version}, Is New: {is_new}"  # noqa: 501
        )

        model: Optional[TSQLModel] = None  # Declare model once with its optional type
        try:
            if is_new:
                self._logger.debug(
                    f"Aggregate ID {agg_id} is new. Creating root model."
                )
                model = self.model_cls(id=agg_id, version=-1)
                self._map_aggregate_to_model(aggregate, model)
                # Add to session first.
                session.add(model)
                # If add was successful, update versions.
                aggregate._increment_version()
                model.version = aggregate.version
                self._logger.info(
                    f"Added new aggregate ID: {agg_id} to session with version {aggregate.version}. Commit required externally."  # noqa: 501
                )
            else:
                self._logger.debug(
                    f"Aggregate ID {agg_id} exists. Loading root model for update."  # noqa: E501
                )
                model = await session.get(self.model_cls, agg_id)

                if not model:
                    self._logger.error(
                        f"Aggregate ID {agg_id} not found in database during update attempt."  # noqa: 501
                    )
                    raise AggregateNotFoundError(agg_id)

                db_version = getattr(model, "version", None)
                if db_version is None:
                    raise RepositoryError(
                        f"Database model {self.model_cls.__name__} for ID {agg_id} is missing 'version'. Cannot perform optimistic lock."  # noqa: 501
                    )

                if db_version != current_agg_version:
                    self._logger.warning(
                        f"Optimistic lock failed for ID: {agg_id}. Expected DB version {current_agg_version}, found {db_version}."  # noqa: 501
                    )
                    raise OptimisticConcurrencyError(
                        agg_id, current_agg_version, db_version
                    )
                self._logger.debug(
                    f"Optimistic lock check passed for ID {agg_id} (Version: {current_agg_version})."  # noqa: 501
                )

                self._map_aggregate_to_model(aggregate, model)
                aggregate._increment_version()
                model.version = aggregate.version

                self._logger.info(
                    f"Updated aggregate ID: {agg_id} in session to version {aggregate.version}. Commit required externally."  # noqa: 501
                )
            return aggregate

        except (OptimisticConcurrencyError, AggregateNotFoundError) as e:
            self._logger.error(f"Save failed for aggregate ID {agg_id}: {e!s}")
            raise e
        except Exception as e:
            self._logger.exception(
                f"Unexpected error during save for aggregate ID {agg_id}: {e!s}"  # noqa: 501
            )
            raise RepositoryError(
                f"Save failed for aggregate ID {agg_id}: {e!s}"
            ) from e

    async def delete_by_id(
        self,
        session: AsyncSession,
        id: TAggregateId,
    ) -> bool:
        """
        Asynchronously deletes an Aggregate by its ID using AsyncSession.
        """
        self._logger.debug(f"Attempting to delete aggregate ID: {id}")
        model: Optional[TSQLModel] = await session.get(self.model_cls, id)
        if not model:
            self._logger.warning(f"Aggregate ID: {id} not found for deletion.")
            return False

        try:
            # For async, session.delete is synchronous, flush/commit is async
            # However, SQLModel/SQLAlchemy handles this correctly.
            # The actual delete operation is usually part of the flush.
            await session.delete(model)  # This marks the object for deletion.
            self._logger.info(
                f"Marked aggregate ID: {id} for deletion in session. Commit/flush required externally."  # noqa: 501
            )
            return True
        except Exception as e:
            self._logger.exception(
                f"Delete failed for aggregate ID {id}: {e!s}",
            )
            raise RepositoryError(f"Delete failed for {id}: {e!s}") from e

delete_by_id(session, id) async

Asynchronously deletes an Aggregate by its ID using AsyncSession.

Source code in src/castlecraft_engineer/abstractions/repository.py
async def delete_by_id(
    self,
    session: AsyncSession,
    id: TAggregateId,
) -> bool:
    """
    Asynchronously deletes an Aggregate by its ID using AsyncSession.
    """
    self._logger.debug(f"Attempting to delete aggregate ID: {id}")
    model: Optional[TSQLModel] = await session.get(self.model_cls, id)
    if not model:
        self._logger.warning(f"Aggregate ID: {id} not found for deletion.")
        return False

    try:
        # For async, session.delete is synchronous, flush/commit is async
        # However, SQLModel/SQLAlchemy handles this correctly.
        # The actual delete operation is usually part of the flush.
        await session.delete(model)  # This marks the object for deletion.
        self._logger.info(
            f"Marked aggregate ID: {id} for deletion in session. Commit/flush required externally."  # noqa: 501
        )
        return True
    except Exception as e:
        self._logger.exception(
            f"Delete failed for aggregate ID {id}: {e!s}",
        )
        raise RepositoryError(f"Delete failed for {id}: {e!s}") from e

get_by_id(session, id) async

Asynchronously retrieves an Aggregate by its ID using AsyncSession.

Source code in src/castlecraft_engineer/abstractions/repository.py
async def get_by_id(
    self,
    session: AsyncSession,
    id: TAggregateId,
) -> Optional[TAggregate]:
    """
    Asynchronously retrieves an Aggregate by its ID using AsyncSession.
    """
    self._logger.debug(f"Getting aggregate by ID: {id}")
    model: Optional[TSQLModel] = await session.get(self.model_cls, id)
    if not model:
        self._logger.warning(
            f"Aggregate ID {id} not found in database for root model {self.model_cls.__name__}."  # noqa: 501
        )
        return None

    self._logger.debug(
        f"Found root model for ID {id}. Mapping to aggregate.",
    )
    return self._map_model_to_aggregate(model)

save(session, aggregate) async

Asynchronously persists Aggregate state using AsyncSession.

Source code in src/castlecraft_engineer/abstractions/repository.py
async def save(
    self,
    session: AsyncSession,
    aggregate: TAggregate,
) -> TAggregate:
    """
    Asynchronously persists Aggregate state using AsyncSession.
    """
    if not isinstance(aggregate, self.aggregate_cls):
        raise TypeError(
            f"Input must be an instance of {self.aggregate_cls.__name__}, got {type(aggregate).__name__}"  # noqa: 501
        )

    agg_id = aggregate.id
    current_agg_version = aggregate.version
    is_new = current_agg_version == -1

    self._logger.debug(
        f"Attempting to save aggregate ID: {agg_id}, Current Aggregate Version: {current_agg_version}, Is New: {is_new}"  # noqa: 501
    )

    model: Optional[TSQLModel] = None  # Declare model once with its optional type
    try:
        if is_new:
            self._logger.debug(
                f"Aggregate ID {agg_id} is new. Creating root model."
            )
            model = self.model_cls(id=agg_id, version=-1)
            self._map_aggregate_to_model(aggregate, model)
            # Add to session first.
            session.add(model)
            # If add was successful, update versions.
            aggregate._increment_version()
            model.version = aggregate.version
            self._logger.info(
                f"Added new aggregate ID: {agg_id} to session with version {aggregate.version}. Commit required externally."  # noqa: 501
            )
        else:
            self._logger.debug(
                f"Aggregate ID {agg_id} exists. Loading root model for update."  # noqa: E501
            )
            model = await session.get(self.model_cls, agg_id)

            if not model:
                self._logger.error(
                    f"Aggregate ID {agg_id} not found in database during update attempt."  # noqa: 501
                )
                raise AggregateNotFoundError(agg_id)

            db_version = getattr(model, "version", None)
            if db_version is None:
                raise RepositoryError(
                    f"Database model {self.model_cls.__name__} for ID {agg_id} is missing 'version'. Cannot perform optimistic lock."  # noqa: 501
                )

            if db_version != current_agg_version:
                self._logger.warning(
                    f"Optimistic lock failed for ID: {agg_id}. Expected DB version {current_agg_version}, found {db_version}."  # noqa: 501
                )
                raise OptimisticConcurrencyError(
                    agg_id, current_agg_version, db_version
                )
            self._logger.debug(
                f"Optimistic lock check passed for ID {agg_id} (Version: {current_agg_version})."  # noqa: 501
            )

            self._map_aggregate_to_model(aggregate, model)
            aggregate._increment_version()
            model.version = aggregate.version

            self._logger.info(
                f"Updated aggregate ID: {agg_id} in session to version {aggregate.version}. Commit required externally."  # noqa: 501
            )
        return aggregate

    except (OptimisticConcurrencyError, AggregateNotFoundError) as e:
        self._logger.error(f"Save failed for aggregate ID {agg_id}: {e!s}")
        raise e
    except Exception as e:
        self._logger.exception(
            f"Unexpected error during save for aggregate ID {agg_id}: {e!s}"  # noqa: 501
        )
        raise RepositoryError(
            f"Save failed for aggregate ID {agg_id}: {e!s}"
        ) from e

OptimisticConcurrencyError

Bases: RepositoryError, StaleDataError

Optimistic concurrency conflict.

Source code in src/castlecraft_engineer/abstractions/repository.py
class OptimisticConcurrencyError(RepositoryError, StaleDataError):
    """Optimistic concurrency conflict."""

    def __init__(
        self,
        aggregate_id: Any,
        expected_version: int,
        actual_version: Optional[int] = None,
    ):
        msg = (
            f"Concurrency error for ID '{aggregate_id}'. "
            f"Expected version {expected_version}, "
            f"but found version {actual_version} in database."
        )
        super().__init__(msg)
        self.aggregate_id = aggregate_id
        self.expected_version = expected_version
        self.actual_version = actual_version

RepositoryError

Bases: Exception

Base repository error.

Source code in src/castlecraft_engineer/abstractions/repository.py
class RepositoryError(Exception):
    """Base repository error."""