Source code for zarr.registry

from __future__ import annotations

import warnings
from collections import defaultdict
from importlib.metadata import entry_points as get_entry_points
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from zarr.core.config import BadConfigError, config

if TYPE_CHECKING:
    from importlib.metadata import EntryPoint

    from zarr.abc.codec import (
        ArrayArrayCodec,
        ArrayBytesCodec,
        BytesBytesCodec,
        Codec,
        CodecPipeline,
    )
    from zarr.core.buffer import Buffer, NDBuffer
    from zarr.core.common import JSON

__all__ = [
    "Registry",
    "get_buffer_class",
    "get_codec_class",
    "get_ndbuffer_class",
    "get_pipeline_class",
    "register_buffer",
    "register_codec",
    "register_ndbuffer",
    "register_pipeline",
]

T = TypeVar("T")


[docs] class Registry(dict[str, type[T]], Generic[T]): def __init__(self) -> None: super().__init__()
[docs] self.lazy_load_list: list[EntryPoint] = []
[docs] def lazy_load(self) -> None: for e in self.lazy_load_list: self.register(e.load()) self.lazy_load_list.clear()
[docs] def register(self, cls: type[T]) -> None: self[fully_qualified_name(cls)] = cls
__codec_registries: dict[str, Registry[Codec]] = defaultdict(Registry) __pipeline_registry: Registry[CodecPipeline] = Registry() __buffer_registry: Registry[Buffer] = Registry() __ndbuffer_registry: Registry[NDBuffer] = Registry() """ The registry module is responsible for managing implementations of codecs, pipelines, buffers and ndbuffers and collecting them from entrypoints. The implementation used is determined by the config. """ def _collect_entrypoints() -> list[Registry[Any]]: """ Collects codecs, pipelines, buffers and ndbuffers from entrypoints. Entry points can either be single items or groups of items. Allowed syntax for entry_points.txt is e.g. [zarr.codecs] gzip = package:EntrypointGzipCodec1 [zarr.codecs.gzip] some_name = package:EntrypointGzipCodec2 another = package:EntrypointGzipCodec3 [zarr] buffer = package:TestBuffer1 [zarr.buffer] xyz = package:TestBuffer2 abc = package:TestBuffer3 ... """ entry_points = get_entry_points() __buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.buffer")) __buffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="buffer")) __ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr.ndbuffer")) __ndbuffer_registry.lazy_load_list.extend(entry_points.select(group="zarr", name="ndbuffer")) __pipeline_registry.lazy_load_list.extend(entry_points.select(group="zarr.codec_pipeline")) __pipeline_registry.lazy_load_list.extend( entry_points.select(group="zarr", name="codec_pipeline") ) for e in entry_points.select(group="zarr.codecs"): __codec_registries[e.name].lazy_load_list.append(e) for group in entry_points.groups: if group.startswith("zarr.codecs."): codec_name = group.split(".")[2] __codec_registries[codec_name].lazy_load_list.extend(entry_points.select(group=group)) return [ *__codec_registries.values(), __pipeline_registry, __buffer_registry, __ndbuffer_registry, ] def _reload_config() -> None: config.refresh() def fully_qualified_name(cls: type) -> str: module = cls.__module__ return module + "." + cls.__qualname__
[docs] def register_codec(key: str, codec_cls: type[Codec]) -> None: if key not in __codec_registries: __codec_registries[key] = Registry() __codec_registries[key].register(codec_cls)
[docs] def register_pipeline(pipe_cls: type[CodecPipeline]) -> None: __pipeline_registry.register(pipe_cls)
[docs] def register_ndbuffer(cls: type[NDBuffer]) -> None: __ndbuffer_registry.register(cls)
[docs] def register_buffer(cls: type[Buffer]) -> None: __buffer_registry.register(cls)
[docs] def get_codec_class(key: str, reload_config: bool = False) -> type[Codec]: if reload_config: _reload_config() if key in __codec_registries: # logger.debug("Auto loading codec '%s' from entrypoint", codec_id) __codec_registries[key].lazy_load() codec_classes = __codec_registries[key] if not codec_classes: raise KeyError(key) config_entry = config.get("codecs", {}).get(key) if config_entry is None: if len(codec_classes) == 1: return next(iter(codec_classes.values())) warnings.warn( f"Codec '{key}' not configured in config. Selecting any implementation.", stacklevel=2 ) return list(codec_classes.values())[-1] selected_codec_cls = codec_classes[config_entry] if selected_codec_cls: return selected_codec_cls raise KeyError(key)
def _resolve_codec(data: dict[str, JSON]) -> Codec: """ Get a codec instance from a dict representation of that codec. """ # TODO: narrow the type of the input to only those dicts that map on to codec class instances. return get_codec_class(data["name"]).from_dict(data) # type: ignore[arg-type] def _parse_bytes_bytes_codec(data: dict[str, JSON] | Codec) -> BytesBytesCodec: """ Normalize the input to a ``BytesBytesCodec`` instance. If the input is already a ``BytesBytesCodec``, it is returned as is. If the input is a dict, it is converted to a ``BytesBytesCodec`` instance via the ``_resolve_codec`` function. """ from zarr.abc.codec import BytesBytesCodec if isinstance(data, dict): result = _resolve_codec(data) if not isinstance(result, BytesBytesCodec): msg = f"Expected a dict representation of a BytesBytesCodec; got a dict representation of a {type(result)} instead." raise TypeError(msg) else: if not isinstance(data, BytesBytesCodec): raise TypeError(f"Expected a BytesBytesCodec. Got {type(data)} instead.") result = data return result def _parse_array_bytes_codec(data: dict[str, JSON] | Codec) -> ArrayBytesCodec: """ Normalize the input to a ``ArrayBytesCodec`` instance. If the input is already a ``ArrayBytesCodec``, it is returned as is. If the input is a dict, it is converted to a ``ArrayBytesCodec`` instance via the ``_resolve_codec`` function. """ from zarr.abc.codec import ArrayBytesCodec if isinstance(data, dict): result = _resolve_codec(data) if not isinstance(result, ArrayBytesCodec): msg = f"Expected a dict representation of a ArrayBytesCodec; got a dict representation of a {type(result)} instead." raise TypeError(msg) else: if not isinstance(data, ArrayBytesCodec): raise TypeError(f"Expected a ArrayBytesCodec. Got {type(data)} instead.") result = data return result def _parse_array_array_codec(data: dict[str, JSON] | Codec) -> ArrayArrayCodec: """ Normalize the input to a ``ArrayArrayCodec`` instance. If the input is already a ``ArrayArrayCodec``, it is returned as is. If the input is a dict, it is converted to a ``ArrayArrayCodec`` instance via the ``_resolve_codec`` function. """ from zarr.abc.codec import ArrayArrayCodec if isinstance(data, dict): result = _resolve_codec(data) if not isinstance(result, ArrayArrayCodec): msg = f"Expected a dict representation of a ArrayArrayCodec; got a dict representation of a {type(result)} instead." raise TypeError(msg) else: if not isinstance(data, ArrayArrayCodec): raise TypeError(f"Expected a ArrayArrayCodec. Got {type(data)} instead.") result = data return result
[docs] def get_pipeline_class(reload_config: bool = False) -> type[CodecPipeline]: if reload_config: _reload_config() __pipeline_registry.lazy_load() path = config.get("codec_pipeline.path") pipeline_class = __pipeline_registry.get(path) if pipeline_class: return pipeline_class raise BadConfigError( f"Pipeline class '{path}' not found in registered pipelines: {list(__pipeline_registry)}." )
[docs] def get_buffer_class(reload_config: bool = False) -> type[Buffer]: if reload_config: _reload_config() __buffer_registry.lazy_load() path = config.get("buffer") buffer_class = __buffer_registry.get(path) if buffer_class: return buffer_class raise BadConfigError( f"Buffer class '{path}' not found in registered buffers: {list(__buffer_registry)}." )
[docs] def get_ndbuffer_class(reload_config: bool = False) -> type[NDBuffer]: if reload_config: _reload_config() __ndbuffer_registry.lazy_load() path = config.get("ndbuffer") ndbuffer_class = __ndbuffer_registry.get(path) if ndbuffer_class: return ndbuffer_class raise BadConfigError( f"NDBuffer class '{path}' not found in registered buffers: {list(__ndbuffer_registry)}." )
_collect_entrypoints()