# SPDX-License-Identifier: GPLv3-or-later
# Copyright © 2025 pygaindalf Rui Pinheiro
from typing import Any, Generator
from dataclasses import dataclass
from contextlib import contextmanager
from contextvars import ContextVar, Token
from pydantic import BaseModel
CONFIG_CONTEXT = ContextVar['ContextStack']('config_context')
[docs]
@dataclass
class ContextStack:
type CurrentType = dict[str, Any] | BaseModel
name : str | None = None
parent : 'ContextStack | None' = None
current : CurrentType | None = None
token : Token | None = None
[docs]
@classmethod
def push(cls, current: CurrentType, name : str | None = None) -> 'ContextStack':
"""
Push a new context onto the stack.
"""
parent = CONFIG_CONTEXT.get(None)
new_context = cls(name=name, parent=parent, current=current)
new_context.token = CONFIG_CONTEXT.set(new_context)
return new_context
[docs]
@classmethod
def pop(cls, token : Token | None = None) -> None:
"""
Pop the current context from the stack and return it.
If the stack is empty, return None.
"""
current = CONFIG_CONTEXT.get()
if current is None:
raise RuntimeError("No context to pop from the stack")
if current.token is None:
raise RuntimeError("Current context does not have a valid token")
if token is not None and current.token != token:
raise RuntimeError("Provided token does not match the current context token")
CONFIG_CONTEXT.reset(current.token)
[docs]
@classmethod
def get(cls) -> 'ContextStack | None':
"""
Get the current context from the stack.
If the stack is empty, return None.
"""
return CONFIG_CONTEXT.get(None)
[docs]
@classmethod
def iterate(cls, skip : int = 0) -> 'Generator[ContextStack]':
"""
Iterate over the current context and all parent contexts.
"""
context = cls.get()
while context is not None:
if skip > 0:
skip -= 1
else:
yield context
context = context.parent
[docs]
@classmethod
def find_inheritance(cls, name : str, skip : int = 0) -> dict[str, Any] | BaseModel | None:
"""
Find a context by name in the stack.
"""
scope = [name]
for context in cls.iterate(skip=skip):
# Search for the name in the current context
_scope = scope + ['default']
for i in range(len(_scope)):
found = True
obj : dict[str, Any] | BaseModel | None = context.current
for s in _scope[i::-1]:
if isinstance(obj, dict):
if s not in obj:
found = False
break
obj = obj[s]
elif isinstance(obj, BaseModel):
if not hasattr(obj, s):
found = False
break
obj = getattr(obj, s)
else:
# If the object is neither a dict nor a BaseModel, we cannot find the scope.
found = False
break
if found:
return obj
if context.name is not None:
scope.append(context.name)
return None
[docs]
@classmethod
def find_name(cls, name : str, skip : int = 0) -> bool:
"""
Check if a context with the given name exists in the stack.
"""
for context in cls.iterate(skip=skip):
if context.name == name:
return True
return False
[docs]
@classmethod
@contextmanager
def with_context(cls, current : CurrentType, name : str | None = None) -> Any:
"""
Context manager to push a new context onto the stack.
"""
context = cls.push(current, name=name)
try:
yield context
finally:
cls.pop(context.token)
[docs]
@classmethod
@contextmanager
def with_updated_name(cls, name : str) -> Any:
context = cls.get()
if context is None:
raise RuntimeError("No context to update the name")
old_name = context.name
context.name = name
try:
yield context
finally:
context.name = old_name