Skip to content

Commit

Permalink
Fix zarr append dtype checks (#6476)
Browse files Browse the repository at this point in the history
* fix zarr append dtype check first commit

* use zstore in _validate_datatype

* remove coding.strings.is_unicode_dtype check

* test appending fixed length strings

* test string length mismatch raises for U and S

* add explanatory comment for zarr append dtype checks

Co-authored-by: Maximilian Roos <[email protected]>
  • Loading branch information
cisaacstern and max-sixty authored May 11, 2022
1 parent 770e878 commit 4a53e41
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 19 deletions.
48 changes: 30 additions & 18 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np

from .. import backends, coding, conventions
from .. import backends, conventions
from ..core import indexing
from ..core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -1277,28 +1277,40 @@ def _validate_region(ds, region):
)


def _validate_datatypes_for_zarr_append(dataset):
"""DataArray.name and Dataset keys must be a string or None"""
def _validate_datatypes_for_zarr_append(zstore, dataset):
"""If variable exists in the store, confirm dtype of the data to append is compatible with
existing dtype.
"""

existing_vars = zstore.get_variables()

def check_dtype(var):
def check_dtype(vname, var):
if (
not np.issubdtype(var.dtype, np.number)
and not np.issubdtype(var.dtype, np.datetime64)
and not np.issubdtype(var.dtype, np.bool_)
and not coding.strings.is_unicode_dtype(var.dtype)
and not var.dtype == object
vname not in existing_vars
or np.issubdtype(var.dtype, np.number)
or np.issubdtype(var.dtype, np.datetime64)
or np.issubdtype(var.dtype, np.bool_)
or var.dtype == object
):
# and not re.match('^bytes[1-9]+$', var.dtype.name)):
# We can skip dtype equality checks under two conditions: (1) if the var to append is
# new to the dataset, because in this case there is no existing var to compare it to;
# or (2) if var to append's dtype is known to be easy-to-append, because in this case
# we can be confident appending won't cause problems. Examples of dtypes which are not
# easy-to-append include length-specified strings of type `|S*` or `<U*` (where * is a
# positive integer character length). For these dtypes, appending dissimilar lengths
# can result in truncation of appended data. Therefore, variables which already exist
# in the dataset, and with dtypes which are not known to be easy-to-append, necessitate
# exact dtype equality, as checked below.
pass
elif not var.dtype == existing_vars[vname].dtype:
raise ValueError(
"Invalid dtype for data variable: {} "
"dtype must be a subtype of number, "
"datetime, bool, a fixed sized string, "
"a fixed size unicode string or an "
"object".format(var)
f"Mismatched dtypes for variable {vname} between Zarr store on disk "
f"and dataset to append. Store has dtype {existing_vars[vname].dtype} but "
f"dataset to append has dtype {var.dtype}."
)

for k in dataset.data_vars.values():
check_dtype(k)
for vname, var in dataset.data_vars.items():
check_dtype(vname, var)


def to_zarr(
Expand Down Expand Up @@ -1403,7 +1415,7 @@ def to_zarr(
)

if mode in ["a", "r+"]:
_validate_datatypes_for_zarr_append(dataset)
_validate_datatypes_for_zarr_append(zstore, dataset)
if append_dim is not None:
existing_dims = zstore.get_dimensions()
if append_dim not in existing_dims:
Expand Down
17 changes: 16 additions & 1 deletion xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@
_NON_STANDARD_CALENDARS,
_STANDARD_CALENDARS,
)
from .test_dataset import create_append_test_data, create_test_data
from .test_dataset import (
create_append_string_length_mismatch_test_data,
create_append_test_data,
create_test_data,
)

try:
import netCDF4 as nc4
Expand Down Expand Up @@ -2112,6 +2116,17 @@ def test_append_with_existing_encoding_raises(self):
encoding={"da": {"compressor": None}},
)

@pytest.mark.parametrize("dtype", ["U", "S"])
def test_append_string_length_mismatch_raises(self, dtype):
ds, ds_to_append = create_append_string_length_mismatch_test_data(dtype)
with self.create_zarr_target() as store_target:
ds.to_zarr(store_target, mode="w")
with pytest.raises(ValueError, match="Mismatched dtypes for variable"):
ds_to_append.to_zarr(
store_target,
append_dim="time",
)

def test_check_encoding_is_consistent_after_append(self):

ds, ds_to_append, _ = create_append_test_data()
Expand Down
35 changes: 35 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def create_append_test_data(seed=None):
time2 = pd.date_range("2000-02-01", periods=nt2)
string_var = np.array(["ae", "bc", "df"], dtype=object)
string_var_to_append = np.array(["asdf", "asdfg"], dtype=object)
string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2")
string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2")
unicode_var = ["áó", "áó", "áó"]
datetime_var = np.array(
["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]"
Expand All @@ -94,6 +96,9 @@ def create_append_test_data(seed=None):
dims=["lat", "lon", "time"],
),
"string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]),
"string_var_fixed_length": xr.DataArray(
string_var_fixed_length, coords=[time1], dims=["time"]
),
"unicode_var": xr.DataArray(
unicode_var, coords=[time1], dims=["time"]
).astype(np.unicode_),
Expand All @@ -112,6 +117,9 @@ def create_append_test_data(seed=None):
"string_var": xr.DataArray(
string_var_to_append, coords=[time2], dims=["time"]
),
"string_var_fixed_length": xr.DataArray(
string_var_fixed_length_to_append, coords=[time2], dims=["time"]
),
"unicode_var": xr.DataArray(
unicode_var[:nt2], coords=[time2], dims=["time"]
).astype(np.unicode_),
Expand All @@ -137,6 +145,33 @@ def create_append_test_data(seed=None):
return ds, ds_to_append, ds_with_new_var


def create_append_string_length_mismatch_test_data(dtype):
def make_datasets(data, data_to_append):
ds = xr.Dataset(
{"temperature": (["time"], data)},
coords={"time": [0, 1, 2]},
)
ds_to_append = xr.Dataset(
{"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]}
)
assert all(objp.data.flags.writeable for objp in ds.variables.values())
assert all(
objp.data.flags.writeable for objp in ds_to_append.variables.values()
)
return ds, ds_to_append

u2_strings = ["ab", "cd", "ef"]
u5_strings = ["abc", "def", "ghijk"]

s2_strings = np.array(["aa", "bb", "cc"], dtype="|S2")
s3_strings = np.array(["aaa", "bbb", "ccc"], dtype="|S3")

if dtype == "U":
return make_datasets(u2_strings, u5_strings)
elif dtype == "S":
return make_datasets(s2_strings, s3_strings)


def create_test_multiindex():
mindex = pd.MultiIndex.from_product(
[["a", "b"], [1, 2]], names=("level_1", "level_2")
Expand Down

0 comments on commit 4a53e41

Please sign in to comment.