class InMemoryEventStore(EventStore[TAggregateId]):
"""
An in-memory implementation of the EventStore for testing purposes.
This store is not thread-safe for concurrent writes from multiple async tasks
if those tasks might interleave operations on the same aggregate ID without
proper external synchronization. However, for typical single-threaded test
execution or sequential operations within a test, it's perfectly suitable.
"""
def __init__(self) -> None:
self._streams: Dict[TAggregateId, List[Event]] = defaultdict(list)
self._versions: Dict[TAggregateId, int] = {}
# For async operations, a lock per aggregate_id might be needed
# if true concurrent access to the same stream is simulated in tests.
# For most unit test scenarios, this is not strictly necessary.
self._locks: Dict[TAggregateId, asyncio.Lock] = defaultdict(asyncio.Lock)
async def _get_stream_lock(self, aggregate_id: TAggregateId) -> asyncio.Lock:
return self._locks[aggregate_id]
async def append_events(
self,
aggregate_id: TAggregateId,
expected_version: int,
events: List[Event],
) -> None:
if not events:
return
async with await self._get_stream_lock(aggregate_id):
current_version = self._versions.get(aggregate_id, -1)
if current_version != expected_version:
raise EventStoreConflictError(
aggregate_id, expected_version, current_version
)
stream = self._streams[aggregate_id]
stream.extend(events)
# The new version is the index of the last event in the stream
self._versions[aggregate_id] = current_version + len(events)
async def load_events(
self,
aggregate_id: TAggregateId,
from_version: Optional[int] = None,
) -> List[Event]:
# Lock for read consistency if needed
async with await self._get_stream_lock(aggregate_id):
stream = self._streams.get(aggregate_id, [])
if from_version is None:
return list(stream) # Return a copy
# Versions are 0-indexed (sequence number of the event)
# from_version means "after this version"
# So, if from_version is 0, we want events from index 1 onwards.
# The events themselves are indexed 0, 1, 2...
# If from_version is X, we need events whose sequence is > X.
# Event at index `i` has sequence `i`.
# So we need events from index `from_version + 1`.
start_index = from_version + 1
if start_index < 0: # Should not happen with valid from_version
start_index = 0
# Return a copy
return list(stream[start_index:])
async def get_current_version(self, aggregate_id: TAggregateId) -> Optional[int]:
# Lock for read consistency
async with await self._get_stream_lock(aggregate_id):
# Version is the sequence number of the last event.
# If no events, version is -1, but the method should return None.
if aggregate_id not in self._versions:
return None
return self._versions[aggregate_id]
async def clear(self) -> None:
"""Clears all streams and versions from the store."""
# Need to acquire all locks or a global lock if we had one.
# For simplicity in a test store, direct clear is often acceptable,
# assuming tests run sequentially or manage their own isolation.
# If locks are per-stream, clearing _locks needs care if streams are active.
# A more robust clear might involve iterating and acquiring each lock.
# However, for typical test teardown, this should be fine.
self._streams.clear()
self._versions.clear()
# Re-create defaultdict for locks
self._locks.clear()
async def get_stream(self, aggregate_id: TAggregateId) -> List[Event]:
"""Returns a copy of the event stream for a given aggregate ID."""
async with await self._get_stream_lock(aggregate_id):
return list(self._streams.get(aggregate_id, []))