# SPDX-License-Identifier: GPLv3-or-later
# Copyright © 2025 pygaindalf Rui Pinheiro
import logging
from collections.abc import Iterable, Mapping, MutableSet, Sequence
from collections.abc import Set as AbstractSet
from functools import cached_property
from typing import TYPE_CHECKING, Any, ClassVar, override
from typing import cast as typing_cast
from frozendict import frozendict
from pydantic import ConfigDict, Field, InstanceOf, PrivateAttr, field_validator
from ...util.callguard import CallguardClassOptions
from ...util.helpers import generics
from ...util.models import LoggableHierarchicalModel
from ..collections.journalled import JournalledCollection, JournalledMapping, JournalledSequence, JournalledSet
from ..collections.ordered_view import OrderedViewSet
from ..models.entity import Entity, EntityImpl, EntityRecord
from ..util.superseded import SupersededError, superseded_check
from ..util.uid import Uid
if TYPE_CHECKING:
from _typeshed import SupportsRichComparison
from ..collections import UidProxyMutableSet
from ..models.annotation import Annotation
from .session import Session
[docs]
class Journal(
LoggableHierarchicalModel,
EntityImpl[InstanceOf[MutableSet["Annotation"]], MutableSet[Uid]],
):
__callguard_class_options__ = CallguardClassOptions["Journal"](
decorator=superseded_check,
decorate_public_methods=True,
decorate_ignore_patterns=(
"superseded",
"_marked_superseded",
"dirty",
"has_diff",
"deleted",
"marked_for_deletion",
"uid",
"mark_superseded",
"freeze",
"commit_yield_hierarchy",
"get_diff",
"instance_name",
"instance_hierarchy",
),
)
model_config = ConfigDict(
extra="forbid",
frozen=True,
validate_assignment=True,
)
PROPAGATE_TO_CHILDREN: ClassVar[bool] = False
# MARK: Subclassing
# We rely on init=False on subclasses to convince the type checker that fields do not get exposed in the constructor
# as such we must swallow that parameter here
def __init_subclass__(cls, *, init: bool = False, unsafe_hash: bool = False) -> None:
super().__init_subclass__()
# MARK: EntityRecord
record: InstanceOf[EntityRecord] = Field(description="The entity record associated with this journal entry.")
@property
@override
def uid(self) -> Uid: # pyright: ignore[reportIncompatibleVariableOverride]
return self.record.uid
if not TYPE_CHECKING:
@property
def version(self) -> int:
return self.record.version
@override
def __hash__(self) -> int:
return hash((type(self).__name__, hash(self.record)))
@field_validator("record", mode="before")
@classmethod
def _validate_record(cls, record: Any) -> EntityRecord:
if not isinstance(record, EntityRecord):
msg = f"Expected EntityRecordBase, got {type(record).__name__}"
raise TypeError(msg)
if record.superseded:
msg = f"EntityJournal.record '{record}' is superseded."
raise SupersededError(msg)
return record
@property
def entity_or_none(self) -> Entity | None:
return self.record.entity_or_none
@property
@override
def entity(self) -> Entity:
return self.record.entity
@property
@override
def is_journal(self) -> bool:
return True
@override
def __str__(self) -> str:
return f"{type(self).__name__}({self.record!s})"
@override
def __repr__(self) -> str:
return f"<{type(self).__name__}:{self.record!r}>"
@property
@override
def instance_name(self) -> str:
return f"{type(self).__name__}({self.record.uid})"
[docs]
def sort_key(self) -> SupportsRichComparison:
# Delegate to entity sort key, but we pretend to be the entity for this call
return type(self.record).sort_key(typing_cast("EntityRecord", self))
# MARK: Superseded
_marked_superseded: bool = PrivateAttr(default=False)
[docs]
def mark_superseded(self) -> None:
self.freeze()
self._marked_superseded = True
@property
def superseded(self) -> bool:
try:
if getattr(self, "_marked_superseded", False):
return True
except (TypeError, AttributeError, KeyError):
pass
return self.deleted or self.record.superseded
@property
def dirty(self) -> bool:
if self._dirty_children:
return True
return self.has_diff
@property
def has_diff(self) -> bool:
if self._marked_for_deletion:
return True
for value in self._updates.values():
if isinstance(value, JournalledCollection):
if value.edited:
return True
else:
return True
return False
def _on_dirtied(self) -> None:
self._reset_children_uids_cache()
self._propagate_dirty()
self._reset_notified_dependents()
# MARK: Session
@property
def session(self) -> Session:
parent = self.instance_parent
if parent is None:
msg = f"EntityJournal {self} has no parent Session."
raise RuntimeError(msg)
from .session import Session
if not isinstance(parent, Session):
msg = f"EntityJournal {self} parent is not a Session."
raise TypeError(msg)
return parent
# MARK: Frozen
_frozen: bool = PrivateAttr(default=False)
@property
def frozen(self) -> bool:
return self._frozen
[docs]
def freeze(self) -> None:
self._frozen = True
for v in self._updates.values():
if isinstance(v, JournalledCollection):
v.freeze()
# MARK: Fields API
_updates: dict[str, Any] = PrivateAttr(default_factory=dict)
@staticmethod
def _is_journal_attribute(name: str) -> bool:
return hasattr(Journal, name) or name in Journal.model_fields or name in Journal.model_computed_fields
[docs]
def is_computed_field(self, field: str) -> bool:
return self.record.is_computed_field(field)
[docs]
def is_field_alias(self, field: str) -> bool:
return self.record.is_model_field_alias(field)
[docs]
def is_model_field(self, field: str) -> bool:
return self.record.is_model_field(field)
[docs]
def has_field(self, field: str) -> bool:
return self.is_model_field(field) or self.is_computed_field(field)
[docs]
def can_modify(self, field: str) -> bool:
info = type(self.record).model_fields.get(field, None)
if info is None:
return False
if not isinstance((extra := info.json_schema_extra), dict):
return True
return not extra.get("readOnly", False)
[docs]
def is_field_edited(self, field: str) -> bool:
value = self._updates.get(field, None)
if value is None:
return False
return value.edited if isinstance(value, JournalledCollection) else True
[docs]
def get_original_field(self, field: str) -> Any:
if not self.has_field(field):
msg = f"EntityRecord of type {type(self.record).__name__} does not have field '{field}'."
raise AttributeError(msg)
return super(EntityRecord, self.record).__getattribute__(field)
def _wrap_field(self, field: str, original: Any) -> Any:
if field in self._updates:
msg = f"Field '{field}' of journal {self} is already wrapped."
raise RuntimeError(msg)
new = original
if isinstance(original, OrderedViewSet):
journalled_type = original.get_journalled_type()
new = journalled_type(original, instance_parent=self, instance_name=field)
elif isinstance(original, Sequence) and not isinstance(original, (str, bytes)):
new = JournalledSequence(original, instance_parent=self, instance_name=field)
elif isinstance(original, Mapping):
new = JournalledMapping(original, instance_parent=self, instance_name=field)
elif isinstance(original, AbstractSet):
new = JournalledSet(original, instance_parent=self, instance_name=field)
else:
return original
if self._frozen:
msg = f"Cannot wrap field '{field}' of frozen journal {self}."
raise RuntimeError(msg)
self._updates[field] = new
self._on_dirtied()
return new
[docs]
def set_field[T](self, field: str, value: T) -> T:
field = self.record.resolve_field_alias(field)
has_update = field in self._updates
if not has_update and not self.has_field(field):
msg = f"EntityRecord of type {type(self.record).__name__} does not have field '{field}'."
raise AttributeError(msg)
if not self.can_modify(field):
msg = f"Field '{field}' of entity type {type(self.record).__name__} is read-only."
raise AttributeError(msg)
current = self.get_field(field, wrap=False)
if value is current:
return value
if self._frozen:
msg = f"Cannot modify field '{field}' of frozen journal {self}."
raise RuntimeError(msg)
original = self.get_original_field(field)
if value is original:
if has_update:
del self._updates[field]
else:
if self.record.is_protected_field_type(field):
msg = f"Field '{field}' of record type {type(self.record).__name__} is protected and cannot be modified. Use the collection's methods to modify it instead."
raise AttributeError(msg)
self._updates[field] = value
self._on_dirtied()
return value
[docs]
def get_field(self, field: str, *, wrap: bool = True) -> Any:
if self.superseded:
msg = "Cannot get field from a superseded journal."
raise SupersededError(msg)
field = self.record.resolve_field_alias(field)
if field in self._updates:
return self._updates[field]
original = self.get_original_field(field)
return self._wrap_field(field, original) if wrap else original
if not TYPE_CHECKING:
@override
def __getattribute__(self, name: str) -> Any:
# Short-circuit cases
if (
# Short-circuit private and 'EntityRecord' attributes
(name.startswith("_") or Journal._is_journal_attribute(name))
or
# If this is not a model field
(not self.is_model_field(name))
or
# If not in a session, return the normal attribute
(self.superseded)
):
return super().__getattribute__(name)
# Otherwise, use the journal to get the attribute
return self.get_field(name)
@override
def __setattr__(self, name: str, value: Any) -> None:
if (
# Short-circuit private and 'EntityRecord' attributes
(name.startswith("_") or Journal._is_journal_attribute(name))
or
# If this is not a model field
(not self.is_model_field(name))
or
# If not in a session, set the normal attribute
(self.superseded)
):
return super().__setattr__(name, value)
# Otherwise set attribute on the journal
return self.set_field(name, value)
[docs]
def update(self, **kwargs: Any) -> None:
for k, v in kwargs.items():
self.set_field(k, v)
[docs]
def on_journalled_collection_edit(self, collection: JournalledCollection) -> None:
if self._frozen:
msg = f"Cannot modify field '{collection.instance_name}' of frozen journal {self}."
raise RuntimeError(msg)
self._on_dirtied()
[docs]
def get_diff(self) -> frozendict[str, Any]:
if (cached := getattr(self, "_diff", None)) is not None:
return cached
diff = self._updates.copy()
for k, v in self._updates.items():
if isinstance(v, JournalledCollection):
if not v.edited:
del diff[k]
else:
diff[k] = tuple(v.journal)
result = frozendict(diff)
if self._frozen:
self._diff = result
return result
# MARK: Dirty Propagation
_dirty_children: MutableSet[Uid] = PrivateAttr(default_factory=set)
_propagated_dirty: bool = PrivateAttr(default=False)
def _update_child_dirty_state(self, child: Journal, *, dirty: bool | None = None) -> None:
if dirty is None:
dirty = child.dirty
if dirty:
self._dirty_children.add(child.uid)
else:
self._dirty_children.discard(child.uid)
self._propagate_dirty()
def _propagate_dirty(self) -> None:
"""Propagate whether we are dirty to the parent entity's journal."""
# Check if the dirty state has changed since the last time we propagated to our parent journal
dirty = self.dirty
if dirty == self._propagated_dirty:
return
# Propagate to parent journal
if (parent := self.record.record_parent_or_none) is None:
return
parent_journal = parent.journal
parent_journal._update_child_dirty_state(self, dirty=dirty) # noqa: SLF001
# Cache the last propagated dirty state
self._propagated_dirty = dirty
# MARK: Deletion
_marked_for_deletion: bool = PrivateAttr(default=False)
@property
def marked_for_deletion(self) -> bool:
return self._marked_for_deletion
_deleted: bool = PrivateAttr(default=False)
@property
def deleted(self) -> bool:
try:
return getattr(self, "_deleted", False)
except (TypeError, AttributeError, KeyError):
return False
[docs]
def delete(self) -> None:
if self._marked_for_deletion:
return
self._marked_for_deletion = True
# Once marked for deletion, this journal will no longer be used to make updates to the entity
self._updates.clear()
self.freeze()
# Delete all children recursively
self._propagate_dirty()
self.record.propagate_deletion()
# MARK: Dependent Notifications
_notified_dependents: bool = PrivateAttr(default=False)
@property
def notified_dependents(self) -> bool:
return self._notified_dependents
[docs]
def notify_dependents(self) -> None:
if self._notified_dependents:
msg = f"Journal {self} has already notified its dependents of changes."
raise RuntimeError(msg)
self._notified_dependents = True
if not self.has_diff:
return
deletion = self._marked_for_deletion
self.log.debug(t"Notifying dependents of pending {'deletion' if deletion else 'update'}...")
for dep in self.record.dependent_uids:
record = EntityRecord.by_uid(dep)
if not record.marked_for_deletion:
if deletion:
record.on_dependency_deleted(self)
else:
record.on_dependency_updated(self)
def _reset_notified_dependents(self) -> None:
if not self._notified_dependents:
return
self._notified_dependents = False
self.session.on_journal_reset_notified_dependents(self)
# MARK: Commit
_committed: bool = PrivateAttr(default=False)
@property
def committed(self) -> bool:
return self._committed
[docs]
def commit(self) -> EntityRecord | None:
assert self._notified_dependents, f"Cannot commit journal {self} before notifying dependents."
assert not self._committed, f"Journal {self} has already been committed."
self.freeze()
if not self.has_diff:
self.mark_superseded()
return self.record
deletion = self._marked_for_deletion
self.log.debug(t"Committing entity {'deletion' if deletion else 'update'}...")
if self._marked_for_deletion:
self._commit_delete()
result = None
else:
result = self._commit_update()
self._committed = True
self.mark_superseded()
return result
def _commit_update(self) -> EntityRecord:
self.log.debug("Committing update")
# Collect all updates
updates = {}
for attr, update in self._updates.items():
if isinstance(update, (JournalledSequence, JournalledMapping, JournalledSet)):
if not update.edited:
continue
updates[attr] = update
if not updates:
return self.record
if self.log.isEnabledFor(logging.DEBUG):
updates_gen = ", ".join(f"'{k}': {v!s}" for k, v in updates.items())
self.log.debug(t"Updates to apply: {{{updates_gen}}}")
# Update the entity
new_record = self._new_entity = self.record.update(**updates)
self.log.debug("New entity record created: %s version %d", new_record, new_record.version)
# Done
return new_record
def _commit_delete(self) -> None:
assert self._marked_for_deletion, f"Cannot commit deletion of journal {self} that is not marked for deletion."
assert not self._deleted, f"Journal {self} has already been deleted."
self.log.debug("Committing deletion")
self.record.apply_deletion()
self._deleted = True
# MARK: Children
[docs]
def iter_children_uids(self) -> Iterable[Uid]:
yield from self.record.iter_children_uids(use_journal=True)
[docs]
@cached_property
def children_uids(self) -> Iterable[Uid]:
return frozenset(self.iter_children_uids())
def _reset_children_uids_cache(self) -> None:
self.__dict__.pop("children_uids", None)
# MARK: Dependencies
@property
def extra_dependencies(self) -> UidProxyMutableSet[Entity]:
from ..collections import UidProxyMutableSet
klass = UidProxyMutableSet[Entity]
return klass(instance=self, field="extra_dependency_uids", source=klass)
[docs]
def add_dependency(self, record_or_uid: EntityRecord | Uid) -> None:
uid = EntityRecord.narrow_to_uid(record_or_uid)
if uid in self.extra_dependency_uids:
return
self.extra_dependency_uids.add(uid)
[docs]
def remove_dependency(self, record_or_uid: EntityRecord | Uid) -> None:
uid = EntityRecord.narrow_to_uid(record_or_uid)
if uid not in self.extra_dependency_uids:
return
self.extra_dependency_uids.discard(uid)
generics.register_type(Journal)