Skip to content

Commit

Permalink
feat(server/v2)!: grpcgateway autoregistration (#22941)
Browse files Browse the repository at this point in the history
  • Loading branch information
technicallyty authored Jan 6, 2025
1 parent 23595c8 commit 884a7a5
Show file tree
Hide file tree
Showing 8 changed files with 616 additions and 17 deletions.
4 changes: 4 additions & 0 deletions server/v2/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Each entry must include the Github issue reference in the following format:

## [Unreleased]

### Features

* [#22715](https://github.com/cosmos/cosmos-sdk/pull/22941) Add custom HTTP handler for grpc-gateway that removes the need to manually register grpc-gateway services.

## [v2.0.0-beta.1](https://github.com/cosmos/cosmos-sdk/releases/tag/server/v2.0.0-beta.1)

Initial tag of `cosmossdk.io/server/v2`.
145 changes: 145 additions & 0 deletions server/v2/api/grpcgateway/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package grpcgateway

import (
"net/http"
"strconv"
"strings"

gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"

"cosmossdk.io/core/transaction"
"cosmossdk.io/log"
"cosmossdk.io/server/v2/appmanager"
)

var _ http.Handler = &gatewayInterceptor[transaction.Tx]{}

// gatewayInterceptor handles routing grpc-gateway queries to the app manager's query router.
type gatewayInterceptor[T transaction.Tx] struct {
logger log.Logger
// gateway is the fallback grpc gateway mux handler.
gateway *runtime.ServeMux

// customEndpointMapping is a mapping of custom GET options on proto RPC handlers, to the fully qualified method name.
//
// example: /cosmos/bank/v1beta1/denoms_metadata -> cosmos.bank.v1beta1.Query.DenomsMetadata
customEndpointMapping map[string]string

// appManager is used to route queries to the application.
appManager appmanager.AppManager[T]
}

// newGatewayInterceptor creates a new gatewayInterceptor.
func newGatewayInterceptor[T transaction.Tx](logger log.Logger, gateway *runtime.ServeMux, am appmanager.AppManager[T]) (*gatewayInterceptor[T], error) {
getMapping, err := getHTTPGetAnnotationMapping()
if err != nil {
return nil, err
}
return &gatewayInterceptor[T]{
logger: logger,
gateway: gateway,
customEndpointMapping: getMapping,
appManager: am,
}, nil
}

// ServeHTTP implements the http.Handler interface. This function will attempt to match http requests to the
// interceptors internal mapping of http annotations to query request type names.
// If no match can be made, it falls back to the runtime gateway server mux.
func (g *gatewayInterceptor[T]) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
g.logger.Debug("received grpc-gateway request", "request_uri", request.RequestURI)
match := matchURL(request.URL, g.customEndpointMapping)
if match == nil {
// no match cases fall back to gateway mux.
g.gateway.ServeHTTP(writer, request)
return
}
g.logger.Debug("matched request", "query_input", match.QueryInputName)
_, out := runtime.MarshalerForRequest(g.gateway, request)
var msg gogoproto.Message
var err error

switch request.Method {
case http.MethodPost:
msg, err = createMessageFromJSON(match, request)
case http.MethodGet:
msg, err = createMessage(match)
default:
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, status.Error(codes.Unimplemented, "HTTP method must be POST or GET"))
return
}
if err != nil {
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
return
}

// extract block height header
var height uint64
heightStr := request.Header.Get(GRPCBlockHeightHeader)
if heightStr != "" {
height, err = strconv.ParseUint(heightStr, 10, 64)
if err != nil {
err = status.Errorf(codes.InvalidArgument, "invalid height: %s", heightStr)
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
return
}
}

query, err := g.appManager.Query(request.Context(), height, msg)
if err != nil {
// if we couldn't find a handler for this request, just fall back to the gateway mux.
if strings.Contains(err.Error(), "no handler") {
g.gateway.ServeHTTP(writer, request)
} else {
// for all other errors, we just return the error.
runtime.DefaultHTTPProtoErrorHandler(request.Context(), g.gateway, out, writer, request, err)
}
return
}
// for no errors, we forward the response.
runtime.ForwardResponseMessage(request.Context(), g.gateway, out, writer, request, query)
}

// getHTTPGetAnnotationMapping returns a mapping of RPC Method HTTP GET annotation to the RPC Handler's Request Input type full name.
//
// example: "/cosmos/auth/v1beta1/account_info/{address}":"cosmos.auth.v1beta1.Query.AccountInfo"
func getHTTPGetAnnotationMapping() (map[string]string, error) {
protoFiles, err := gogoproto.MergedRegistry()
if err != nil {
return nil, err
}

httpGets := make(map[string]string)
protoFiles.RangeFiles(func(fd protoreflect.FileDescriptor) bool {
for i := 0; i < fd.Services().Len(); i++ {
serviceDesc := fd.Services().Get(i)
for j := 0; j < serviceDesc.Methods().Len(); j++ {
methodDesc := serviceDesc.Methods().Get(j)

httpAnnotation := proto.GetExtension(methodDesc.Options(), annotations.E_Http)
if httpAnnotation == nil {
continue
}

httpRule, ok := httpAnnotation.(*annotations.HttpRule)
if !ok || httpRule == nil {
continue
}
if httpRule.GetGet() == "" {
continue
}

httpGets[httpRule.GetGet()] = string(methodDesc.Input().FullName())
}
}
return true
})

return httpGets, nil
}
16 changes: 10 additions & 6 deletions server/v2/api/grpcgateway/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"cosmossdk.io/core/transaction"
"cosmossdk.io/log"
serverv2 "cosmossdk.io/server/v2"
"cosmossdk.io/server/v2/appmanager"
)

var (
Expand All @@ -37,6 +38,7 @@ func New[T transaction.Tx](
logger log.Logger,
config server.ConfigMap,
ir jsonpb.AnyResolver,
appManager appmanager.AppManager[T],
cfgOptions ...CfgOption,
) (*Server[T], error) {
// The default JSON marshaller used by the gRPC-Gateway is unable to marshal non-nullable non-scalar fields.
Expand Down Expand Up @@ -71,12 +73,14 @@ func New[T transaction.Tx](
}
}

// TODO: register the gRPC-Gateway routes

s.logger = logger.With(log.ModuleKey, s.Name())
s.config = serverCfg
mux := http.NewServeMux()
mux.Handle("/", s.GRPCGatewayRouter)
interceptor, err := newGatewayInterceptor[T](logger, s.GRPCGatewayRouter, appManager)
if err != nil {
return nil, fmt.Errorf("failed to create grpc-gateway interceptor: %w", err)
}
mux.Handle("/", interceptor)

s.server = &http.Server{
Addr: s.config.Address,
Expand Down Expand Up @@ -133,15 +137,15 @@ func (s *Server[T]) Stop(ctx context.Context) error {
return s.server.Shutdown(ctx)
}

// GRPCBlockHeightHeader is the gRPC header for block height.
const GRPCBlockHeightHeader = "x-cosmos-block-height"

// CustomGRPCHeaderMatcher for mapping request headers to
// GRPC metadata.
// HTTP headers that start with 'Grpc-Metadata-' are automatically mapped to
// gRPC metadata after removing prefix 'Grpc-Metadata-'. We can use this
// CustomGRPCHeaderMatcher if headers don't start with `Grpc-Metadata-`
func CustomGRPCHeaderMatcher(key string) (string, bool) {
// GRPCBlockHeightHeader is the gRPC header for block height.
const GRPCBlockHeightHeader = "x-cosmos-block-height"

switch strings.ToLower(key) {
case GRPCBlockHeightHeader:
return GRPCBlockHeightHeader, true
Expand Down
187 changes: 187 additions & 0 deletions server/v2/api/grpcgateway/uri.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package grpcgateway

import (
"io"
"net/http"
"net/url"
"reflect"
"regexp"
"strings"

"github.com/cosmos/gogoproto/jsonpb"
gogoproto "github.com/cosmos/gogoproto/proto"
"github.com/mitchellh/mapstructure"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

const maxBodySize = 1 << 20 // 1 MB

// uriMatch contains information related to a URI match.
type uriMatch struct {
// QueryInputName is the fully qualified name of the proto input type of the query rpc method.
QueryInputName string

// Params are any wildcard/query params found in the request.
//
// example:
// - foo/bar/{baz} - foo/bar/qux -> {baz: qux}
// - foo/bar?baz=qux - foo/bar -> {baz: qux}
Params map[string]string
}

// HasParams reports whether the uriMatch has any params.
func (uri uriMatch) HasParams() bool {
return len(uri.Params) > 0
}

// matchURL attempts to find a match for the given URL.
// NOTE: if no match is found, nil is returned.
func matchURL(u *url.URL, getPatternToQueryInputName map[string]string) *uriMatch {
uriPath := strings.TrimRight(u.Path, "/")
queryParams := u.Query()

params := make(map[string]string)
for key, vals := range queryParams {
if len(vals) > 0 {
// url.Values contains a slice for the values as you are able to specify a key multiple times in URL.
// example: https://localhost:9090/do/something?color=red&color=blue&color=green
// We will just take the first value in the slice.
params[key] = vals[0]
}
}

// for simple cases where there are no wildcards, we can just do a map lookup.
if inputName, ok := getPatternToQueryInputName[uriPath]; ok {
return &uriMatch{
QueryInputName: inputName,
Params: params,
}
}

// attempt to find a match in the pattern map.
for getPattern, queryInputName := range getPatternToQueryInputName {
getPattern = strings.TrimRight(getPattern, "/")

regexPattern, wildcardNames := patternToRegex(getPattern)

regex := regexp.MustCompile(regexPattern)
matches := regex.FindStringSubmatch(uriPath)

if len(matches) > 1 {
// first match is the full string, subsequent matches are capture groups
for i, name := range wildcardNames {
params[name] = matches[i+1]
}

return &uriMatch{
QueryInputName: queryInputName,
Params: params,
}
}
}

return nil
}

// patternToRegex converts a URI pattern with wildcards to a regex pattern.
// Returns the regex pattern and a slice of wildcard names in order
func patternToRegex(pattern string) (string, []string) {
escaped := regexp.QuoteMeta(pattern)
var wildcardNames []string

// extract and replace {param=**} patterns
r1 := regexp.MustCompile(`\\\{([^}]+?)=\\\*\\\*\\}`)
escaped = r1.ReplaceAllStringFunc(escaped, func(match string) string {
// extract wildcard name without the =** suffix
name := regexp.MustCompile(`\\\{(.+?)=`).FindStringSubmatch(match)[1]
wildcardNames = append(wildcardNames, name)
return "(.+)"
})

// extract and replace {param} patterns
r2 := regexp.MustCompile(`\\\{([^}]+)\\}`)
escaped = r2.ReplaceAllStringFunc(escaped, func(match string) string {
// extract wildcard name from the curl braces {}.
name := regexp.MustCompile(`\\\{(.*?)\\}`).FindStringSubmatch(match)[1]
wildcardNames = append(wildcardNames, name)
return "([^/]+)"
})

return "^" + escaped + "$", wildcardNames
}

// createMessageFromJSON creates a message from the uriMatch given the JSON body in the http request.
func createMessageFromJSON(match *uriMatch, r *http.Request) (gogoproto.Message, error) {
requestType := gogoproto.MessageType(match.QueryInputName)
if requestType == nil {
return nil, status.Error(codes.InvalidArgument, "invalid request type")
}

msg, ok := reflect.New(requestType.Elem()).Interface().(gogoproto.Message)
if !ok {
return nil, status.Error(codes.Internal, "failed to cast to proto message")
}

defer r.Body.Close()
limitedReader := io.LimitReader(r.Body, maxBodySize)
err := jsonpb.Unmarshal(limitedReader, msg)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}

return msg, nil
}

// createMessage creates a message from the given uriMatch. If the match has params, the message will be populated
// with the value of those params. Otherwise, an empty message is returned.
func createMessage(match *uriMatch) (gogoproto.Message, error) {
requestType := gogoproto.MessageType(match.QueryInputName)
if requestType == nil {
return nil, status.Error(codes.InvalidArgument, "unknown request type")
}

msg, ok := reflect.New(requestType.Elem()).Interface().(gogoproto.Message)
if !ok {
return nil, status.Error(codes.Internal, "failed to create message instance")
}

// if the uri match has params, we need to populate the message with the values of those params.
if match.HasParams() {
// convert flat params map to nested structure
nestedParams := make(map[string]any)
for key, value := range match.Params {
parts := strings.Split(key, ".")
current := nestedParams

// step through nested levels
for i, part := range parts {
if i == len(parts)-1 {
// Last part - set the value
current[part] = value
} else {
// continue nestedness
if _, exists := current[part]; !exists {
current[part] = make(map[string]any)
}
current = current[part].(map[string]any)
}
}
}

// Configure decoder to handle the nested structure
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: msg,
TagName: "json", // Use json tags as they're simpler
WeaklyTypedInput: true,
})
if err != nil {
return nil, status.Error(codes.Internal, "failed to create message instance")
}

if err := decoder.Decode(nestedParams); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
}
return msg, nil
}
Loading

0 comments on commit 884a7a5

Please sign in to comment.