Skip to content

Implementing Custom Aggregate Repositories (e.g., MongoDB with Pydantic)

The castlecraft_engineer framework is designed with flexibility in mind. While it provides convenient base classes for repositories using SQLModel and SQLAlchemy, the core Aggregate concept is persistence-agnostic. This means you can implement repositories tailored to different databases or ORM/ODM tools.

This guide demonstrates how to create a custom aggregate repository for MongoDB using Motor (an asynchronous MongoDB driver) and Pydantic for defining the structure of your MongoDB documents.

Prerequisites

  • Familiarity with the Aggregate and general Repository patterns.
  • Motor and Pydantic libraries installed:
    uv pip install motor pydantic bson # 'bson' for ObjectId
    
  • A MongoDB instance running and accessible.

Core Components

  1. Your Aggregate: Defined as usual, inheriting from castlecraft_engineer.abstractions.aggregate.Aggregate.
  2. Pydantic Model: Represents the BSON document structure in MongoDB. It must include id (typically MongoDB's _id) and version fields for identity and optimistic concurrency.
  3. Custom Repository Class: Manages persistence logic using Motor. It will implement methods like get_by_id, save, and delete_by_id.
  4. Mapping Logic: Functions to convert between your Aggregate instances and their Pydantic/MongoDB document representations.

Step-by-Step Implementation

1. Define Your Aggregate

Let's assume we have a simple ProductAggregate.

# product_aggregate.py
from typing import TypeVar
from bson import ObjectId # For MongoDB ObjectIDs
from castlecraft_engineer.abstractions.aggregate import Aggregate
from castlecraft_engineer.abstractions.event import Event # If you use events

# Define a specific ID type for this aggregate
ProductAggregateId = ObjectId

@dataclass
class ProductCreated(Event):
    id: ProductAggregateId
    name: str
    sku: str

class ProductAggregate(Aggregate[ProductAggregateId]):
    name: str
    sku: str
    price: float
    stock_quantity: int

    def __init__(self, id: ProductAggregateId, version: int = -1, name: str = "", sku: str = "", price: float = 0.0, stock_quantity: int = 0):
        super().__init__(id, version)
        self.name = name
        self.sku = sku
        self.price = price
        self.stock_quantity = stock_quantity

    @classmethod
    def create(cls, id: ProductAggregateId, name: str, sku: str, price: float, stock_quantity: int) -> "ProductAggregate":
        product = cls(id=id, name=name, sku=sku, price=price, stock_quantity=stock_quantity)
        # product._record_event(ProductCreated(id=id, name=name, sku=sku)) # Example event
        return product

    def update_price(self, new_price: float):
        if new_price < 0:
            raise ValueError("Price cannot be negative.")
        self.price = new_price
        # self._record_event(...)

    def change_stock(self, quantity_change: int):
        if self.stock_quantity + quantity_change < 0:
            raise ValueError("Stock quantity cannot be negative.")
        self.stock_quantity += quantity_change
        # self._record_event(...)

2. Define Your Pydantic Data Model

This model defines how ProductAggregate data is stored in a MongoDB collection.

# product_document.py
from pydantic import BaseModel, Field
from bson import ObjectId
from typing import Optional

# Reuse the ID type from the aggregate
from .product_aggregate import ProductAggregateId

class ProductDocument(BaseModel):
    id: ProductAggregateId = Field(default_factory=ObjectId, alias="_id")
    version: int
    name: str
    sku: str
    price: float
    stock_quantity: int

    class Config:
        populate_by_name = True # Allows using alias "_id" with field name "id"
        json_encoders = {ObjectId: str} # Helpful for serialization
        arbitrary_types_allowed = True # For ObjectId

3. Create the Custom MongoDB Repository

This class will handle interactions with a MongoDB collection via Motor.

# product_mongo_repository.py
import logging
from typing import Type, Optional, Generic, TypeVar
from motor.motor_asyncio import AsyncIOMotorCollection
from bson import ObjectId
from pydantic import BaseModel

from castlecraft_engineer.abstractions.aggregate import Aggregate
# Reuse exceptions from the framework or define your own if needed
from castlecraft_engineer.abstractions.repository import (
    RepositoryError,
    AggregateNotFoundError,
    OptimisticConcurrencyError,
)

from .product_aggregate import ProductAggregate, ProductAggregateId
from .product_document import ProductDocument

# Generic type variables for the repository
TAggregateId = TypeVar("TAggregateId")
TAggregate = TypeVar("TAggregate", bound=Aggregate)
TPersistenceModel = TypeVar("TPersistenceModel", bound=BaseModel) # For Pydantic models

class ProductMongoRepository(Generic[TAggregateId, TAggregate, TPersistenceModel]):
    """
    A MongoDB-backed repository for Aggregates, using Pydantic for document modeling.
    """
    def __init__(
        self,
        motor_collection: AsyncIOMotorCollection,
        aggregate_cls: Type[TAggregate],
        persistence_model_cls: Type[TPersistenceModel],
    ):
        if not issubclass(aggregate_cls, Aggregate):
            raise ValueError("aggregate_cls must be an Aggregate subclass.")
        if not issubclass(persistence_model_cls, BaseModel):
            raise ValueError("persistence_model_cls must be a Pydantic BaseModel subclass.")

        # Validate presence of 'id' (or '_id' alias) and 'version' in the Pydantic model
        has_id_field = False
        for field_name, field_info in persistence_model_cls.model_fields.items():
            if field_name == "id" or field_info.alias == "_id":
                has_id_field = True
                break
        if not has_id_field:
            raise TypeError(
                f"{persistence_model_cls.__name__} needs an 'id' field (can be aliased to '_id')."
            )
        if "version" not in persistence_model_cls.model_fields:
            raise TypeError(
                f"{persistence_model_cls.__name__} needs a 'version' field for optimistic concurrency."
            )

        self.collection = motor_collection
        self.aggregate_cls = aggregate_cls
        self.persistence_model_cls = persistence_model_cls
        self._logger = logging.getLogger(
            f"{self.__class__.__name__}[{self.aggregate_cls.__name__}]"
        )
        self._logger.debug(
            f"Initialized for Aggregate: {self.aggregate_cls.__name__}, "
            f"Persistence Model: {self.persistence_model_cls.__name__}"
        )

    def _map_document_to_aggregate(self, doc_dict: dict) -> TAggregate:
        """Maps a MongoDB document (dictionary) to an Aggregate instance."""
        self._logger.debug(f"Mapping document to {self.aggregate_cls.__name__} (ID: {doc_dict.get('_id')})")
        try:
            # Validate and structure data using the Pydantic model
            model_instance = self.persistence_model_cls.model_validate(doc_dict)

            # Pydantic model's 'id' field is aliased to '_id' from MongoDB
            agg_id = model_instance.id
            agg_version = model_instance.version

            # Create aggregate with base properties
            # Pass all model_instance attributes that match aggregate constructor or properties
            agg_data = {
                field: getattr(model_instance, field)
                for field in model_instance.model_fields
                if hasattr(self.aggregate_cls, field) # Check if aggregate expects this field
            }
            # Ensure id and version are correctly passed if not directly in agg_data via model_fields
            agg_data['id'] = agg_id
            agg_data['version'] = agg_version

            agg = self.aggregate_cls(**agg_data)

            self._logger.debug(f"Finished mapping document to Aggregate (Version: {agg.version})")
            return agg
        except Exception as e:
            self._logger.error(f"Document->Aggregate mapping failed: {e!s}", exc_info=True)
            raise RepositoryError(
                f"Failed to map document to aggregate {self.aggregate_cls.__name__}: {e!s}"
            ) from e

    def _map_aggregate_to_document_data(self, aggregate: TAggregate) -> dict:
        """
        Maps an Aggregate instance to a dictionary suitable for MongoDB storage,
        respecting the Pydantic model's structure.
        Version is handled by the save method. ID (_id) is used in queries.
        """
        self._logger.debug(f"Mapping {self.aggregate_cls.__name__} to document data (ID: {aggregate.id})")
        try:
            data = {}
            for field_name in self.persistence_model_cls.model_fields.keys():
                # MongoDB's _id is typically the aggregate.id. It's used in the query, not usually in $set.
                # Version is also handled explicitly by the save method during $set.
                if field_name == "id" or self.persistence_model_cls.model_fields[field_name].alias == "_id":
                    continue
                if field_name == "version":
                    continue

                if hasattr(aggregate, field_name):
                    data[field_name] = getattr(aggregate, field_name)

            self._logger.debug("Finished mapping Aggregate to document data")
            return data
        except Exception as e:
            self._logger.error(f"Aggregate->Document data mapping failed: {e!s}", exc_info=True)
            raise RepositoryError(
                f"Failed to map aggregate {self.aggregate_cls.__name__} to document data: {e!s}"
            ) from e

    async def get_by_id(self, id_val: TAggregateId) -> Optional[TAggregate]:
        self._logger.debug(f"Getting aggregate by ID: {id_val}")
        doc_dict = await self.collection.find_one({"_id": id_val})
        if not doc_dict:
            self._logger.warning(f"Aggregate ID {id_val} not found in database.")
            return None

        self._logger.debug(f"Found document for ID {id_val}. Mapping to aggregate.")
        return self._map_document_to_aggregate(doc_dict)

    async def save(self, aggregate: TAggregate) -> TAggregate:
        agg_id = aggregate.id
        current_agg_version = aggregate.version
        is_new = current_agg_version == -1

        self._logger.debug(
            f"Attempting to save aggregate ID: {agg_id}, Current Version: {current_agg_version}, Is New: {is_new}"
        )

        try:
            doc_data_to_persist = self._map_aggregate_to_document_data(aggregate)

            if is_new:
                aggregate._increment_version() # New version is 0
                doc_data_to_persist["version"] = aggregate.version
                # For new documents, _id is the aggregate.id
                doc_data_to_persist["_id"] = agg_id

                await self.collection.insert_one(doc_data_to_persist)
                self._logger.info(
                    f"Inserted new aggregate ID: {agg_id} with version {aggregate.version}."
                )
            else:
                # Optimistic concurrency: update only if version matches
                next_version = current_agg_version + 1
                doc_data_to_persist["version"] = next_version

                result = await self.collection.update_one(
                    {"_id": agg_id, "version": current_agg_version}, # Match current ID and version
                    {"$set": doc_data_to_persist}
                )

                if result.matched_count == 0:
                    # Document not found with the expected version, or not found at all
                    current_db_doc = await self.collection.find_one({"_id": agg_id})
                    if current_db_doc is None:
                        self._logger.error(f"Aggregate ID {agg_id} not found during update attempt.")
                        raise AggregateNotFoundError(agg_id)
                    else:
                        db_version = current_db_doc.get("version")
                        self._logger.warning(
                            f"Optimistic lock failed for ID: {agg_id}. "
                            f"Expected DB version {current_agg_version}, found {db_version}."
                        )
                        raise OptimisticConcurrencyError(agg_id, current_agg_version, db_version)

                aggregate._increment_version() # Commit version increment on aggregate object
                self._logger.info(
                    f"Updated aggregate ID: {agg_id} to version {aggregate.version}."
                )
            return aggregate
        except (OptimisticConcurrencyError, AggregateNotFoundError) as e:
            self._logger.error(f"Save failed for aggregate ID {agg_id}: {e!s}")
            raise e
        except Exception as e:
            self._logger.exception(f"Unexpected error during save for aggregate ID {agg_id}: {e!s}")
            raise RepositoryError(f"Save failed for aggregate ID {agg_id}: {e!s}") from e

    async def delete_by_id(self, id_val: TAggregateId) -> bool:
        self._logger.debug(f"Attempting to delete aggregate ID: {id_val}")
        result = await self.collection.delete_one({"_id": id_val})
        if result.deleted_count == 0:
            self._logger.warning(f"Aggregate ID: {id_val} not found for deletion.")
            return False

        self._logger.info(f"Deleted aggregate ID: {id_val}.")
        return True

# Example of how to instantiate and use (e.g., in a DI setup or command handler)
# async def get_product_repository(db: AsyncIOMotorDatabase) -> ProductMongoRepository:
#     collection = db["products"]
#     return ProductMongoRepository(
#         motor_collection=collection,
#         aggregate_cls=ProductAggregate,
#         persistence_model_cls=ProductDocument
#     )

4. Using the Custom Repository

You would typically inject an instance of ProductMongoRepository into your application services or command handlers.

# Example usage (conceptual)
# from motor.motor_asyncio import AsyncIOMotorClient
# from .product_mongo_repository import ProductMongoRepository, ProductAggregate, ProductDocument, ProductAggregateId

# async def main_logic():
#     # Setup Motor client and get collection
#     client = AsyncIOMotorClient("mongodb://localhost:27017")
#     db = client["mydatabase"]
#     products_collection = db["products"]

#     # Instantiate repository
#     repo = ProductMongoRepository(
#         motor_collection=products_collection,
#         aggregate_cls=ProductAggregate,
#         persistence_model_cls=ProductDocument
#     )

#     # Create a new product
#     new_product_id = ProductAggregateId() # Generates a new ObjectId
#     product = ProductAggregate.create(
#         id=new_product_id,
#         name="Awesome Gadget",
#         sku="GADGET001",
#         price=99.99,
#         stock_quantity=100
#     )
#     await repo.save(product)
#     print(f"Saved product: {product.id}, version: {product.version}")

#     # Retrieve and update
#     retrieved_product = await repo.get_by_id(new_product_id)
#     if retrieved_product:
#         retrieved_product.update_price(89.99)
#         await repo.save(retrieved_product)
#         print(f"Updated product: {retrieved_product.id}, new version: {retrieved_product.version}")

#     # ... more operations

Key Considerations

  • Optimistic Concurrency: The save method includes a basic optimistic locking strategy by checking the version field during updates. If update_one doesn't find a matching document (ID and version), it raises an OptimisticConcurrencyError.
  • Transactions: For operations involving multiple aggregates or complex consistency requirements, explore MongoDB's multi-document transaction capabilities if your MongoDB version and setup support them. Motor provides an API for transactions.
  • Error Handling: The provided example includes basic error handling and logging. Adapt this to your project's specific needs.
  • Indexing: Ensure appropriate indexes are created on your MongoDB collection (e.g., on _id (default), and potentially version or other queryable fields) for performance.
  • Dependency Injection: In a real application, you'd likely use a dependency injection framework to manage the lifecycle and provision of your repository and its dependencies (like the Motor collection).

Conclusion

This example demonstrates that the Aggregate pattern is highly adaptable. By creating a custom repository, you can leverage different persistence technologies like MongoDB while keeping your domain logic (within the Aggregate) clean and independent of specific database concerns. The key is to implement the repository interface (get, save, delete) and the mapping logic between your aggregates and their persisted form.