Skip to content

Commit

Permalink
Support bool datatype in decode_json op (#1139)
Browse files Browse the repository at this point in the history
* support bool datatype in decode_json

* lint fixes
  • Loading branch information
kvignesh1420 authored Oct 1, 2020
1 parent 9e98604 commit 6579047
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 44 deletions.
8 changes: 8 additions & 0 deletions tensorflow_io/core/kernels/serialization_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class DecodeJSONOp : public OpKernel {
case DT_STRING:
writeToTensor(entry, value_tensor, flat_index, writeString);
break;
case DT_BOOL:
writeToTensor(entry, value_tensor, flat_index, writeBool);
break;
default:
OP_REQUIRES(
context, false,
Expand Down Expand Up @@ -135,6 +138,11 @@ class DecodeJSONOp : public OpKernel {
value_tensor->flat<tstring>()(flat_index) = (*entry).GetString();
}

static void writeBool(rapidjson::Value* entry, Tensor* value_tensor,
int64& flat_index) {
value_tensor->flat<bool>()(flat_index) = (*entry).GetBool();
}

// Full Tensor Write

template <class T>
Expand Down
70 changes: 26 additions & 44 deletions tests/test_serialization_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,27 @@ def _fixture_lookup(name):
def fixture_json():
"""fixture_json"""
data = """{
"R": {
"Foo": 208.82240295410156,
"Bar": 93
},
"Background": [
"~/0109.jpg",
"~/0110.jpg"
],
"Focal Length": 36.9439697265625,
"Location": [
7.8685874938964844,
-4.7373886108398438,
-0.038147926330566406],
"Rotation": [
-4.592592716217041,
-4.4698805809020996,
-6.9197754859924316]
}
"""
"R": {
"Foo": 208.82240295410156,
"Bar": 93
},
"Background": [
"~/0109.jpg",
"~/0110.jpg"
],
"Focal Length": 36.9439697265625,
"Location": [
7.8685874938964844,
-4.7373886108398438,
-0.038147926330566406],
"Rotation": [
-4.592592716217041,
-4.4698805809020996,
-6.9197754859924316],
"Valid": true,
"Boundary": [[10, 20],[30, 40]]
}
"""
value = {
"R": {
"Foo": tf.constant(208.82240295410156, tf.float64),
Expand All @@ -70,6 +72,8 @@ def fixture_json():
"Rotation": tf.constant(
[-4.592592716217041, -4.4698805809020996, -6.9197754859924316], tf.float64
),
"Valid": tf.constant(True, tf.bool),
"Boundary": tf.constant([[10, 20], [30, 40]], tf.int64),
}
specs = {
"R": {
Expand All @@ -82,6 +86,8 @@ def fixture_json():
),
"Location": tf.TensorSpec(tf.TensorShape([3]), tf.float64),
"Rotation": tf.TensorSpec(tf.TensorShape([3]), tf.float64),
"Valid": tf.TensorSpec(tf.TensorShape([]), tf.bool),
"Boundary": tf.TensorSpec(tf.TensorShape([2, 2]), tf.int64),
}

return data, value, specs
Expand Down Expand Up @@ -188,7 +194,7 @@ def test_serialization_decode_in_dataset(
)


def test_json_partial_shape():
def test_decode_json_partial_shape():
"""Test case for partial shape GitHub 918."""
r = json.dumps({"foo": [1, 2, 3, 4, 5]})

Expand All @@ -200,27 +206,3 @@ def parse_json(json_text):

v = parse_json(r)
assert np.array_equal(v, [1, 2, 3, 4, 5])


def test_json_multiple_dimension_tensor():

# Test case is to resolve the issue where multiple dimension tensor
# was not supported for decode_json.
# The issue was initially raised in:
# https://github.com/tensorflow/io/pull/695#issuecomment-683270751
r = '{"x": [[[1.0], [2.0]]], "y": ["index", "count"], "z": 0.5}'

@tf.function(autograph=False)
def parse_json(json_text):
specs = {
"x": tf.TensorSpec(tf.TensorShape([1, 2, 1]), tf.float32),
"y": tf.TensorSpec(tf.TensorShape([2]), tf.string),
"z": tf.TensorSpec(tf.TensorShape([]), tf.float32),
}
parsed = tfio.experimental.serialization.decode_json(json_text, specs)
return parsed

v = parse_json(r)
assert np.array_equal(v["x"], [[[1.0], [2.0]]])
assert np.array_equal(v["y"], [b"index", b"count"])
assert np.array_equal(v["z"], 0.5)

0 comments on commit 6579047

Please sign in to comment.