Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zarr append dtype checks #6476

Merged
merged 10 commits into from
May 11, 2022
48 changes: 30 additions & 18 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np

from .. import backends, coding, conventions
from .. import backends, conventions
from ..core import indexing
from ..core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -1277,28 +1277,40 @@ def _validate_region(ds, region):
)


def _validate_datatypes_for_zarr_append(dataset):
"""DataArray.name and Dataset keys must be a string or None"""
def _validate_datatypes_for_zarr_append(zstore, dataset):
"""If variable exists in the store, confirm dtype of the data to append is compatible with
existing dtype.
"""

existing_vars = zstore.get_variables()

def check_dtype(var):
def check_dtype(vname, var):
if (
not np.issubdtype(var.dtype, np.number)
and not np.issubdtype(var.dtype, np.datetime64)
and not np.issubdtype(var.dtype, np.bool_)
and not coding.strings.is_unicode_dtype(var.dtype)
and not var.dtype == object
vname not in existing_vars
or np.issubdtype(var.dtype, np.number)
or np.issubdtype(var.dtype, np.datetime64)
or np.issubdtype(var.dtype, np.bool_)
or var.dtype == object
):
# and not re.match('^bytes[1-9]+$', var.dtype.name)):
# We can skip dtype equality checks under two conditions: (1) if the var to append is
# new to the dataset, because in this case there is no existing var to compare it to;
# or (2) if var to append's dtype is known to be easy-to-append, because in this case
# we can be confident appending won't cause problems. Examples of dtypes which are not
# easy-to-append include length-specified strings of type `|S*` or `<U*` (where * is a
# positive integer character length). For these dtypes, appending dissimilar lengths
# can result in truncation of appended data. Therefore, variables which already exist
# in the dataset, and with dtypes which are not known to be easy-to-append, necessitate
# exact dtype equality, as checked below.
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a brief comment (e.g, based on your comment here: #6476 (comment)) to summarize why it's OK not to check these cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the review!

I've added the requested comment in 9f97f00

elif not var.dtype == existing_vars[vname].dtype:
raise ValueError(
"Invalid dtype for data variable: {} "
"dtype must be a subtype of number, "
"datetime, bool, a fixed sized string, "
"a fixed size unicode string or an "
"object".format(var)
f"Mismatched dtypes for variable {vname} between Zarr store on disk "
f"and dataset to append. Store has dtype {existing_vars[vname].dtype} but "
f"dataset to append has dtype {var.dtype}."
)

for k in dataset.data_vars.values():
check_dtype(k)
for vname, var in dataset.data_vars.items():
check_dtype(vname, var)


def to_zarr(
Expand Down Expand Up @@ -1403,7 +1415,7 @@ def to_zarr(
)

if mode in ["a", "r+"]:
_validate_datatypes_for_zarr_append(dataset)
_validate_datatypes_for_zarr_append(zstore, dataset)
if append_dim is not None:
existing_dims = zstore.get_dimensions()
if append_dim not in existing_dims:
Expand Down
17 changes: 16 additions & 1 deletion xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@
_NON_STANDARD_CALENDARS,
_STANDARD_CALENDARS,
)
from .test_dataset import create_append_test_data, create_test_data
from .test_dataset import (
create_append_string_length_mismatch_test_data,
create_append_test_data,
create_test_data,
)

try:
import netCDF4 as nc4
Expand Down Expand Up @@ -2112,6 +2116,17 @@ def test_append_with_existing_encoding_raises(self):
encoding={"da": {"compressor": None}},
)

@pytest.mark.parametrize("dtype", ["U", "S"])
def test_append_string_length_mismatch_raises(self, dtype):
ds, ds_to_append = create_append_string_length_mismatch_test_data(dtype)
with self.create_zarr_target() as store_target:
ds.to_zarr(store_target, mode="w")
with pytest.raises(ValueError, match="Mismatched dtypes for variable"):
ds_to_append.to_zarr(
store_target,
append_dim="time",
)

def test_check_encoding_is_consistent_after_append(self):

ds, ds_to_append, _ = create_append_test_data()
Expand Down
35 changes: 35 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def create_append_test_data(seed=None):
time2 = pd.date_range("2000-02-01", periods=nt2)
string_var = np.array(["ae", "bc", "df"], dtype=object)
string_var_to_append = np.array(["asdf", "asdfg"], dtype=object)
string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2")
string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2")
unicode_var = ["áó", "áó", "áó"]
datetime_var = np.array(
["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]"
Expand All @@ -94,6 +96,9 @@ def create_append_test_data(seed=None):
dims=["lat", "lon", "time"],
),
"string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]),
"string_var_fixed_length": xr.DataArray(
string_var_fixed_length, coords=[time1], dims=["time"]
),
"unicode_var": xr.DataArray(
unicode_var, coords=[time1], dims=["time"]
).astype(np.unicode_),
Expand All @@ -112,6 +117,9 @@ def create_append_test_data(seed=None):
"string_var": xr.DataArray(
string_var_to_append, coords=[time2], dims=["time"]
),
"string_var_fixed_length": xr.DataArray(
string_var_fixed_length_to_append, coords=[time2], dims=["time"]
),
"unicode_var": xr.DataArray(
unicode_var[:nt2], coords=[time2], dims=["time"]
).astype(np.unicode_),
Expand All @@ -137,6 +145,33 @@ def create_append_test_data(seed=None):
return ds, ds_to_append, ds_with_new_var


def create_append_string_length_mismatch_test_data(dtype):
def make_datasets(data, data_to_append):
ds = xr.Dataset(
{"temperature": (["time"], data)},
coords={"time": [0, 1, 2]},
)
ds_to_append = xr.Dataset(
{"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]}
)
assert all(objp.data.flags.writeable for objp in ds.variables.values())
assert all(
objp.data.flags.writeable for objp in ds_to_append.variables.values()
)
return ds, ds_to_append

u2_strings = ["ab", "cd", "ef"]
u5_strings = ["abc", "def", "ghijk"]

s2_strings = np.array(["aa", "bb", "cc"], dtype="|S2")
s3_strings = np.array(["aaa", "bbb", "ccc"], dtype="|S3")

if dtype == "U":
return make_datasets(u2_strings, u5_strings)
elif dtype == "S":
return make_datasets(s2_strings, s3_strings)


def create_test_multiindex():
mindex = pd.MultiIndex.from_product(
[["a", "b"], [1, 2]], names=("level_1", "level_2")
Expand Down