Source code for zarr.storage._zip

from __future__ import annotations

import os
import threading
import time
import zipfile
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

from zarr.abc.store import (
    ByteRequest,
    OffsetByteRequest,
    RangeByteRequest,
    Store,
    SuffixByteRequest,
)
from zarr.core.buffer import Buffer, BufferPrototype

if TYPE_CHECKING:
    from collections.abc import AsyncIterator, Iterable

ZipStoreAccessModeLiteral = Literal["r", "w", "a"]


[docs] class ZipStore(Store): """ Store using a ZIP file. Parameters ---------- path : str Location of file. mode : str, optional One of 'r' to read an existing file, 'w' to truncate and write a new file, 'a' to append to an existing file, or 'x' to exclusively create and write a new file. compression : int, optional Compression method to use when writing to the archive. allowZip64 : bool, optional If True (the default) will create ZIP files that use the ZIP64 extensions when the zipfile is larger than 2 GiB. If False will raise an exception when the ZIP file would require ZIP64 extensions. Attributes ---------- allowed_exceptions supports_writes supports_deletes supports_partial_writes supports_listing path compression allowZip64 """ supports_writes: bool = True supports_deletes: bool = False supports_partial_writes: bool = False supports_listing: bool = True path: Path compression: int allowZip64: bool _zf: zipfile.ZipFile _lock: threading.RLock def __init__( self, path: Path | str, *, mode: ZipStoreAccessModeLiteral = "r", read_only: bool | None = None, compression: int = zipfile.ZIP_STORED, allowZip64: bool = True, ) -> None: if read_only is None: read_only = mode == "r" super().__init__(read_only=read_only) if isinstance(path, str): path = Path(path) assert isinstance(path, Path) self.path = path # root? self._zmode = mode self.compression = compression self.allowZip64 = allowZip64 def _sync_open(self) -> None: if self._is_open: raise ValueError("store is already open") self._lock = threading.RLock() self._zf = zipfile.ZipFile( self.path, mode=self._zmode, compression=self.compression, allowZip64=self.allowZip64, ) self._is_open = True async def _open(self) -> None: self._sync_open() def __getstate__(self) -> dict[str, Any]: # We need a copy to not modify the state of the original store state = self.__dict__.copy() for attr in ["_zf", "_lock"]: state.pop(attr, None) return state def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__ = state self._is_open = False self._sync_open()
[docs] def close(self) -> None: # docstring inherited super().close() with self._lock: self._zf.close()
[docs] async def clear(self) -> None: # docstring inherited with self._lock: self._check_writable() self._zf.close() os.remove(self.path) self._zf = zipfile.ZipFile( self.path, mode="w", compression=self.compression, allowZip64=self.allowZip64 )
def __str__(self) -> str: return f"zip://{self.path}" def __repr__(self) -> str: return f"ZipStore('{self}')" def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.path == other.path def _get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None, ) -> Buffer | None: if not self._is_open: self._sync_open() # docstring inherited try: with self._zf.open(key) as f: # will raise KeyError if byte_range is None: return prototype.buffer.from_bytes(f.read()) elif isinstance(byte_range, RangeByteRequest): f.seek(byte_range.start) return prototype.buffer.from_bytes(f.read(byte_range.end - f.tell())) size = f.seek(0, os.SEEK_END) if isinstance(byte_range, OffsetByteRequest): f.seek(byte_range.offset) elif isinstance(byte_range, SuffixByteRequest): f.seek(max(0, size - byte_range.suffix)) else: raise TypeError(f"Unexpected byte_range, got {byte_range}.") return prototype.buffer.from_bytes(f.read()) except KeyError: return None
[docs] async def get( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None, ) -> Buffer | None: # docstring inherited assert isinstance(key, str) with self._lock: return self._get(key, prototype=prototype, byte_range=byte_range)
[docs] async def get_partial_values( self, prototype: BufferPrototype, key_ranges: Iterable[tuple[str, ByteRequest | None]], ) -> list[Buffer | None]: # docstring inherited out = [] with self._lock: for key, byte_range in key_ranges: out.append(self._get(key, prototype=prototype, byte_range=byte_range)) return out
def _set(self, key: str, value: Buffer) -> None: if not self._is_open: self._sync_open() # generally, this should be called inside a lock keyinfo = zipfile.ZipInfo(filename=key, date_time=time.localtime(time.time())[:6]) keyinfo.compress_type = self.compression if keyinfo.filename[-1] == os.sep: keyinfo.external_attr = 0o40775 << 16 # drwxrwxr-x keyinfo.external_attr |= 0x10 # MS-DOS directory flag else: keyinfo.external_attr = 0o644 << 16 # ?rw-r--r-- self._zf.writestr(keyinfo, value.to_bytes())
[docs] async def set(self, key: str, value: Buffer) -> None: # docstring inherited self._check_writable() if not self._is_open: self._sync_open() assert isinstance(key, str) if not isinstance(value, Buffer): raise TypeError( f"ZipStore.set(): `value` must be a Buffer instance. Got an instance of {type(value)} instead." ) with self._lock: self._set(key, value)
[docs] async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None: raise NotImplementedError
[docs] async def set_if_not_exists(self, key: str, value: Buffer) -> None: self._check_writable() with self._lock: members = self._zf.namelist() if key not in members: self._set(key, value)
[docs] async def delete_dir(self, prefix: str) -> None: # only raise NotImplementedError if any keys are found self._check_writable() if prefix != "" and not prefix.endswith("/"): prefix += "/" async for _ in self.list_prefix(prefix): raise NotImplementedError
[docs] async def delete(self, key: str) -> None: # docstring inherited # we choose to only raise NotImplementedError here if the key exists # this allows the array/group APIs to avoid the overhead of existence checks self._check_writable() if await self.exists(key): raise NotImplementedError
[docs] async def exists(self, key: str) -> bool: # docstring inherited with self._lock: try: self._zf.getinfo(key) except KeyError: return False else: return True
[docs] async def list(self) -> AsyncIterator[str]: # docstring inherited with self._lock: for key in self._zf.namelist(): yield key
[docs] async def list_prefix(self, prefix: str) -> AsyncIterator[str]: # docstring inherited async for key in self.list(): if key.startswith(prefix): yield key
[docs] async def list_dir(self, prefix: str) -> AsyncIterator[str]: # docstring inherited prefix = prefix.rstrip("/") keys = self._zf.namelist() seen = set() if prefix == "": keys_unique = {k.split("/")[0] for k in keys} for key in keys_unique: if key not in seen: seen.add(key) yield key else: for key in keys: if key.startswith(prefix + "/") and key.strip("/") != prefix: k = key.removeprefix(prefix + "/").split("/")[0] if k not in seen: seen.add(k) yield k