# SPDX-License-Identifier: GPLv3-or-later
# Copyright © 2025 pygaindalf Rui Pinheiro
import annotationlib
import sys
from abc import ABCMeta
from collections.abc import Callable, Iterable, MutableMapping, MutableSet
from collections.abc import Set as AbstractSet
from functools import cached_property
from typing import TYPE_CHECKING, Any, ClassVar, Self, override
from typing import cast as typing_cast
from frozendict import frozendict
from pydantic import (
ConfigDict,
PositiveInt,
PrivateAttr,
ValidationInfo,
field_serializer,
field_validator,
model_validator,
)
from ....util.callguard import CallguardClassOptions
from ....util.helpers import generics, script_info, type_hints
from ....util.mixins import HierarchicalMixinMinimal, HierarchicalProtocol, NamedMixinMinimal, NamedProtocol
from ....util.models import LoggableHierarchicalRootModel
from ...util.superseded import superseded_check
from ...util.uid import Uid, UidProtocol
from .dependency_event_handler.base import EntityDependencyEventHandlerBase
from .dependency_event_handler.type_enum import EntityDependencyEventType
from .entity_dependents import EntityDependents
from .entity_impl import EntityImpl
from .entity_log import EntityLog, EntityModificationType
from .entity_schema import EntitySchema
if TYPE_CHECKING:
from _typeshed import SupportsRichComparison
from ...collections.uid_proxy import UidProxySet
from ...journal.journal import Journal
from ...journal.session import Session
from ...journal.session_manager import SessionManager
from ..annotation import AnnotationRecord
from ..store.entity_store import EntityStore
from .entity import Entity
ENTITY_RECORD_SUBCLASSES: MutableSet[type[EntityRecordBase]] = set()
ENTITY_CLASSES: MutableMapping[type[EntityRecordBase], type[Entity]] = {}
# We need this class to swallow the 'init' kwarg in __init_subclass__ calls from EntityRecordBase
[docs]
class EntityRecordBase[
T_Journal: Journal,
](
EntityRecordMeta,
type_hints.CachedTypeHintsMixin,
LoggableHierarchicalRootModel,
EntityImpl,
EntitySchema,
NamedMixinMinimal,
metaclass=ABCMeta,
# We need init=False here to ensure that pyright looks at the __init__ method for the entity-specific EntitySchema base,
# e.g. InstrumentSchema, TransactionSchema, etc
init=False,
):
__callguard_class_options__ = CallguardClassOptions["EntityRecordBase"](
decorator=superseded_check,
decorate_public_methods=True,
ignore_patterns=(
"superseding",
"superseded",
"deleted",
"reverted",
"exists",
"marked_for_deletion",
"uid",
"version",
"entity_or_none",
"entity",
"entity_log",
"instance_name",
),
)
model_config = ConfigDict(
extra="forbid",
frozen=True,
validate_assignment=True,
)
# MARK: Metaclass infrastructure
def __init_subclass__(cls, *, init: bool = False, unsafe_hash: bool = False) -> None:
super().__init_subclass__()
ENTITY_RECORD_SUBCLASSES.add(cls)
# Initialise dependencies
cls.__init_dependencies__()
# TODO: Move to entity?
[docs]
@classmethod
def is_update_allowed(cls, *, in_commit_only: bool = True, allow_in_abort: bool = False, force_session: bool = False) -> bool:
# Check if we are in the middle of a commit
from ...journal.session_manager import SessionManager
session_manager = SessionManager.get_global_manager_or_none()
if session_manager is None:
if force_session or not script_info.is_unit_test():
return False
else:
if not session_manager.in_session or (session := session_manager.session) is None:
return False
if allow_in_abort and session.in_abort:
return True
if in_commit_only and not session.in_commit:
return False
return True
# MARK: Initialization / Destruction
[docs]
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
if not cls.is_update_allowed(in_commit_only=False):
msg = f"Not allowed to create {cls.__name__} instances outside of a session."
raise RuntimeError(msg)
return super().__new__(cls, *args, **kwargs)
[docs]
@override
def model_post_init(self, context: Any) -> None:
super().model_post_init(context)
from .entity_record import EntityRecord
assert isinstance(self, EntityRecord), f"Expected EntityRecord, got {type(self).__name__}"
self.entity_log.on_init_record(self)
self.entity_dependents.on_init_record(self)
self.entity.on_init_record(self)
if self.entity_log.most_recent.what == EntityModificationType.CREATED and (session := self.session_or_none) is not None:
session.on_entity_record_created(self)
# MARK: Deletion
@property
def deleted(self) -> bool:
if not self.superseded:
return False
if self._reverted:
return True
superseding_log = self.entity_log.get_entry_by_version(self.version + 1)
if superseding_log is None:
msg = f"Entity record {self} is marked as superseded but no audit log entry found for version {self.version + 1}."
raise ValueError(msg)
return superseding_log.record_deleted
@property
def exists(self) -> bool:
return not self.deleted
@property
def marked_for_deletion(self) -> bool:
if self.deleted:
return True
journal = self.get_journal(create=False, fail=False)
return journal.marked_for_deletion if journal is not None else False
def __del__(self) -> None:
# No need to track deletion in finalizing state
if sys.is_finalizing:
return
self.log.debug(t"Entity record __del__ called for {self}.")
if not self.superseded:
self.log.warning(
t"Entity record {self} is being garbage collected without being superseded. This may indicate a logic error or improper session management."
)
self._apply_deletion(who="system", why="__del__")
[docs]
def delete(self) -> None:
if self.marked_for_deletion:
return
if not self.is_update_allowed(in_commit_only=False):
msg = f"Not allowed to delete {type(self).__name__} instances outside of a session."
raise RuntimeError(msg)
self.log.debug(t"Entity record delete called for {self}.")
# Only situation is_update_allowed returns True but we are not in a session is during unit tests without a session manager,
# in which case it is fine to immediately delete the entity
if self.in_session:
self.journal.delete()
else:
assert script_info.is_unit_test(), f"Unexpected non-session deletion of {self} outside of unit test."
self._apply_deletion()
[docs]
def propagate_deletion(self) -> None:
if not self.marked_for_deletion:
return
for uid in self.children_uids:
child = EntityRecordBase.by_uid_or_none(uid)
if child is not None and not child.marked_for_deletion:
child.delete()
[docs]
def apply_deletion(self, *, who: str | None = None, why: str | None = None) -> None:
if not self.is_update_allowed(allow_in_abort=True):
msg = f"Not allowed to apply deletion to {type(self).__name__} instances outside of a session commit or abort."
raise RuntimeError(msg)
self._apply_deletion(who=who, why=why)
def _apply_deletion(self, *, who: str | None = None, why: str | None = None) -> None:
from .entity_record import EntityRecord
assert isinstance(self, EntityRecord), f"Expected EntityRecord, got {type(self).__name__}"
self.entity_log.on_delete_record(self, who=who, why=why)
self.entity_dependents.on_delete_record(self)
self.entity.on_delete_record()
self.log.info(t"Entity record {self} has been deleted.")
# MARK: Revertion
_reverted: bool = PrivateAttr(default=False)
@property
def reverted(self) -> bool:
return self._reverted
[docs]
def revert(self) -> None:
if self._reverted:
return
if not self.is_update_allowed(in_commit_only=True, allow_in_abort=True):
msg = f"Not allowed to revert {type(self).__name__} instances outside of a session commit or abort."
raise RuntimeError(msg)
if self.entity_log.version >= self.version:
msg = f"Cannot revert entity record {self} because its entity log is still tracking it."
raise RuntimeError(msg)
self._reverted = True
# MARK: Instance Name
PROPAGATE_INSTANCE_NAME_FROM_PARENT: ClassVar[bool] = False
STRICT_INSTANCE_NAME_VALIDATION: ClassVar[bool] = True
@property
@override
def instance_name(self) -> str:
"""Get the instance name, or class name if not set."""
if (entity := self.entity_or_none) is None:
return f"{type(self).__name__}@unknown"
return entity.instance_name
# MARK: Uid
@model_validator(mode="before")
@classmethod
def _validate_uid_before(cls, data: Any) -> Self:
if (uid := data.get("uid", None)) is None:
msg = "Entity record must have a valid 'uid' to be created. None found."
raise ValueError(msg)
if not isinstance(uid, Uid):
msg = f"Expected 'uid' to be of type Uid, got {type(uid).__name__}."
raise TypeError(msg)
if not isinstance(uid, Uid):
msg = f"Expected 'uid' to be of type Uid, got {type(uid).__name__}."
raise TypeError(msg)
uid_namespace = cls.get_entity_class().uid_namespace()
if uid.namespace != uid_namespace:
msg = f"Uid namespace '{uid.namespace}' does not match expected namespace '{uid_namespace}'."
raise ValueError(msg)
# If this entity already exists in the store, confirm a version was explicitly passed
if (existing := cls._get_entity_store().get_entity_record(uid)) is not None:
version = data.get("version", None)
if version is None:
msg = f"Entity record with UID {uid} already exists. You must provide an explicit version to create a new version of the entity."
raise ValueError(msg)
elif version != existing.entity_log.next_version:
msg = f"Entity record with UID {uid} already exists with version {existing.version}. You must provide the next version {existing.entity_log.next_version} to create a new version of the entity."
raise ValueError(msg)
data["uid"] = uid
return data
@classmethod
def _get_entity_store(cls) -> EntityStore:
from ..store.entity_store import EntityStore
if (uid_storage := EntityStore.get_global_store()) is None:
msg = f"Global EntityStore is not set. Please create an EntityStore instance and call set_as_global_store() on it before creating {cls.__name__} instances."
raise ValueError(msg)
return uid_storage
[docs]
@classmethod
def by_uid_or_none[T: EntityRecordBase](cls: type[T], uid: Uid) -> T | None:
if not isinstance(uid, Uid):
msg = f"Expected 'uid' to be of type Uid, got {type(uid).__name__}."
raise TypeError(msg)
result = cls._get_entity_store().get_entity_record(uid)
if result is None:
return None
if not isinstance(result, cls):
msg = f"UID storage returned an instance of {type(result).__name__} instead of {cls.__name__}."
raise TypeError(msg)
return result
[docs]
@classmethod
def by_uid[T: EntityRecordBase](cls: type[T], uid: Uid) -> T:
if (result := cls.by_uid_or_none(uid)) is None:
msg = f"Could not find an entity record of type {cls.__name__} for UID {uid}."
raise ValueError(msg)
return result
[docs]
@classmethod
def narrow_to_uid[T: EntityRecordBase](cls: type[T], value: T | Uid) -> Uid:
if isinstance(value, Uid):
# try to convert to concrete entity record so we can test isinstance
record = cls.by_uid_or_none(value)
if record is None:
# We cannot sanity check in this case - we want to support narrowing UIDs that may not yet exist in the store
pass
elif not isinstance(record, cls):
msg = f"UID {value} does not correspond to an instance of class {cls.__name__}. Found instance of {type(record).__name__}."
raise TypeError(msg)
return value
elif isinstance(value, cls):
return value.uid
else:
msg = f"Value must be a {cls.__name__} or Uid, got {type(value)}"
raise TypeError(msg)
[docs]
@classmethod
def narrow_to_instance_or_none[T: EntityRecordBase](cls: type[T], value: T | Uid) -> T | None:
if isinstance(value, cls):
return value
elif isinstance(value, Uid):
record = cls.by_uid(value)
if record is None:
return None
if not isinstance(record, cls):
msg = f"UID {value} does not correspond to an instance of {cls.__name__}. Found instance of {type(record).__name__}."
raise TypeError(msg)
return record
else:
msg = f"Value must be a {cls.__name__} or Uid, got {type(value)}"
raise TypeError(msg)
[docs]
@classmethod
def narrow_to_instance[T: EntityRecordBase](cls: type[T], value: T | Uid) -> T:
if (result := cls.narrow_to_instance_or_none(value)) is None:
msg = f"Could not find an entity record of type {cls.__name__} for value {value}."
raise ValueError(msg)
return result
# MARK: Entity
[docs]
@classmethod
def register_entity_class(cls, entity_class: type[Entity]) -> None:
from .entity import Entity
if not issubclass(entity_class, Entity):
msg = f"Expected 'entity_class' to be a subclass of EntityBase, got {entity_class.__name__}."
raise TypeError(msg)
if cls in ENTITY_CLASSES:
msg = f"Entity record class {cls.__name__} is already registered with entity class {ENTITY_CLASSES[cls].__name__}."
raise ValueError(msg)
ENTITY_CLASSES[cls] = entity_class
[docs]
@classmethod
def get_entity_class(cls) -> type[Entity]:
if (entity_cls := ENTITY_CLASSES.get(cls)) is None:
msg = f"Entity record class {cls.__name__} is not registered with any entity class. Please call 'register_entity_class' to register it."
raise ValueError(msg)
return entity_cls
@property
def entity_or_none(self) -> Entity | None:
if self._reverted:
return None
from .entity import Entity
return Entity.by_uid_or_none(self.uid)
@property
@override
def entity(self) -> Entity:
if (entity := self.entity_or_none) is None:
msg = f"Entity record with UID {self.uid} is not associated with any Entity instance."
raise ValueError(msg)
return entity
[docs]
def call_entity_method(self, name: str, *args, **kwargs) -> Any:
method = getattr(self.entity, name)
return method(*args, **kwargs)
# MARK: Parent
PROPAGATE_INSTANCE_PARENT_FROM_PARENT_TO_CHILDREN: ClassVar[bool] = True
@property
@override
def instance_parent(self) -> HierarchicalProtocol | NamedProtocol | None:
return self.entity_or_none
@property
def record_parent_or_none(self) -> EntityRecordBase | None:
if (entity := self.entity_or_none) is None:
return None
if (parent := entity.instance_parent) is None:
return None
from .entity_base import EntityBase
record = parent.record_or_none if isinstance(parent, EntityBase) else parent
if record is None or not isinstance(record, EntityRecordBase):
return None
return record.superseding_or_none
@property
def record_parent(self) -> EntityRecordBase:
if (parent := self.record_parent_or_none) is None:
msg = f"{type(self).__name__} instance {self.uid} has no valid entity record parent."
raise ValueError(msg)
return parent
# MARK: Version / Entity Log
if TYPE_CHECKING:
entity_log: EntityLog
else:
@property
def entity_log(self) -> EntityLog:
return self.entity.entity_log
@field_validator("version", mode="before")
@classmethod
def _validate_version_before(cls, version: PositiveInt | None, info: ValidationInfo) -> PositiveInt:
if version is None:
version = typing_cast("PositiveInt", EntityLog(info.data["uid"]).next_version)
return version
@field_validator("version", mode="after")
@classmethod
def _validate_version(cls, version: PositiveInt, info: ValidationInfo) -> PositiveInt:
entity_log = EntityLog(info.data["uid"])
if version != entity_log.next_version:
msg = f"Entity record version '{version}' does not match the next audit log version '{entity_log.version + 1}'. The version should be incremented when the entity is cloned as part of an update action."
raise ValueError(msg)
return version
[docs]
def is_newer_version_than(self, other: EntityRecordBase) -> bool:
if not isinstance(other, EntityRecordBase):
msg = f"Expected EntityRecordBase, got {type(other)}"
raise TypeError(msg)
if self.uid != other.uid:
msg = f"Cannot compare versions of entities with different UIDs: {self.uid} vs {other.uid}"
raise ValueError(msg)
return self.version > other.version
@property
def superseded(self) -> bool:
"""Indicates whether this entity record instance has been superseded by another instance with an incremented version."""
return self._reverted or self.entity_log.version > self.version
@property
def superseding_or_none[T: EntityRecordBase](self: T) -> T | None:
if not self.superseded:
return self
return type(self).by_uid_or_none(self.uid)
@property
def superseding[T: EntityRecordBase](self: T) -> T:
if (result := self.superseding_or_none) is None:
msg = f"Entity record {self} has been superseded but the superseding entity record could not be found."
raise ValueError(msg)
return result
[docs]
def update[T: EntityRecordBase](self: T, **kwargs: Any) -> T:
"""Create a new instance of the entity record with the updated data.
The new instance will have an incremented version and the same UID, superseding the current instance.
"""
# Check if we are in the middle of a commit
if not self.is_update_allowed():
msg = f"Not allowed to update {type(self).__name__} instances outside of a session commit."
raise RuntimeError(msg)
# Validate data
if not kwargs:
msg = "No data provided to update the entity record."
raise ValueError(msg)
if "uid" in kwargs:
msg = "Cannot update the 'uid' of an entity record. The UID is immutable and should not be changed."
raise ValueError(msg)
if "version" in kwargs:
msg = "Cannot update the 'version' of an entity record. The version is managed by the entity record itself and should not be changed directly."
raise ValueError(msg)
args = {}
for field_name in type(self).model_fields:
target_name = self.reverse_field_alias(field_name)
if field_name in kwargs:
args[target_name] = kwargs[field_name]
else:
args[target_name] = getattr(self, field_name)
args.update(kwargs)
args["uid"] = self.uid
args["version"] = self.entity_log.next_version
# Sanity check - name won't change
if (new_name := self.entity.calculate_instance_name_from_dict(args)) != self.instance_name:
msg = f"Updating the entity record cannot change its instance name. Original: '{self.instance_name}', New: '{new_name}'."
raise ValueError(msg)
# Update entity record
new_record = type(self)(**args)
# Sanity check
if not isinstance(new_record, type(self)):
msg = f"Expected new entity record to be an instance of {type(self).__name__}, got {type(new_record).__name__}."
raise TypeError(msg)
if new_record.instance_name != self.instance_name:
msg = f"Updating the entity record cannot change its instance name. Original: '{self.instance_name}', New: '{new_record.instance_name}'."
raise ValueError(msg)
# Return updated entity record
return new_record
# MARK: Session
@property
def session_manager_or_none(self) -> SessionManager | None:
return self.entity.session_manager_or_none
@property
def session_manager(self) -> SessionManager:
return self.entity.session_manager
@property
def session_or_none(self) -> Session | None:
return self.entity.session_or_none
@property
def session(self) -> Session:
return self.entity.session
@property
def in_session(self) -> bool:
return self.entity.in_session
# MARK: Journal
get_journal_class = generics.GenericIntrospectionMethod[T_Journal]()
@property
@override
def is_journal(self) -> bool:
return False
[docs]
def get_journal(self, *, create: bool = True, fail: bool = True) -> Journal | None:
session = self.session if fail else self.session_or_none
from .entity_record import EntityRecord
assert isinstance(self, EntityRecord), f"Expected EntityRecord, got {type(self).__name__} instead."
return session.get_record_journal(record=self, create=create) if session is not None else None
@property
def journal(self) -> T_Journal:
result = self.get_journal(create=True)
if result is None:
msg = f"No journal found for entity record {self}."
raise RuntimeError(msg)
journal_cls = self.get_journal_class()
if type(result) is not journal_cls:
msg = f"Expected journal of type {journal_cls}, got {type(result)}."
raise RuntimeError(msg)
return result
@property
def j(self) -> T_Journal:
return self.journal
@property
def has_journal(self) -> bool:
return self.get_journal(create=False) is not None
@staticmethod
def _is_entity_record_attribute(name: str) -> bool:
return hasattr(EntityRecordBase, name) or name in EntityRecordBase.model_fields or name in EntityRecordBase.model_computed_fields
[docs]
@classmethod
def get_model_field_aliases(cls) -> frozendict[str, str]:
aliases = getattr(cls, "model_field_aliases", None)
if aliases is None:
aliases = {}
for name, info in cls.model_fields.items():
if info.alias:
aliases[info.alias] = name
aliases = frozendict(aliases)
setattr(cls, "model_field_aliases", aliases)
return aliases
[docs]
@classmethod
def is_model_field_alias(cls, alias: str) -> bool:
return cls.get_model_field_aliases().get(alias, None) is not None
[docs]
@classmethod
def resolve_field_alias(cls, alias: str) -> str:
return cls.get_model_field_aliases().get(alias, alias)
[docs]
@classmethod
def get_model_field_reverse_aliases(cls) -> frozendict[str, str]:
reverse = getattr(cls, "model_field_reverse_aliases", None)
if reverse is None:
reverse = {}
for name, info in cls.model_fields.items():
if info.alias:
reverse[name] = info.alias
reverse = frozendict(reverse)
setattr(cls, "model_field_reverse_aliases", reverse)
return reverse
[docs]
@classmethod
def reverse_field_alias(cls, name: str) -> str:
return cls.get_model_field_reverse_aliases().get(name, name)
[docs]
@classmethod
def is_model_field(cls, field: str) -> bool:
return field in cls.model_fields
[docs]
@classmethod
def is_computed_field(cls, field: str) -> bool:
return field in cls.model_computed_fields
@property
def dirty(self) -> bool:
if not self.in_session:
return False
j = self.get_journal(create=False)
return j.dirty if j is not None else False
@property
def has_diff(self) -> bool:
if not self.in_session:
return False
j = self.get_journal(create=False)
return j.has_diff if j is not None else False
[docs]
def is_journal_field_edited(self, field: str) -> bool:
journal = self.get_journal(create=False)
return journal.is_field_edited(field) if journal is not None else False
[docs]
def get_journal_field(self, field: str, *, create: bool = False) -> Any:
journal = self.get_journal(create=create)
if journal is None or not journal.is_field_edited(field):
return getattr(self, field)
else:
return getattr(journal, field)
_PROTECTED_FIELD_TYPES: ClassVar[tuple[type, ...]]
@classmethod
def _get_protected_field_types(cls) -> tuple[type, ...]:
s = getattr(EntityRecordBase, "_PROTECTED_FIELD_TYPES", None)
if s is not None:
return s
from ...collections import JournalledCollection, OrderedViewMutableSet
default = (JournalledCollection, OrderedViewMutableSet, OrderedViewMutableSet, EntityLog, EntityDependents)
setattr(EntityRecordBase, "_PROTECTED_FIELD_TYPES", default)
return default
@classmethod
def _get_field_annotation(cls, field: str) -> Any | None:
for mro in cls.__mro__:
annotations = annotationlib.get_annotations(mro, format=annotationlib.Format.VALUE)
annotation = annotations.get(field, None)
if annotation is not None:
return annotation
return None
_PROTECTED_FIELD_LOOKUP: ClassVar[MutableMapping[str, bool]]
[docs]
@classmethod
def is_protected_field_type(cls, field: str) -> bool:
protected_field_lookup = getattr(cls, "_PROTECTED_FIELD_LOOKUP", None)
if protected_field_lookup is None:
protected_field_lookup = cls._PROTECTED_FIELD_LOOKUP = {}
elif (result := protected_field_lookup.get(field)) is not None:
return result
annotation = cls._get_field_annotation(field)
if annotation is None:
msg = f"Field '{field}' not found in entity record type {cls.__name__} annotations."
raise RuntimeError(msg)
forbidden_types = cls._get_protected_field_types()
result = False
for hint in type_hints.iterate_type_hints(annotation):
origin = generics.get_origin(hint, passthrough=True)
if issubclass(origin, forbidden_types):
result = True
break
protected_field_lookup[field] = result
return result
# MARK: Children
# These are all entities that are considered reachable (and therefore not garbage collected) by the existence of this entity record.
# I.e. those referenced by fields in the entity record as well as annotations.
def _get_children_field_ignore(self, field_name: str) -> bool:
return field_name.startswith("_") or field_name in ("uid", "extra_dependency_uids")
[docs]
def iter_children_uids(self, *, use_journal: bool = False) -> Iterable[Uid]:
# Inspect all fields of the entity record for UIDs or Entities
for attr in type(self).model_fields:
if self._get_children_field_ignore(attr):
continue
if use_journal and (journal := self.get_journal(create=False)) is not None and journal.is_field_edited(attr):
value = getattr(journal, attr, None)
else:
value = getattr(self, attr, None)
if value is None:
continue
if isinstance(value, Uid):
yield value
elif isinstance(value, UidProtocol):
yield value.uid
elif isinstance(value, Iterable) and not isinstance(value, (str, bytes, bytearray)):
for item in value:
if isinstance(item, Uid):
yield item
elif isinstance(item, UidProtocol):
yield item.uid
[docs]
@cached_property
def children_uids(self) -> Iterable[Uid]:
return frozenset(self.iter_children_uids())
@property
def journal_children_uids(self) -> Iterable[Uid]:
journal = self.get_journal(create=False)
return self.children_uids if journal is None else journal.children_uids
[docs]
def get_children_uids(self, *, use_journal: bool = False) -> Iterable[Uid]:
return self.journal_children_uids if use_journal else self.children_uids
@property
def children(self) -> Iterable[EntityRecordBase]:
for uid in self.children_uids:
yield EntityRecordBase.by_uid(uid)
[docs]
def iter_hierarchy(
self, *, condition: Callable[[EntityRecordBase], bool] | None = None, use_journal: bool = False, check_condition_on_return: bool = True
) -> Iterable[EntityRecordBase]:
"""Return a flat ordered set of all entities in this hierarchy."""
if condition is not None and not condition(self):
return
# Iterate dirty children journals
for uid in self.get_children_uids(use_journal=use_journal):
child = EntityRecordBase.by_uid_or_none(uid)
if child is None:
continue
if condition is not None and not condition(child):
continue
yield from child.iter_hierarchy(condition=condition, use_journal=use_journal, check_condition_on_return=check_condition_on_return)
if check_condition_on_return and condition is not None and not condition(self):
msg = f"Entity record {self} failed condition check on return of yield_hierarchy."
raise RuntimeError(msg)
# Yield self, then return
yield self
[docs]
def is_reachable(self, *, recursive: bool = True, use_journal: bool = False) -> bool:
return self.entity.is_reachable(recursive=recursive, use_journal=use_journal)
# MARK: Annotations
[docs]
def on_annotation_record_created(self, annotation_or_uid: AnnotationRecord | Uid) -> None:
if not self.is_update_allowed(in_commit_only=False):
msg = f"Not allowed to modify annotations of {type(self).__name__} instances outside of a session."
raise RuntimeError(msg)
from ..annotation import Annotation
annotation = Annotation.narrow_to_instance(annotation_or_uid)
if not self.in_session:
assert script_info.is_unit_test(), f"Unexpected non-session annotation addition to {self} outside of unit test."
annotations = set(self.annotations)
annotations.add(annotation)
self.update(annotations=annotations)
else:
if annotation in self.journal.annotations:
return
self.journal.annotations.add(annotation)
[docs]
def on_annotation_record_deleted(self, annotation_or_uid: AnnotationRecord | Uid) -> None:
self.log.debug(t"Entity record {self} received deletion notice for annotation {annotation_or_uid}.")
if not self.is_update_allowed(in_commit_only=False, force_session=True):
msg = f"Not allowed to modify annotations of {type(self).__name__} instances outside of a session."
raise RuntimeError(msg)
from ..annotation import Annotation
annotation = Annotation.narrow_to_instance(annotation_or_uid)
if annotation not in self.journal.annotations:
return
self.journal.annotations.discard(annotation)
@field_validator("annotations", mode="before")
@classmethod
def _validate_annotations_before(cls, annotations: Any) -> frozenset:
from ..annotation import Annotation
if annotations is None:
return frozenset()
if not isinstance(annotations, AbstractSet):
msg = f"Expected 'annotations' to be an AbstractSet, got {type(annotations).__name__}."
raise TypeError(msg)
for item in annotations:
if not isinstance(item, Annotation):
msg = f"Expected items in 'annotations' to be of type Annotation or Uid, got {type(item).__name__}."
raise TypeError(msg)
if not isinstance(annotations, frozenset):
annotations = frozenset(annotations)
return annotations
# MARK: Dependents
if TYPE_CHECKING:
entity_dependents: EntityDependents
else:
@property
def entity_dependents(self) -> EntityDependents:
return self.entity.entity_dependents
@property
def dependent_uids(self) -> Iterable[Uid]:
return self.entity_dependents.dependent_uids
@property
def dependents(self) -> Iterable[EntityRecordBase]:
return self.entity_dependents.dependents
# MARK: Dependencies
# TODO: Fix this
@property
def extra_dependencies(self) -> UidProxySet[EntityRecordBase]:
from ...collections import UidProxySet
return UidProxySet[EntityRecordBase](instance=self, field="extra_dependency_uids")
# MARK: Dependency Events
__entity_dependency_event_handler_records__: ClassVar[MutableSet[EntityDependencyEventHandlerBase]]
@classmethod
def __init_dependencies__(cls) -> None:
cls.__entity_dependency_event_handler_records__ = set()
[docs]
@classmethod
def register_dependency_event_handler(cls, record: EntityDependencyEventHandlerBase) -> None:
cls.__entity_dependency_event_handler_records__.add(record)
if script_info.is_unit_test():
@classmethod
def clear_dependency_event_handlers(cls) -> None:
if hasattr(cls, "__entity_dependency_event_handler_records__"):
cls.__entity_dependency_event_handler_records__.clear()
cls.__init_dependencies__()
for t in ENTITY_RECORD_SUBCLASSES:
if issubclass(t, cls):
if hasattr(t, "__entity_dependency_event_handler_records__"):
t.__entity_dependency_event_handler_records__.clear()
t.__init_dependencies__()
[docs]
@classmethod
def iter_dependency_event_handlers(cls) -> Iterable[EntityDependencyEventHandlerBase]:
for subclass in cls.__mro__:
if not issubclass(subclass, EntityRecordBase):
continue
if not hasattr(subclass, "__entity_dependency_event_handler_records__"):
continue
yield from subclass.__entity_dependency_event_handler_records__
def _call_dependency_event_handlers(self, event: EntityDependencyEventType, record: EntityRecordBase, journal: Journal) -> bool:
from .entity_record import EntityRecord
assert isinstance(self, EntityRecord), f"Expected EntityRecord, got {type(self).__name__} instead."
assert isinstance(record, EntityRecord), f"Expected EntityRecord, got {type(record).__name__} instead."
matched = False
for ev_record in type(self).iter_dependency_event_handlers():
matched_current = ev_record(owner=self, event=event, record=record, journal=journal)
matched |= matched_current
# Abort if one of the handlers marks this entity for deletion
if matched_current and self.marked_for_deletion:
break
return matched
[docs]
def on_dependency_updated(self, source: Journal) -> None:
if self.marked_for_deletion:
msg = f"Entity record {self} is already marked for deletion or deleted, cannot process dependency deletion."
raise RuntimeError(msg)
parent = self.record_parent_or_none
record = source.record
self.log.debug(t"Entity record {self} received invalidation from dependency entity record {record}.")
# If the source record is the parent of this entity record, then we should confirm we are still reachable
if record is parent:
if not self.is_reachable(use_journal=True, recursive=False):
self.log.warning(t"Entity record {self} is no longer reachable from its parent {parent}, deleting itself.")
self.delete()
return
# Propagate the update to any fields that reference the source entity record
if record is not parent:
self._propagate_dependency_update(source)
# Call event handlers
self._call_dependency_event_handlers(event=EntityDependencyEventType.UPDATED, record=record, journal=source)
def _propagate_dependency_update(self, source: Journal) -> None:
from ...collections import HasJournalledTypeCollectionProtocol, OnItemUpdatedCollectionProtocol
record = source.record
entity = record.entity
# Loop through entity record fields, and search for OrderedViewMutableSets that referenced the source entity record
for nm in type(self).model_fields:
original = getattr(self, nm, None)
if original is None:
continue
elif original is entity:
# Direct reference to the entity / entity record
value = self.journal.get_field(nm, wrap=True)
if value is not record:
continue
self.log.debug(t"Propagating dependency {record} update to field '{nm}'")
self.journal.set_field(nm, entity)
elif isinstance(original, HasJournalledTypeCollectionProtocol):
self.log.debug(t"Checking collection field '{nm}' for dependency {record}")
edited = self.is_journal_field_edited(nm)
value = self.journal.get_field(nm, wrap=False) if edited else original
# Only propagate if the invalidated child is present in both the wrapped set *and* the original set
# (this avoids propagating invalidations for items that have been added or removed from the set in the same session as the invalidation)
if entity not in original or (value is not original and entity not in value):
continue
if not edited:
value = self.journal.get_field(nm, wrap=True)
assert isinstance(value, OnItemUpdatedCollectionProtocol), f"Expected OnItemUpdatedCollectionProtocol, got {type(value)}"
assert entity in value, f"Expected collection to contain {entity}, but it does not."
self.log.debug(t"Propagating dependency {record} update to collection '{nm}'")
value.on_item_updated(record, source)
[docs]
def on_dependency_deleted(self, source: Journal) -> None:
if self.marked_for_deletion:
msg = f"Entity record {self} is already marked for deletion or deleted, cannot process dependency deletion."
raise RuntimeError(msg)
record = source.record
self.log.debug(t"Entity record {self} received deletion notice from dependency entity record {record}.")
# If the source record is the parent of this entity record, then we should delete ourselves too
if (parent := self.record_parent_or_none) is not None and record.uid == parent.uid:
if parent.uid == record.uid:
self.log.warning(t"Entity record {self} is a child of deleted entity record {record}, deleting itself too.")
self.delete()
return
# Sanity check: source record cannot be a child of this entity record
if record.uid in self.journal_children_uids:
msg = f"Entity record {record} is a child of {self} and therefore the latter cannot be deleted."
raise RuntimeError(msg)
# Call event handlers
self._call_dependency_event_handlers(event=EntityDependencyEventType.DELETED, record=record, journal=source)
if self.marked_for_deletion:
return
# If the source record is in our extra dependencies, we remove it
if record.uid in self.extra_dependency_uids:
self.journal.remove_dependency(record.uid)
# MARK: Serialization
@field_serializer("annotations", mode="plain")
@classmethod
def _serialize_annotations(cls, annotations: frozenset) -> tuple[Any, ...]:
return tuple(annotations)
# MARK: Utilities
[docs]
def sort_key(self) -> SupportsRichComparison:
return self.uid
@override
def __hash__(self) -> int:
return hash((self.uid, self.version))
@override
def __eq__(self, other: object) -> bool:
if isinstance(other, EntityRecordBase):
return self.uid == other.uid and self.version == other.version
else:
return False
@override
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def _customize_str_repr(self, spr: str) -> str:
assert spr and spr[0] == "<", f"Expected string representation to start with '<', got {spr} instead." # noqa: PT018
result = spr.removesuffix(">")
result += f" v{self.version}"
if self.deleted:
result += " (X)"
elif self.superseded:
result += " (S)"
result = result.replace(f"{type(self).__name__.removesuffix('Record')}@", "")
return result + ">"
@override
def __str__(self) -> str:
return self._customize_str_repr(super(HierarchicalMixinMinimal, self).__str__())
@override
def __repr__(self) -> str:
return self._customize_str_repr(super(HierarchicalMixinMinimal, self).__repr__())