Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(client/v2): refactor of flags #17306

Merged
merged 5 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions client/v2/autocli/flag/address.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

type addressStringType struct{}

func (a addressStringType) NewValue(ctx context.Context, b *Builder) Value {
func (a addressStringType) NewValue(_ context.Context, b *Builder) Value {
return &addressValue{addressCodec: b.AddressCodec}
}

Expand All @@ -27,7 +27,7 @@ func (a addressStringType) DefaultValue() string {

type validatorAddressStringType struct{}

func (a validatorAddressStringType) NewValue(ctx context.Context, b *Builder) Value {
func (a validatorAddressStringType) NewValue(_ context.Context, b *Builder) Value {
return &addressValue{addressCodec: b.ValidatorAddressCodec}
}

Expand Down Expand Up @@ -61,7 +61,7 @@ func (a *addressValue) Set(s string) error {
}

func (a addressValue) Type() string {
return "bech32 account address key name"
return "bech32 account address"
}

type consensusAddressStringType struct{}
Expand Down
2 changes: 1 addition & 1 deletion client/v2/autocli/flag/binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type binaryType struct{}

var _ Value = (*fileBinaryValue)(nil)

func (f binaryType) NewValue(_ context.Context, _ *Builder) Value {
func (f binaryType) NewValue(context.Context, *Builder) Value {
return &fileBinaryValue{}
}

Expand Down
303 changes: 303 additions & 0 deletions client/v2/autocli/flag/builder.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
package flag

import (
"context"
"fmt"
"strconv"

cosmos_proto "github.com/cosmos/cosmos-proto"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"

autocliv1 "cosmossdk.io/api/cosmos/autocli/v1"
"cosmossdk.io/client/v2/internal/util"
"cosmossdk.io/core/address"
)

Expand Down Expand Up @@ -55,3 +65,296 @@ func (b *Builder) DefineScalarFlagType(scalarName string, flagType Type) {
b.init()
b.scalarFlagTypes[scalarName] = flagType
}

func (b *Builder) AddMessageFlags(ctx context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions) (*MessageBinder, error) {
return b.addMessageFlags(ctx, flagSet, messageType, commandOptions, namingOptions{})
}

// AddMessageFlags adds flags for each field in the message to the flag set.
func (b *Builder) addMessageFlags(ctx context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions, options namingOptions) (*MessageBinder, error) {
fields := messageType.Descriptor().Fields()
numFields := fields.Len()
handler := &MessageBinder{
messageType: messageType,
}

isPositional := map[string]bool{}
hasVarargs := false
hasOptional := false
n := len(commandOptions.PositionalArgs)
// positional args are also parsed using a FlagSet so that we can reuse all the same parsers
handler.positionalFlagSet = pflag.NewFlagSet("positional", pflag.ContinueOnError)
for i, arg := range commandOptions.PositionalArgs {
isPositional[arg.ProtoField] = true

field := fields.ByName(protoreflect.Name(arg.ProtoField))
if field == nil {
return nil, fmt.Errorf("can't find field %s on %s", arg.ProtoField, messageType.Descriptor().FullName())
}

if arg.Optional && arg.Varargs {
return nil, fmt.Errorf("positional argument %s can't be both optional and varargs", arg.ProtoField)
}

if arg.Varargs {
if i != n-1 {
return nil, fmt.Errorf("varargs positional argument %s must be the last argument", arg.ProtoField)
}

hasVarargs = true
}

if arg.Optional {
if i != n-1 {
return nil, fmt.Errorf("optional positional argument %s must be the last argument", arg.ProtoField)
}

hasOptional = true
}

_, hasValue, err := b.addFieldFlag(
ctx,
handler.positionalFlagSet,
field,
&autocliv1.FlagOptions{Name: fmt.Sprintf("%d", i)},
namingOptions{},
)
if err != nil {
return nil, err
}

handler.positionalArgs = append(handler.positionalArgs, fieldBinding{
field: field,
hasValue: hasValue,
})
}

if hasVarargs {
handler.CobraArgs = cobra.MinimumNArgs(n - 1)
handler.hasVarargs = true
} else if hasOptional {
handler.CobraArgs = cobra.RangeArgs(n-1, n)
handler.hasOptional = true
} else {
handler.CobraArgs = cobra.ExactArgs(n)
}

// validate flag options
for name := range commandOptions.FlagOptions {
if fields.ByName(protoreflect.Name(name)) == nil {
return nil, fmt.Errorf("can't find field %s on %s specified as a flag", name, messageType.Descriptor().FullName())
}
}

flagOptsByFlagName := map[string]*autocliv1.FlagOptions{}
for i := 0; i < numFields; i++ {
field := fields.Get(i)
if isPositional[string(field.Name())] {
continue
}

flagOpts := commandOptions.FlagOptions[string(field.Name())]
name, hasValue, err := b.addFieldFlag(ctx, flagSet, field, flagOpts, options)
flagOptsByFlagName[name] = flagOpts
if err != nil {
return nil, err
}

handler.flagBindings = append(handler.flagBindings, fieldBinding{
hasValue: hasValue,
field: field,
})
}

flagSet.VisitAll(func(flag *pflag.Flag) {
opts := flagOptsByFlagName[flag.Name]
if opts != nil {
// This is a bit of hacking around the pflag API, but
// we need to set these options here using Flag.VisitAll because the flag
// constructors that pflag gives us (StringP, Int32P, etc.) do not
// actually return the *Flag instance
flag.Deprecated = opts.Deprecated
flag.ShorthandDeprecated = opts.ShorthandDeprecated
flag.Hidden = opts.Hidden
}
})

return handler, nil
}

// bindPageRequest create a flag for pagination
func (b *Builder) bindPageRequest(ctx context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor) (HasValue, error) {
return b.addMessageFlags(
ctx,
flagSet,
util.ResolveMessageType(b.TypeResolver, field.Message()),
&autocliv1.RpcCommandOptions{},
namingOptions{Prefix: "page-"},
)
}

// namingOptions specifies internal naming options for flags.
type namingOptions struct {
// Prefix is a prefix to prepend to all flags.
Prefix string
}

// addFieldFlag adds a flag for the provided field to the flag set.
func (b *Builder) addFieldFlag(ctx context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor, opts *autocliv1.FlagOptions, options namingOptions) (name string, hasValue HasValue, err error) {
if opts == nil {
opts = &autocliv1.FlagOptions{}
}

if field.Kind() == protoreflect.MessageKind && field.Message().FullName() == "cosmos.base.query.v1beta1.PageRequest" {
hasValue, err := b.bindPageRequest(ctx, flagSet, field)
return "", hasValue, err
}

name = opts.Name
if name == "" {
name = options.Prefix + util.DescriptorKebabName(field)
}

usage := opts.Usage
if usage == "" {
usage = util.DescriptorDocs(field)
}

shorthand := opts.Shorthand
defaultValue := opts.DefaultValue

if typ := b.resolveFlagType(field); typ != nil {
if defaultValue == "" {
defaultValue = typ.DefaultValue()
}

val := typ.NewValue(ctx, b)
flagSet.AddFlag(&pflag.Flag{
Name: name,
Shorthand: shorthand,
Usage: usage,
DefValue: defaultValue,
Value: val,
})
return name, val, nil
}

// use the built-in pflag StringP, Int32P, etc. functions
var val HasValue

if field.IsList() {
val = bindSimpleListFlag(flagSet, field.Kind(), name, shorthand, usage)
} else if field.IsMap() {
keyKind := field.MapKey().Kind()
valKind := field.MapValue().Kind()
val = bindSimpleMapFlag(flagSet, keyKind, valKind, name, shorthand, usage)
} else {
val = bindSimpleFlag(flagSet, field.Kind(), name, shorthand, usage)
}

// This is a bit of hacking around the pflag API, but the
// defaultValue is set in this way because this is much easier than trying
// to parse the string into the types that StringSliceP, Int32P, etc. expect
if defaultValue != "" {
err = flagSet.Set(name, defaultValue)
}
return name, val, err
}

func (b *Builder) resolveFlagType(field protoreflect.FieldDescriptor) Type {
typ := b.resolveFlagTypeBasic(field)
if field.IsList() {
if typ != nil {
return compositeListType{simpleType: typ}
}
return nil
}
if field.IsMap() {
keyKind := field.MapKey().Kind()
valType := b.resolveFlagType(field.MapValue())
if valType != nil {
switch keyKind {
case protoreflect.StringKind:
ct := new(compositeMapType[string])
ct.keyValueResolver = func(s string) (string, error) { return s, nil }
ct.valueType = valType
ct.keyType = "string"
return ct
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
ct := new(compositeMapType[int32])
ct.keyValueResolver = func(s string) (int32, error) {
i, err := strconv.ParseInt(s, 10, 32)
return int32(i), err
}
ct.valueType = valType
ct.keyType = "int32"
return ct
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
ct := new(compositeMapType[int64])
ct.keyValueResolver = func(s string) (int64, error) {
i, err := strconv.ParseInt(s, 10, 64)
return i, err
}
ct.valueType = valType
ct.keyType = "int64"
return ct
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
ct := new(compositeMapType[uint32])
ct.keyValueResolver = func(s string) (uint32, error) {
i, err := strconv.ParseUint(s, 10, 32)
return uint32(i), err
}
ct.valueType = valType
ct.keyType = "uint32"
return ct
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
ct := new(compositeMapType[uint64])
ct.keyValueResolver = func(s string) (uint64, error) {
i, err := strconv.ParseUint(s, 10, 64)
return i, err
}
ct.valueType = valType
ct.keyType = "uint64"
return ct
case protoreflect.BoolKind:
ct := new(compositeMapType[bool])
ct.keyValueResolver = strconv.ParseBool
ct.valueType = valType
ct.keyType = "bool"
return ct
}
return nil

}
return nil
}

return typ
}

func (b *Builder) resolveFlagTypeBasic(field protoreflect.FieldDescriptor) Type {
scalar := proto.GetExtension(field.Options(), cosmos_proto.E_Scalar)
if scalar != nil {
b.init()
if typ, ok := b.scalarFlagTypes[scalar.(string)]; ok {
return typ
}
}

switch field.Kind() {
case protoreflect.BytesKind:
return binaryType{}
case protoreflect.EnumKind:
return enumType{enum: field.Enum()}
case protoreflect.MessageKind:
b.init()
if flagType, ok := b.messageFlagTypes[field.Message().FullName()]; ok {
return flagType
}
return jsonMessageFlagType{
messageDesc: field.Message(),
}
default:
return nil
}
}
2 changes: 1 addition & 1 deletion client/v2/autocli/flag/coin.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type coinValue struct {
value *basev1beta1.Coin
}

func (c coinType) NewValue(_ context.Context, _ *Builder) Value {
func (c coinType) NewValue(context.Context, *Builder) Value {
return &coinValue{}
}

Expand Down
Loading