Skip to content

Commit

Permalink
runtime: remove DisallowUnknownFields()
Browse files Browse the repository at this point in the history
This is now entirely controlled by the marshaller
options.
  • Loading branch information
johanbrandhorst committed May 23, 2020
1 parent 885c0bd commit 12f0a6d
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 47 deletions.
32 changes: 17 additions & 15 deletions examples/internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func testEcho(t *testing.T, port int, apiPrefix string, contentType string) {
t.Logf("%s", buf)
}

var msg examplepb.SimpleMessage
var msg examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &msg); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
Expand Down Expand Up @@ -135,7 +135,7 @@ func testEchoOneof(t *testing.T, port int, apiPrefix string, contentType string)
t.Logf("%s", buf)
}

var msg examplepb.SimpleMessage
var msg examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &msg); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
Expand Down Expand Up @@ -168,7 +168,7 @@ func testEchoOneof1(t *testing.T, port int, apiPrefix string, contentType string
t.Logf("%s", buf)
}

var msg examplepb.SimpleMessage
var msg examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &msg); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
Expand Down Expand Up @@ -201,7 +201,7 @@ func testEchoOneof2(t *testing.T, port int, apiPrefix string, contentType string
t.Logf("%s", buf)
}

var msg examplepb.SimpleMessage
var msg examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &msg); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
Expand All @@ -216,7 +216,7 @@ func testEchoOneof2(t *testing.T, port int, apiPrefix string, contentType string
}

func testEchoBody(t *testing.T, port int, apiPrefix string) {
sent := examplepb.SimpleMessage{Id: "example"}
sent := examplepb.UnannotatedSimpleMessage{Id: "example"}
payload, err := marshaler.Marshal(&sent)
if err != nil {
t.Fatalf("marshaler.Marshal(%#v) failed with %v; want success", payload, err)
Expand All @@ -240,12 +240,12 @@ func testEchoBody(t *testing.T, port int, apiPrefix string) {
t.Logf("%s", buf)
}

var received examplepb.SimpleMessage
var received examplepb.UnannotatedSimpleMessage
if err := marshaler.Unmarshal(buf, &received); err != nil {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
}
if diff := cmp.Diff(received, sent, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&received, &sent, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}

Expand Down Expand Up @@ -334,7 +334,7 @@ func testABECreate(t *testing.T, port int) {
t.Error("msg.Uuid is empty; want not empty")
}
msg.Uuid = ""
if diff := cmp.Diff(msg, want, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&msg, &want, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}
}
Expand Down Expand Up @@ -442,7 +442,7 @@ func testABECreateBody(t *testing.T, port int) {
t.Error("msg.Uuid is empty; want not empty")
}
msg.Uuid = ""
if diff := cmp.Diff(msg, want, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&msg, &want, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}
}
Expand Down Expand Up @@ -673,7 +673,7 @@ func testABELookup(t *testing.T, port int) {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
}
if diff := cmp.Diff(msg, want, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&msg, &want, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}

Expand Down Expand Up @@ -1340,7 +1340,7 @@ func testABERepeated(t *testing.T, port int) {
t.Errorf("marshaler.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
}
if diff := cmp.Diff(msg, want, protocmp.Transform()); diff != "" {
if diff := cmp.Diff(&msg, &want, protocmp.Transform()); diff != "" {
t.Errorf(diff)
}
}
Expand Down Expand Up @@ -1590,15 +1590,16 @@ func testResponseBodies(t *testing.T, port int) {
t.Logf("%s", buf)
}

var got []*examplepb.ResponseBodyOut_Response
var got []*examplepb.RepeatedResponseBodyOut_Response
err = marshaler.Unmarshal(buf, &got)
if err != nil {
t.Errorf("marshaler.Unmarshal failed with %v; want success", err)
return
}
want := []*examplepb.ResponseBodyOut_Response{
want := []*examplepb.RepeatedResponseBodyOut_Response{
{
Data: "foo",
Type: examplepb.RepeatedResponseBodyOut_Response_UNKNOWN,
},
}
if diff := cmp.Diff(got, want, protocmp.Transform()); diff != "" {
Expand Down Expand Up @@ -1708,15 +1709,16 @@ func testResponseStrings(t *testing.T, port int) {
t.Logf("%s", buf)
}

var got []*examplepb.ResponseBodyOut_Response
var got []*examplepb.RepeatedResponseBodyOut_Response
err = marshaler.Unmarshal(buf, &got)
if err != nil {
t.Errorf("marshaler.Unmarshal failed with %v; want success", err)
return
}
want := []*examplepb.ResponseBodyOut_Response{
want := []*examplepb.RepeatedResponseBodyOut_Response{
{
Data: "foo",
Type: examplepb.RepeatedResponseBodyOut_Response_UNKNOWN,
},
}
if diff := cmp.Diff(got, want, protocmp.Transform()); diff != "" {
Expand Down
7 changes: 6 additions & 1 deletion examples/internal/server/responsebody.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"

"github.com/golang/glog"
examples "github.com/grpc-ecosystem/grpc-gateway/v2/examples/internal/proto/examplepb"
)

Expand All @@ -16,6 +17,7 @@ func newResponseBodyServer() examples.ResponseBodyServiceServer {
}

func (s *responseBodyServer) GetResponseBody(ctx context.Context, req *examples.ResponseBodyIn) (*examples.ResponseBodyOut, error) {
glog.Info(req)
return &examples.ResponseBodyOut{
Response: &examples.ResponseBodyOut_Response{
Data: req.Data,
Expand All @@ -24,16 +26,18 @@ func (s *responseBodyServer) GetResponseBody(ctx context.Context, req *examples.
}

func (s *responseBodyServer) ListResponseBodies(ctx context.Context, req *examples.ResponseBodyIn) (*examples.RepeatedResponseBodyOut, error) {
glog.Info(req)
return &examples.RepeatedResponseBodyOut{
Response: []*examples.RepeatedResponseBodyOut_Response{
&examples.RepeatedResponseBodyOut_Response{
{
Data: req.Data,
},
},
}, nil
}

func (s *responseBodyServer) ListResponseStrings(ctx context.Context, req *examples.ResponseBodyIn) (*examples.RepeatedResponseStrings, error) {
glog.Info(req)
if req.Data == "empty" {
return &examples.RepeatedResponseStrings{
Values: []string{},
Expand All @@ -45,6 +49,7 @@ func (s *responseBodyServer) ListResponseStrings(ctx context.Context, req *examp
}

func (s *responseBodyServer) GetResponseBodyStream(req *examples.ResponseBodyIn, stream examples.ResponseBodyService_GetResponseBodyStreamServer) error {
glog.Info(req)
if err := stream.Send(&examples.ResponseBodyOut{
Response: &examples.ResponseBodyOut_Response{
Data: fmt.Sprintf("first %s", req.Data),
Expand Down
42 changes: 14 additions & 28 deletions runtime/marshal_jsonpb.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,29 @@ func (j *JSONPb) marshalNonProtoField(v interface{}) ([]byte, error) {

// Unmarshal unmarshals JSON "data" into "v"
func (j *JSONPb) Unmarshal(data []byte, v interface{}) error {
return unmarshalJSONPb(data, v)
return unmarshalJSONPb(data, j.UnmarshalOptions, v)
}

// NewDecoder returns a Decoder which reads JSON stream from "r".
func (j *JSONPb) NewDecoder(r io.Reader) Decoder {
d := json.NewDecoder(r)
return DecoderWrapper{Decoder: d}
return DecoderWrapper{
Decoder: d,
UnmarshalOptions: j.UnmarshalOptions,
}
}

// DecoderWrapper is a wrapper around a *json.Decoder that adds
// support for protos to the Decode method.
type DecoderWrapper struct {
*json.Decoder
protojson.UnmarshalOptions
}

// Decode wraps the embedded decoder's Decode method to support
// protos using a jsonpb.Unmarshaler.
func (d DecoderWrapper) Decode(v interface{}) error {
return decodeJSONPb(d.Decoder, v)
return decodeJSONPb(d.Decoder, d.UnmarshalOptions, v)
}

// NewEncoder returns an Encoder which writes JSON stream into "w".
Expand All @@ -171,15 +175,15 @@ func (j *JSONPb) NewEncoder(w io.Writer) Encoder {
})
}

func unmarshalJSONPb(data []byte, v interface{}) error {
func unmarshalJSONPb(data []byte, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
d := json.NewDecoder(bytes.NewReader(data))
return decodeJSONPb(d, v)
return decodeJSONPb(d, unmarshaler, v)
}

func decodeJSONPb(d *json.Decoder, v interface{}) error {
func decodeJSONPb(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
p, ok := v.(proto.Message)
if !ok {
return decodeNonProtoField(d, v)
return decodeNonProtoField(d, unmarshaler, v)
}

// Decode into bytes for marshalling
Expand All @@ -189,13 +193,10 @@ func decodeJSONPb(d *json.Decoder, v interface{}) error {
return err
}

unmarshaler := &protojson.UnmarshalOptions{
DiscardUnknown: allowUnknownFields,
}
return unmarshaler.Unmarshal([]byte(b), p)
}

func decodeNonProtoField(d *json.Decoder, v interface{}) error {
func decodeNonProtoField(d *json.Decoder, unmarshaler protojson.UnmarshalOptions, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("%T is not a pointer", v)
Expand All @@ -212,9 +213,6 @@ func decodeNonProtoField(d *json.Decoder, v interface{}) error {
return err
}

unmarshaler := &protojson.UnmarshalOptions{
DiscardUnknown: allowUnknownFields,
}
return unmarshaler.Unmarshal([]byte(b), rv.Interface().(proto.Message))
}
rv = rv.Elem()
Expand All @@ -239,7 +237,7 @@ func decodeNonProtoField(d *json.Decoder, v interface{}) error {
}
bk := result[0]
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(*v), bv.Interface()); err != nil {
if err := unmarshalJSONPb([]byte(*v), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.SetMapIndex(bk, bv.Elem())
Expand All @@ -256,7 +254,7 @@ func decodeNonProtoField(d *json.Decoder, v interface{}) error {
}
for _, item := range sl {
bv := reflect.New(rv.Type().Elem())
if err := unmarshalJSONPb([]byte(item), bv.Interface()); err != nil {
if err := unmarshalJSONPb([]byte(item), unmarshaler, bv.Interface()); err != nil {
return err
}
rv.Set(reflect.Append(rv, bv.Elem()))
Expand Down Expand Up @@ -294,18 +292,6 @@ func (j *JSONPb) Delimiter() []byte {
return []byte("\n")
}

// allowUnknownFields helps not to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
var allowUnknownFields = true

// DisallowUnknownFields enables option in decoder (unmarshaller) to
// return an error when it finds an unknown field. This function must be
// called before using the JSON marshaller.
func DisallowUnknownFields() {
allowUnknownFields = false
}

var (
convFromType = map[reflect.Kind]reflect.Value{
reflect.String: reflect.ValueOf(String),
Expand Down
8 changes: 5 additions & 3 deletions runtime/marshal_jsonpb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,16 +471,18 @@ func TestJSONPbDecoderFields(t *testing.T) {

func TestJSONPbDecoderUnknownField(t *testing.T) {
var (
m runtime.JSONPb
m = runtime.JSONPb{
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: false,
},
}
got examplepb.ABitOfEverything
)
data := `{
"uuid": "6EC2446F-7E89-4127-B3E6-5C05E6BECBA7",
"unknownField": "111"
}`

runtime.DisallowUnknownFields()

r := strings.NewReader(data)
dec := m.NewDecoder(r)
if err := dec.Decode(&got); err == nil {
Expand Down
3 changes: 3 additions & 0 deletions runtime/marshaler_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ var (
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
}
)
Expand Down

0 comments on commit 12f0a6d

Please sign in to comment.