Skip to content

castlecraft_engineer.testing.aggregate

castlecraft_engineer.testing.aggregate

BaseAggregateRepositoryTest

Bases: Generic[T, M, Repo]

Source code in src/castlecraft_engineer/testing/aggregate.py
class BaseAggregateRepositoryTest(Generic[T, M, Repo]):
    repository_class: Optional[Type[Repo]] = None
    aggregate_class: Optional[Type[T]] = None
    orm_model_class: Optional[Type[M]] = None

    @pytest.fixture
    def mock_session(self) -> MagicMock:
        return MagicMock(spec=Session)

    @pytest.fixture
    def aggregate_id(self) -> uuid.UUID:
        return uuid.uuid4()

    @pytest.fixture
    def sample_aggregate(self, aggregate_id: uuid.UUID) -> T:
        if not self.aggregate_class:
            pytest.skip("aggregate_class not set")
        # Assumes Aggregate takes id in __init__
        return self.aggregate_class(id=aggregate_id)  # type: ignore[misc]

    @pytest.fixture
    def sample_orm_model(self) -> M:
        if not self.orm_model_class:
            pytest.skip("orm_model_class not set")
        # Create a mock or real instance as needed
        return MagicMock(spec=self.orm_model_class)

    @pytest.fixture
    def repository_instance(
        self,
        mock_session: MagicMock,
    ) -> Repo:
        if not self.repository_class:
            pytest.skip("repository_class not set")
        if not self.aggregate_class:
            pytest.skip("aggregate_class not set for repository test")
        if not self.orm_model_class:
            pytest.skip("orm_model_class not set for repository test")

        # Pass both required arguments to the constructor
        return self.repository_class(
            aggregate_cls=self.aggregate_class,
            model_cls=self.orm_model_class,
        )

    def setup_get_by_id_mock(
        self,
        mock_session: MagicMock,
        orm_model_instance: Optional[M],
    ):
        """Configure mock session for get_by_id."""
        mock_session.get.return_value = orm_model_instance

    def assert_session_add_called(
        self,
        mock_session: MagicMock,
        expected_model: Any,
    ):
        """Verify session.add was called correctly."""
        mock_session.add.assert_called_once_with(expected_model)

    def assert_session_commit_called(self, mock_session: MagicMock):
        """Verify session.commit was called."""
        # Note: Commit is often called outside the repo
        mock_session.commit.assert_called_once()

    def assert_session_refresh_called(
        self, mock_session: MagicMock, expected_model: Any
    ):
        """Verify session.refresh was called."""
        mock_session.refresh.assert_called_once_with(expected_model)

    def assert_session_delete_called(
        self, mock_session: MagicMock, expected_model: Any
    ):
        """Verify session.delete was called correctly."""
        mock_session.delete.assert_called_once_with(expected_model)

assert_session_add_called(mock_session, expected_model)

Verify session.add was called correctly.

Source code in src/castlecraft_engineer/testing/aggregate.py
def assert_session_add_called(
    self,
    mock_session: MagicMock,
    expected_model: Any,
):
    """Verify session.add was called correctly."""
    mock_session.add.assert_called_once_with(expected_model)

assert_session_commit_called(mock_session)

Verify session.commit was called.

Source code in src/castlecraft_engineer/testing/aggregate.py
def assert_session_commit_called(self, mock_session: MagicMock):
    """Verify session.commit was called."""
    # Note: Commit is often called outside the repo
    mock_session.commit.assert_called_once()

assert_session_delete_called(mock_session, expected_model)

Verify session.delete was called correctly.

Source code in src/castlecraft_engineer/testing/aggregate.py
def assert_session_delete_called(
    self, mock_session: MagicMock, expected_model: Any
):
    """Verify session.delete was called correctly."""
    mock_session.delete.assert_called_once_with(expected_model)

assert_session_refresh_called(mock_session, expected_model)

Verify session.refresh was called.

Source code in src/castlecraft_engineer/testing/aggregate.py
def assert_session_refresh_called(
    self, mock_session: MagicMock, expected_model: Any
):
    """Verify session.refresh was called."""
    mock_session.refresh.assert_called_once_with(expected_model)

setup_get_by_id_mock(mock_session, orm_model_instance)

Configure mock session for get_by_id.

Source code in src/castlecraft_engineer/testing/aggregate.py
def setup_get_by_id_mock(
    self,
    mock_session: MagicMock,
    orm_model_instance: Optional[M],
):
    """Configure mock session for get_by_id."""
    mock_session.get.return_value = orm_model_instance

BaseAsyncAggregateRepositoryTest

Bases: Generic[T, M, AsyncRepo]

Base class for testing AsyncAggregateRepository implementations.

Provides pytest fixtures and helper methods for testing asynchronous repository operations. Subclasses should define repository_class, aggregate_class, and orm_model_class.

Source code in src/castlecraft_engineer/testing/aggregate.py
class BaseAsyncAggregateRepositoryTest(Generic[T, M, AsyncRepo]):
    """
    Base class for testing AsyncAggregateRepository implementations.

    Provides pytest fixtures and helper methods for testing asynchronous
    repository operations. Subclasses should define `repository_class`,
    `aggregate_class`, and `orm_model_class`.
    """

    repository_class: Optional[Type[AsyncRepo]] = None
    aggregate_class: Optional[Type[T]] = None
    orm_model_class: Optional[Type[M]] = None

    @pytest.fixture
    def mock_async_session(self) -> AsyncMock:
        """Provides an AsyncMock instance simulating an AsyncSession."""
        return AsyncMock(spec=AsyncSession)

    @pytest.fixture
    def aggregate_id(self) -> uuid.UUID:
        return uuid.uuid4()

    @pytest.fixture
    def sample_aggregate(self, aggregate_id: uuid.UUID) -> T:
        if not self.aggregate_class:
            pytest.skip("aggregate_class not set")
        # Assumes Aggregate takes id in __init__
        return self.aggregate_class(id=aggregate_id)  # type: ignore[misc]

    @pytest.fixture
    def sample_orm_model(self) -> M:
        if not self.orm_model_class:
            pytest.skip("orm_model_class not set")
        # Create a mock or real instance as needed
        return AsyncMock(spec=self.orm_model_class)

    @pytest.fixture
    def repository_instance(
        self,
        mock_async_session: AsyncMock,  # Injected but not directly used by constructor
    ) -> AsyncRepo:
        if not self.repository_class:
            pytest.skip("repository_class not set")
        if not self.aggregate_class:
            pytest.skip("aggregate_class not set for repository test")
        if not self.orm_model_class:
            pytest.skip("orm_model_class not set for repository test")

        # Pass both required arguments to the constructor
        return self.repository_class(
            aggregate_cls=self.aggregate_class,
            model_cls=self.orm_model_class,
        )

    def setup_get_by_id_mock_async(
        self,
        mock_async_session: AsyncMock,
        orm_model_instance: Optional[M],
    ):
        """Configure mock async session for get_by_id."""
        mock_async_session.get.return_value = orm_model_instance

    def assert_session_add_called(
        self,
        mock_async_session: AsyncMock,
        expected_model: Any,
    ):
        """Verify session.add was called correctly (session.add is synchronous)."""
        mock_async_session.add.assert_called_once_with(expected_model)

    def assert_session_commit_awaited(self, mock_async_session: AsyncMock):
        """Verify session.commit was awaited."""
        mock_async_session.commit.assert_awaited_once()

    def assert_session_refresh_awaited(
        self, mock_async_session: AsyncMock, expected_model: Any
    ):
        """Verify session.refresh was awaited."""
        mock_async_session.refresh.assert_awaited_once_with(expected_model)

    def assert_session_delete_awaited(
        self, mock_async_session: AsyncMock, expected_model: Any
    ):
        """Verify session.delete was awaited."""
        mock_async_session.delete.assert_awaited_once_with(expected_model)

assert_session_add_called(mock_async_session, expected_model)

Verify session.add was called correctly (session.add is synchronous).

Source code in src/castlecraft_engineer/testing/aggregate.py
def assert_session_add_called(
    self,
    mock_async_session: AsyncMock,
    expected_model: Any,
):
    """Verify session.add was called correctly (session.add is synchronous)."""
    mock_async_session.add.assert_called_once_with(expected_model)

assert_session_commit_awaited(mock_async_session)

Verify session.commit was awaited.

Source code in src/castlecraft_engineer/testing/aggregate.py
def assert_session_commit_awaited(self, mock_async_session: AsyncMock):
    """Verify session.commit was awaited."""
    mock_async_session.commit.assert_awaited_once()

assert_session_delete_awaited(mock_async_session, expected_model)

Verify session.delete was awaited.

Source code in src/castlecraft_engineer/testing/aggregate.py
def assert_session_delete_awaited(
    self, mock_async_session: AsyncMock, expected_model: Any
):
    """Verify session.delete was awaited."""
    mock_async_session.delete.assert_awaited_once_with(expected_model)

assert_session_refresh_awaited(mock_async_session, expected_model)

Verify session.refresh was awaited.

Source code in src/castlecraft_engineer/testing/aggregate.py
def assert_session_refresh_awaited(
    self, mock_async_session: AsyncMock, expected_model: Any
):
    """Verify session.refresh was awaited."""
    mock_async_session.refresh.assert_awaited_once_with(expected_model)

mock_async_session()

Provides an AsyncMock instance simulating an AsyncSession.

Source code in src/castlecraft_engineer/testing/aggregate.py
@pytest.fixture
def mock_async_session(self) -> AsyncMock:
    """Provides an AsyncMock instance simulating an AsyncSession."""
    return AsyncMock(spec=AsyncSession)

setup_get_by_id_mock_async(mock_async_session, orm_model_instance)

Configure mock async session for get_by_id.

Source code in src/castlecraft_engineer/testing/aggregate.py
def setup_get_by_id_mock_async(
    self,
    mock_async_session: AsyncMock,
    orm_model_instance: Optional[M],
):
    """Configure mock async session for get_by_id."""
    mock_async_session.get.return_value = orm_model_instance