129 lines
4.7 KiB
129 lines
4.7 KiB
import asyncio
|
|
import contextlib
|
|
import contextvars
|
|
import threading
|
|
from typing import Any, Dict, Union
|
|
|
|
|
|
class _CVar:
|
|
"""Storage utility for Local."""
|
|
|
|
def __init__(self) -> None:
|
|
self._data: "contextvars.ContextVar[Dict[str, Any]]" = contextvars.ContextVar(
|
|
"asgiref.local"
|
|
)
|
|
|
|
def __getattr__(self, key):
|
|
storage_object = self._data.get({})
|
|
try:
|
|
return storage_object[key]
|
|
except KeyError:
|
|
raise AttributeError(f"{self!r} object has no attribute {key!r}")
|
|
|
|
def __setattr__(self, key: str, value: Any) -> None:
|
|
if key == "_data":
|
|
return super().__setattr__(key, value)
|
|
|
|
storage_object = self._data.get({})
|
|
storage_object[key] = value
|
|
self._data.set(storage_object)
|
|
|
|
def __delattr__(self, key: str) -> None:
|
|
storage_object = self._data.get({})
|
|
if key in storage_object:
|
|
del storage_object[key]
|
|
self._data.set(storage_object)
|
|
else:
|
|
raise AttributeError(f"{self!r} object has no attribute {key!r}")
|
|
|
|
|
|
class Local:
|
|
"""Local storage for async tasks.
|
|
|
|
This is a namespace object (similar to `threading.local`) where data is
|
|
also local to the current async task (if there is one).
|
|
|
|
In async threads, local means in the same sense as the `contextvars`
|
|
module - i.e. a value set in an async frame will be visible:
|
|
|
|
- to other async code `await`-ed from this frame.
|
|
- to tasks spawned using `asyncio` utilities (`create_task`, `wait_for`,
|
|
`gather` and probably others).
|
|
- to code scheduled in a sync thread using `sync_to_async`
|
|
|
|
In "sync" threads (a thread with no async event loop running), the
|
|
data is thread-local, but additionally shared with async code executed
|
|
via the `async_to_sync` utility, which schedules async code in a new thread
|
|
and copies context across to that thread.
|
|
|
|
If `thread_critical` is True, then the local will only be visible per-thread,
|
|
behaving exactly like `threading.local` if the thread is sync, and as
|
|
`contextvars` if the thread is async. This allows genuinely thread-sensitive
|
|
code (such as DB handles) to be kept stricly to their initial thread and
|
|
disable the sharing across `sync_to_async` and `async_to_sync` wrapped calls.
|
|
|
|
Unlike plain `contextvars` objects, this utility is threadsafe.
|
|
"""
|
|
|
|
def __init__(self, thread_critical: bool = False) -> None:
|
|
self._thread_critical = thread_critical
|
|
self._thread_lock = threading.RLock()
|
|
|
|
self._storage: "Union[threading.local, _CVar]"
|
|
|
|
if thread_critical:
|
|
# Thread-local storage
|
|
self._storage = threading.local()
|
|
else:
|
|
# Contextvar storage
|
|
self._storage = _CVar()
|
|
|
|
@contextlib.contextmanager
|
|
def _lock_storage(self):
|
|
# Thread safe access to storage
|
|
if self._thread_critical:
|
|
try:
|
|
# this is a test for are we in a async or sync
|
|
# thread - will raise RuntimeError if there is
|
|
# no current loop
|
|
asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
# We are in a sync thread, the storage is
|
|
# just the plain thread local (i.e, "global within
|
|
# this thread" - it doesn't matter where you are
|
|
# in a call stack you see the same storage)
|
|
yield self._storage
|
|
else:
|
|
# We are in an async thread - storage is still
|
|
# local to this thread, but additionally should
|
|
# behave like a context var (is only visible with
|
|
# the same async call stack)
|
|
|
|
# Ensure context exists in the current thread
|
|
if not hasattr(self._storage, "cvar"):
|
|
self._storage.cvar = _CVar()
|
|
|
|
# self._storage is a thread local, so the members
|
|
# can't be accessed in another thread (we don't
|
|
# need any locks)
|
|
yield self._storage.cvar
|
|
else:
|
|
# Lock for thread_critical=False as other threads
|
|
# can access the exact same storage object
|
|
with self._thread_lock:
|
|
yield self._storage
|
|
|
|
def __getattr__(self, key):
|
|
with self._lock_storage() as storage:
|
|
return getattr(storage, key)
|
|
|
|
def __setattr__(self, key, value):
|
|
if key in ("_local", "_storage", "_thread_critical", "_thread_lock"):
|
|
return super().__setattr__(key, value)
|
|
with self._lock_storage() as storage:
|
|
setattr(storage, key, value)
|
|
|
|
def __delattr__(self, key):
|
|
with self._lock_storage() as storage:
|
|
delattr(storage, key)
|