Source code for zarr.codecs.zstd

from __future__ import annotations

import asyncio
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING

import numcodecs
from numcodecs.zstd import Zstd
from packaging.version import Version

from zarr.abc.codec import BytesBytesCodec
from zarr.core.buffer.cpu import as_numpy_array_wrapper
from zarr.core.common import JSON, parse_named_configuration
from zarr.registry import register_codec

if TYPE_CHECKING:
    from typing import Self

    from zarr.core.array_spec import ArraySpec
    from zarr.core.buffer import Buffer


def parse_zstd_level(data: JSON) -> int:
    if isinstance(data, int):
        if data >= 23:
            raise ValueError(f"Value must be less than or equal to 22. Got {data} instead.")
        return data
    raise TypeError(f"Got value with type {type(data)}, but expected an int.")


def parse_checksum(data: JSON) -> bool:
    if isinstance(data, bool):
        return data
    raise TypeError(f"Expected bool. Got {type(data)}.")


[docs] @dataclass(frozen=True) class ZstdCodec(BytesBytesCodec): is_fixed_size = True level: int = 0 checksum: bool = False def __init__(self, *, level: int = 0, checksum: bool = False) -> None: # numcodecs 0.13.0 introduces the checksum attribute for the zstd codec _numcodecs_version = Version(numcodecs.__version__) if _numcodecs_version < Version("0.13.0"): raise RuntimeError( "numcodecs version >= 0.13.0 is required to use the zstd codec. " f"Version {_numcodecs_version} is currently installed." ) level_parsed = parse_zstd_level(level) checksum_parsed = parse_checksum(checksum) object.__setattr__(self, "level", level_parsed) object.__setattr__(self, "checksum", checksum_parsed)
[docs] @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration(data, "zstd") return cls(**configuration_parsed) # type: ignore[arg-type]
[docs] def to_dict(self) -> dict[str, JSON]: return {"name": "zstd", "configuration": {"level": self.level, "checksum": self.checksum}}
@cached_property def _zstd_codec(self) -> Zstd: config_dict = {"level": self.level, "checksum": self.checksum} return Zstd.from_config(config_dict) async def _decode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer: return await asyncio.to_thread( as_numpy_array_wrapper, self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype ) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: return await asyncio.to_thread( as_numpy_array_wrapper, self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype )
[docs] def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError
register_codec("zstd", ZstdCodec)