# SPDX-License-Identifier: GPLv3-or-later
# Copyright © 2025 pygaindalf Rui Pinheiro
import datetime
import weakref
from collections.abc import Callable, Iterable, MutableMapping, MutableSet
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, override
from pydantic import ConfigDict, Field, PrivateAttr, computed_field, field_validator
from ...util.callguard import CallguardClassOptions
from ...util.models import LoggableHierarchicalModel
from ..models.entity import Entity, EntityModificationType, EntityRecord, EntityRecordBase
from ..util.superseded import SupersededError, superseded_check
from ..util.uid import UID_SEPARATOR, IncrementingUidFactory, Uid
from .journal import Journal
if TYPE_CHECKING:
from .protocols import SessionManagerHookLiteral
[docs]
class SessionParams(TypedDict):
actor: str
reason: str
[docs]
class Session(LoggableHierarchicalModel):
__callguard_class_options__ = CallguardClassOptions["Session"](
decorator=superseded_check,
decorate_public_methods=True,
decorate_ignore_patterns=("superseded", "ended"),
allow_same_module=True,
ignore_patterns=("_commit_notify_travel_hierarchy_condition", "_commit_apply_travel_hierarchy_condition"),
)
model_config = ConfigDict(
extra="forbid",
frozen=True,
validate_assignment=True,
)
# MARK: Instance Parent
@field_validator("instance_parent_weakref", mode="before")
def _validate_instance_parent_is_session_manager(cls, v: Any) -> Any:
from .session_manager import SessionManager
obj = v() if isinstance(v, weakref.ref) else v
if obj is None or not isinstance(obj, SessionManager):
msg = "Session parent must be a SessionManager object"
raise TypeError(msg)
return v
def _call_parent_hook(self, hook_name: SessionManagerHookLiteral, *args: Any, **kwargs: Any) -> None:
from .session_manager import SessionManager
parent = self.instance_parent
if not isinstance(parent, SessionManager):
msg = "Instance parent is not a SessionManager."
raise TypeError(msg)
parent.call_owner_hook(hook_name, self, *args, **kwargs)
# MARK: Uid
_UID_FACTORY: ClassVar[IncrementingUidFactory] = IncrementingUidFactory()
uid: Uid = Field(default_factory=lambda: Session._UID_FACTORY.next("Session"), validate_default=True, description="Unique identifier for this session.")
@computed_field(description="A human-readable name for this session, derived from its UID.")
@property
def instance_name(self) -> str:
try:
return str(self.uid)
except Exception: # noqa: BLE001 as we want to ensure we can use this in exception messages
return f"{type(self).__name__}{UID_SEPARATOR}<invalid-uid>"
# MARK: Metadata
actor: str = Field(description="Identifier or name of the actor responsible for this session.")
reason: str = Field(description="Reason for starting this session.")
start_time: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(tz=datetime.UTC), init=False, description="Timestamp when the session started."
)
# MARK: State
@property
def superseded(self) -> bool:
return self.ended
# MARK: Entities created in this session
_created: MutableSet[Uid] = PrivateAttr(default_factory=set)
[docs]
def on_entity_record_created(self, record_or_uid: EntityRecord | Uid) -> None:
uid = EntityRecord.narrow_to_uid(record_or_uid)
record = EntityRecord.narrow_to_instance(record_or_uid)
log = record.entity_log.most_recent
if log.what != EntityModificationType.CREATED:
msg = "EntityRecord was not created, cannot notify session."
raise ValueError(msg)
if log.when < self.start_time:
msg = "EntityRecord was created before this session started."
raise ValueError(msg)
self._created.add(uid)
# MARK: EntityRecord Journals
_journals: MutableMapping[Uid, Journal] = PrivateAttr(default_factory=dict)
@property
def dirty(self) -> bool:
return any(j.dirty for j in self._journals.values()) or bool(self._created)
def _add_record_journal(self, record: EntityRecord) -> Journal | None:
if self.ended:
msg = "Cannot add an entity journal to an ended session."
raise RuntimeError(msg)
if self._after_commit_notify:
self.log.warning("Cannot add an entity journal after notification phase.")
return None
journal_cls = record.get_journal_class()
if not issubclass(journal_cls, Journal):
msg = f"{type(record).__name__} journal class {journal_cls} is not a subclass of EntityJournal."
raise TypeError(msg)
journal = journal_cls(instance_parent=weakref.ref(self), record=record)
self._journals[record.uid] = journal
self._restart_commit_notify = True
return journal
[docs]
def get_record_journal(self, record: EntityRecord, *, create: bool = True) -> Journal | None:
if self.ended:
msg = "Cannot get an entity journal from an ended session."
raise SupersededError(msg)
if record.superseded:
self.log.warning(t"EntityRecord {record.instance_name} is superseded; cannot create or retrieve journal.")
return None
journal = self._journals.get(record.uid, None)
if journal is not None:
if journal.record is not record:
if not journal.record.superseded:
msg = "EntityRecord journal already exists for a different version of this entity. Use the latest version instead."
raise RuntimeError(msg)
if create:
del self._journals[record.uid]
else:
return journal
if create:
return self._add_record_journal(record)
else:
return None
def _clear_journals(self) -> None:
for j in self._journals.values():
j.mark_superseded()
self._journals.clear()
def _clear(self) -> None:
self._clear_journals()
self._created.clear()
[docs]
def contains(self, uid: Uid) -> bool:
return (uid in self._journals) or (uid in self._created)
def __len__(self) -> int:
return len(self._journals) + len(self._created)
# MARK: Start
[docs]
@override
def model_post_init(self, context: Any) -> None:
super().model_post_init(context)
self._start()
def _start(self) -> None:
# TODO: Log start
self.log.info(t"Starting session: '{self.reason}' by '{self.actor}'.")
self._call_parent_hook("start")
# MARK: Abort
_in_abort: bool = PrivateAttr(default=False)
@property
def in_abort(self) -> bool:
try:
return getattr(self, "_in_abort", False)
except (TypeError, AttributeError, KeyError):
return False
[docs]
def abort(self) -> None:
if self.ended:
msg = "Cannot abort an ended session."
raise RuntimeError(msg)
# No need to do anything if there are no edits to commit
if not self.dirty:
return
# TODO: Log abort
self.log.warning("Aborting session...")
self._in_abort = True
try:
# Forcefully delete any entity records created in this session
if self._created:
self._clear_journals()
for uid in self._created:
if (entity := Entity.by_uid_or_none(uid)) is not None:
entity.revert()
self._clear()
self.log.warning("Session aborted.")
self._call_parent_hook("abort")
finally:
self._in_abort = False
# MARK: End
_ended: bool = PrivateAttr(default=False)
@property
def ended(self) -> bool:
try:
return getattr(self, "_ended", False)
except (TypeError, AttributeError, KeyError):
return False
[docs]
def end(self) -> None:
if self.ended:
msg = "Cannot end an already ended session."
raise RuntimeError(msg)
if self.dirty:
self.commit()
# TODO: Log ended
self._ended = True
self.log.debug("Session ended.")
self._call_parent_hook("end")
# MARK: Commit
_in_commit: bool = PrivateAttr(default=False)
_after_commit_notify: bool = PrivateAttr(default=False)
@property
def in_commit(self) -> bool:
try:
return getattr(self, "_in_commit", False)
except (TypeError, AttributeError, KeyError):
return False
[docs]
def commit(self) -> None:
if self.ended:
msg = "Cannot commit an ended session."
raise SupersededError(msg)
# No need to do anything if there are no edits to commit
if not self.dirty:
return
# TODO: Log commit
self.log.info("Committing session...")
self._in_commit = True
try:
self._commit()
self._clear()
self._call_parent_hook("commit")
finally:
self._in_commit = False
self._after_commit_notify = False
self.log.info("Commit concluded.")
def _commit(self) -> None:
self._commit_notify()
self._commit_apply()
def _commit_travel_hierarchy(
self, iterable: Iterable[Uid], *, condition: Callable[[EntityRecordBase], bool] | None = None, copy: bool = False
) -> Iterable[EntityRecordBase]:
if copy:
iterable = list(iterable)
for uid in iterable:
record = EntityRecord.by_uid_or_none(uid)
if record is None or record.superseded:
continue
yield from record.iter_hierarchy(condition=condition, use_journal=True)
# MARK: Commit - Notify
_restart_commit_notify: bool = PrivateAttr(default=False)
[docs]
def on_journal_reset_notified_dependents(self, journal: Journal) -> None: # noqa: ARG002 as this is for overriding
if not self._in_commit:
msg = "Can only reset notified dependents during commit."
raise RuntimeError(msg)
if self._after_commit_notify:
msg = "Cannot reset notified dependents after notification phase."
raise RuntimeError(msg)
self._restart_commit_notify = True
def _commit_notify_travel_hierarchy_condition(self, e: EntityRecordBase) -> bool:
j = self._journals.get(e.uid, None)
return (not j.notified_dependents) if j is not None else False
def _commit_notify(self) -> None:
"""Notify all journals of changes in dependency order, allowing them to update their diffs accordingly."""
self.log.info("Notifying journals of changes...")
pass_count = 0
while True:
pass_count += 1
self.log.debug(t"Starting notify pass {pass_count}...")
self._restart_commit_notify = False
for e in self._commit_travel_hierarchy(self._journals.keys(), condition=self._commit_notify_travel_hierarchy_condition, copy=True):
j = self._journals[e.uid]
j.notify_dependents()
if self._restart_commit_notify:
break
if self._restart_commit_notify:
continue
self._call_parent_hook("notify")
if not self._restart_commit_notify:
break
assert not self._restart_commit_notify, "Restart flag should be false after notify loop."
assert all(j.notified_dependents for j in self._journals.values()), "All journals should have notified dependents after notify loop."
# Freeze all journals to prevent further edits
for j in self._journals.values():
j.freeze()
# Done
self._after_commit_notify = True
# MARK: Commit - Apply
def _commit_apply_travel_hierarchy_condition(self, e: EntityRecordBase) -> bool:
j = self._journals.get(e.uid, None)
return j is not None and not j.superseded
def _commit_apply(self) -> None:
"""Iterate through flattened hierarchy, flatten updates and apply them (creating new entity versions, or deleting them as requested)."""
self.log.info("Committing journals...")
# Apply all journals in dependency order
for e in self._commit_travel_hierarchy(self._journals.keys(), condition=self._commit_apply_travel_hierarchy_condition, copy=False):
j = self._journals[e.uid]
j.commit()
# Check for newly-created unreachable entities and revert them
for uid in self._created:
entity = Entity.by_uid(uid)
if not entity.is_reachable(recursive=True, use_journal=True):
self.log.warning(t"Entity {entity.instance_name} created in this session is unreachable; reverting. This may indicate a logic bug.")
entity.revert()
self._call_parent_hook("apply")