Source code for zarr.testing.store

from __future__ import annotations

import asyncio
import pickle
from abc import abstractmethod
from typing import TYPE_CHECKING, Generic, TypeVar

from zarr.storage import WrapperStore

if TYPE_CHECKING:
    from typing import Any

    from zarr.abc.store import ByteRequest
    from zarr.core.buffer.core import BufferPrototype

import pytest

from zarr.abc.store import (
    ByteRequest,
    OffsetByteRequest,
    RangeByteRequest,
    Store,
    SuffixByteRequest,
)
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.sync import _collect_aiterator
from zarr.storage._utils import _normalize_byte_range_index
from zarr.testing.utils import assert_bytes_equal

__all__ = ["StoreTests"]


S = TypeVar("S", bound=Store)
B = TypeVar("B", bound=Buffer)


[docs] class StoreTests(Generic[S, B]):
[docs] store_cls: type[S]
[docs] buffer_cls: type[B]
@abstractmethod
[docs] async def set(self, store: S, key: str, value: Buffer) -> None: """ Insert a value into a storage backend, with a specific key. This should not use any store methods. Bypassing the store methods allows them to be tested. """ ...
@abstractmethod
[docs] async def get(self, store: S, key: str) -> Buffer: """ Retrieve a value from a storage backend, by key. This should not use any store methods. Bypassing the store methods allows them to be tested. """ ...
@abstractmethod @pytest.fixture
[docs] def store_kwargs(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Kwargs for instantiating a store""" ...
@abstractmethod
[docs] def test_store_repr(self, store: S) -> None: ...
@abstractmethod
[docs] def test_store_supports_writes(self, store: S) -> None: ...
@abstractmethod
[docs] def test_store_supports_partial_writes(self, store: S) -> None: ...
@abstractmethod
[docs] def test_store_supports_listing(self, store: S) -> None: ...
@pytest.fixture
[docs] def open_kwargs(self, store_kwargs: dict[str, Any]) -> dict[str, Any]: return store_kwargs
@pytest.fixture
[docs] async def store(self, open_kwargs: dict[str, Any]) -> Store: return await self.store_cls.open(**open_kwargs)
@pytest.fixture
[docs] async def store_not_open(self, store_kwargs: dict[str, Any]) -> Store: return self.store_cls(**store_kwargs)
[docs] def test_store_type(self, store: S) -> None: assert isinstance(store, Store) assert isinstance(store, self.store_cls)
[docs] def test_store_eq(self, store: S, store_kwargs: dict[str, Any]) -> None: # check self equality assert store == store # check store equality with same inputs # asserting this is important for being able to compare (de)serialized stores store2 = self.store_cls(**store_kwargs) assert store == store2
[docs] async def test_serializable_store(self, store: S) -> None: new_store: S = pickle.loads(pickle.dumps(store)) assert new_store == store assert new_store.read_only == store.read_only # quickly roundtrip data to a key to test that new store works data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") key = "foo" await store.set(key, data_buf) observed = await store.get(key, prototype=default_buffer_prototype()) assert_bytes_equal(observed, data_buf)
[docs] def test_store_read_only(self, store: S) -> None: assert not store.read_only with pytest.raises(AttributeError): store.read_only = False # type: ignore[misc]
@pytest.mark.parametrize("read_only", [True, False])
[docs] async def test_store_open_read_only(self, open_kwargs: dict[str, Any], read_only: bool) -> None: open_kwargs["read_only"] = read_only store = await self.store_cls.open(**open_kwargs) assert store._is_open assert store.read_only == read_only
[docs] async def test_store_context_manager(self, open_kwargs: dict[str, Any]) -> None: # Test that the context manager closes the store with await self.store_cls.open(**open_kwargs) as store: assert store._is_open # Test trying to open an already open store with pytest.raises(ValueError, match="store is already open"): await store._open() assert not store._is_open
[docs] async def test_read_only_store_raises(self, open_kwargs: dict[str, Any]) -> None: kwargs = {**open_kwargs, "read_only": True} store = await self.store_cls.open(**kwargs) assert store.read_only # set with pytest.raises( ValueError, match="store was opened in read-only mode and does not support writing" ): await store.set("foo", self.buffer_cls.from_bytes(b"bar")) # delete with pytest.raises( ValueError, match="store was opened in read-only mode and does not support writing" ): await store.delete("foo")
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize( ("data", "byte_range"), [ (b"\x01\x02\x03\x04", None), (b"\x01\x02\x03\x04", RangeByteRequest(1, 4)), (b"\x01\x02\x03\x04", OffsetByteRequest(1)), (b"\x01\x02\x03\x04", SuffixByteRequest(1)), (b"", None), ], )
[docs] async def test_get(self, store: S, key: str, data: bytes, byte_range: ByteRequest) -> None: """ Ensure that data can be read from the store using the store.get method. """ data_buf = self.buffer_cls.from_bytes(data) await self.set(store, key, data_buf) observed = await store.get(key, prototype=default_buffer_prototype(), byte_range=byte_range) start, stop = _normalize_byte_range_index(data_buf, byte_range=byte_range) expected = data_buf[start:stop] assert_bytes_equal(observed, expected)
[docs] async def test_get_not_open(self, store_not_open: S) -> None: """ Ensure that data can be read from the store that isn't yet open using the store.get method. """ assert not store_not_open._is_open data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") key = "c/0" await self.set(store_not_open, key, data_buf) observed = await store_not_open.get(key, prototype=default_buffer_prototype()) assert_bytes_equal(observed, data_buf)
[docs] async def test_get_raises(self, store: S) -> None: """ Ensure that a ValueError is raise for invalid byte range syntax """ data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") await self.set(store, "c/0", data_buf) with pytest.raises((ValueError, TypeError), match=r"Unexpected byte_range, got.*"): await store.get("c/0", prototype=default_buffer_prototype(), byte_range=(0, 2)) # type: ignore[arg-type]
[docs] async def test_get_many(self, store: S) -> None: """ Ensure that multiple keys can be retrieved at once with the _get_many method. """ keys = tuple(map(str, range(10))) values = tuple(f"{k}".encode() for k in keys) for k, v in zip(keys, values, strict=False): await self.set(store, k, self.buffer_cls.from_bytes(v)) observed_buffers = await _collect_aiterator( store._get_many( zip( keys, (default_buffer_prototype(),) * len(keys), (None,) * len(keys), strict=False, ) ) ) observed_kvs = sorted(((k, b.to_bytes()) for k, b in observed_buffers)) # type: ignore[union-attr] expected_kvs = sorted(((k, b) for k, b in zip(keys, values, strict=False))) assert observed_kvs == expected_kvs
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
[docs] async def test_getsize(self, store: S, key: str, data: bytes) -> None: """ Test the result of store.getsize(). """ data_buf = self.buffer_cls.from_bytes(data) expected = len(data_buf) await self.set(store, key, data_buf) observed = await store.getsize(key) assert observed == expected
[docs] async def test_getsize_prefix(self, store: S) -> None: """ Test the result of store.getsize_prefix(). """ data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") keys = ["c/0/0", "c/0/1", "c/1/0", "c/1/1"] keys_values = [(k, data_buf) for k in keys] await store._set_many(keys_values) expected = len(data_buf) * len(keys) observed = await store.getsize_prefix("c") assert observed == expected
[docs] async def test_getsize_raises(self, store: S) -> None: """ Test that getsize() raise a FileNotFoundError if the key doesn't exist. """ with pytest.raises(FileNotFoundError): await store.getsize("c/1000")
@pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"]) @pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
[docs] async def test_set(self, store: S, key: str, data: bytes) -> None: """ Ensure that data can be written to the store using the store.set method. """ assert not store.read_only data_buf = self.buffer_cls.from_bytes(data) await store.set(key, data_buf) observed = await self.get(store, key) assert_bytes_equal(observed, data_buf)
[docs] async def test_set_not_open(self, store_not_open: S) -> None: """ Ensure that data can be written to the store that's not yet open using the store.set method. """ assert not store_not_open._is_open data_buf = self.buffer_cls.from_bytes(b"\x01\x02\x03\x04") key = "c/0" await store_not_open.set(key, data_buf) observed = await self.get(store_not_open, key) assert_bytes_equal(observed, data_buf)
[docs] async def test_set_many(self, store: S) -> None: """ Test that a dict of key : value pairs can be inserted into the store via the `_set_many` method. """ keys = ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"] data_buf = [self.buffer_cls.from_bytes(k.encode()) for k in keys] store_dict = dict(zip(keys, data_buf, strict=True)) await store._set_many(store_dict.items()) for k, v in store_dict.items(): assert (await self.get(store, k)).to_bytes() == v.to_bytes()
@pytest.mark.parametrize( "key_ranges", [ [], [("zarr.json", RangeByteRequest(0, 2))], [("c/0", RangeByteRequest(0, 2)), ("zarr.json", None)], [ ("c/0/0", RangeByteRequest(0, 2)), ("c/0/1", SuffixByteRequest(2)), ("c/0/2", OffsetByteRequest(2)), ], ], )
[docs] async def test_get_partial_values( self, store: S, key_ranges: list[tuple[str, ByteRequest]] ) -> None: # put all of the data for key, _ in key_ranges: await self.set(store, key, self.buffer_cls.from_bytes(bytes(key, encoding="utf-8"))) # read back just part of it observed_maybe = await store.get_partial_values( prototype=default_buffer_prototype(), key_ranges=key_ranges ) observed: list[Buffer] = [] expected: list[Buffer] = [] for obs in observed_maybe: assert obs is not None observed.append(obs) for idx in range(len(observed)): key, byte_range = key_ranges[idx] result = await store.get( key, prototype=default_buffer_prototype(), byte_range=byte_range ) assert result is not None expected.append(result) assert all( obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True) )
[docs] async def test_exists(self, store: S) -> None: assert not await store.exists("foo") await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) assert await store.exists("foo/zarr.json")
[docs] async def test_delete(self, store: S) -> None: if not store.supports_deletes: pytest.skip("store does not support deletes") await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) assert await store.exists("foo/zarr.json") await store.delete("foo/zarr.json") assert not await store.exists("foo/zarr.json")
[docs] async def test_delete_dir(self, store: S) -> None: if not store.supports_deletes: pytest.skip("store does not support deletes") await store.set("zarr.json", self.buffer_cls.from_bytes(b"root")) await store.set("foo-bar/zarr.json", self.buffer_cls.from_bytes(b"root")) await store.set("foo/zarr.json", self.buffer_cls.from_bytes(b"bar")) await store.set("foo/c/0", self.buffer_cls.from_bytes(b"chunk")) await store.delete_dir("foo") assert await store.exists("zarr.json") assert await store.exists("foo-bar/zarr.json") assert not await store.exists("foo/zarr.json") assert not await store.exists("foo/c/0")
[docs] async def test_is_empty(self, store: S) -> None: assert await store.is_empty("") await self.set( store, "foo/bar", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8")) ) assert not await store.is_empty("") assert await store.is_empty("fo") assert not await store.is_empty("foo/") assert not await store.is_empty("foo") assert await store.is_empty("spam/")
[docs] async def test_clear(self, store: S) -> None: await self.set( store, "key", self.buffer_cls.from_bytes(bytes("something", encoding="utf-8")) ) await store.clear() assert await store.is_empty("")
[docs] async def test_list(self, store: S) -> None: assert await _collect_aiterator(store.list()) == () prefix = "foo" data = self.buffer_cls.from_bytes(b"") store_dict = { prefix + "/zarr.json": data, **{prefix + f"/c/{idx}": data for idx in range(10)}, } await store._set_many(store_dict.items()) expected_sorted = sorted(store_dict.keys()) observed = await _collect_aiterator(store.list()) observed_sorted = sorted(observed) assert observed_sorted == expected_sorted
[docs] async def test_list_prefix(self, store: S) -> None: """ Test that the `list_prefix` method works as intended. Given a prefix, it should return all the keys in storage that start with this prefix. """ prefixes = ("", "a/", "a/b/", "a/b/c/") data = self.buffer_cls.from_bytes(b"") fname = "zarr.json" store_dict = {p + fname: data for p in prefixes} await store._set_many(store_dict.items()) for prefix in prefixes: observed = tuple(sorted(await _collect_aiterator(store.list_prefix(prefix)))) expected: tuple[str, ...] = () for key in store_dict: if key.startswith(prefix): expected += (key,) expected = tuple(sorted(expected)) assert observed == expected
[docs] async def test_list_empty_path(self, store: S) -> None: """ Verify that list and list_prefix work correctly when path is an empty string, i.e. no unwanted replacement occurs. """ data = self.buffer_cls.from_bytes(b"") store_dict = { "foo/bar/zarr.json": data, "foo/bar/c/1": data, "foo/baz/c/0": data, } await store._set_many(store_dict.items()) # Test list() observed_list = await _collect_aiterator(store.list()) observed_list_sorted = sorted(observed_list) expected_list_sorted = sorted(store_dict.keys()) assert observed_list_sorted == expected_list_sorted # Test list_prefix() with an empty prefix observed_prefix_empty = await _collect_aiterator(store.list_prefix("")) observed_prefix_empty_sorted = sorted(observed_prefix_empty) expected_prefix_empty_sorted = sorted(store_dict.keys()) assert observed_prefix_empty_sorted == expected_prefix_empty_sorted # Test list_prefix() with a non-empty prefix observed_prefix = await _collect_aiterator(store.list_prefix("foo/bar/")) observed_prefix_sorted = sorted(observed_prefix) expected_prefix_sorted = sorted(k for k in store_dict if k.startswith("foo/bar/")) assert observed_prefix_sorted == expected_prefix_sorted
[docs] async def test_list_dir(self, store: S) -> None: root = "foo" store_dict = { root + "/zarr.json": self.buffer_cls.from_bytes(b"bar"), root + "/c/1": self.buffer_cls.from_bytes(b"\x01"), } assert await _collect_aiterator(store.list_dir("")) == () assert await _collect_aiterator(store.list_dir(root)) == () await store._set_many(store_dict.items()) keys_observed = await _collect_aiterator(store.list_dir(root)) keys_expected = {k.removeprefix(root + "/").split("/")[0] for k in store_dict} assert sorted(keys_observed) == sorted(keys_expected) keys_observed = await _collect_aiterator(store.list_dir(root + "/")) assert sorted(keys_expected) == sorted(keys_observed)
[docs] async def test_set_if_not_exists(self, store: S) -> None: key = "k" data_buf = self.buffer_cls.from_bytes(b"0000") await self.set(store, key, data_buf) new = self.buffer_cls.from_bytes(b"1111") await store.set_if_not_exists("k", new) # no error result = await store.get(key, default_buffer_prototype()) assert result == data_buf await store.set_if_not_exists("k2", new) # no error result = await store.get("k2", default_buffer_prototype()) assert result == new
class LatencyStore(WrapperStore[Store]): """ A wrapper class that takes any store class in its constructor and adds latency to the `set` and `get` methods. This can be used for performance testing. """ get_latency: float set_latency: float def __init__(self, cls: Store, *, get_latency: float = 0, set_latency: float = 0) -> None: self.get_latency = float(get_latency) self.set_latency = float(set_latency) self._store = cls async def set(self, key: str, value: Buffer) -> None: """ Add latency to the ``set`` method. Calls ``asyncio.sleep(self.set_latency)`` before invoking the wrapped ``set`` method. Parameters ---------- key : str The key to set value : Buffer The value to set Returns ------- None """ await asyncio.sleep(self.set_latency) await self._store.set(key, value) async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: """ Add latency to the ``get`` method. Calls ``asyncio.sleep(self.get_latency)`` before invoking the wrapped ``get`` method. Parameters ---------- key : str The key to get prototype : BufferPrototype The BufferPrototype to use. byte_range : ByteRequest, optional An optional byte range. Returns ------- buffer : Buffer or None """ await asyncio.sleep(self.get_latency) return await self._store.get(key, prototype=prototype, byte_range=byte_range)