Skip to content

Commit

Permalink
feat: add host flag
Browse files Browse the repository at this point in the history
  • Loading branch information
shifty11 committed Dec 21, 2023
1 parent 5a8b733 commit 0850d4a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 17 deletions.
1 change: 1 addition & 0 deletions tools/kystrap/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,5 @@ func init() {
// Cobra also supports local flags, which will only run
// when this action is called directly.
rootCmd.PersistentFlags().BoolP(yesFlag, "y", false, "Skip all prompts and use provided or default values")
rootCmd.SetErrPrefix(promptui.IconBad)
}
77 changes: 60 additions & 17 deletions tools/kystrap/cmd/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protoreflect"
"regexp"
"strings"
"time"
)
Expand All @@ -22,6 +23,16 @@ type executionInfo struct {
method protoreflect.MethodDescriptor
position int
success bool
host string
}

func newExecutionInfo() executionInfo {
return executionInfo{
method: nil,
position: 0,
success: false,
host: "localhost:50051",
}
}

func findDescriptorMethod(name string) protoreflect.MethodDescriptor {
Expand All @@ -47,6 +58,28 @@ func setPosition(execution *executionInfo) {
}
}

var addressRegex = regexp.MustCompile(`^([a-zA-Z0-9-]+\.)*[a-zA-Z0-9-]+:([0-9]+)+$`)

func promptAddress(execution *executionInfo) error {
validate := func(input string) error {
if !addressRegex.MatchString(input) {
return errors.New("invalid address... must be in the format host:port")
}
return nil
}
prompt := promptui.Prompt{
Label: "Enter the host and port of the runtime server",
Default: execution.host,
Validate: validate,
}
result, err := prompt.Run()
if err != nil {
return err
}
execution.host = result
return nil
}

func promptMethod(execution *executionInfo) error {
prompt := promptui.Select{
Label: "Which method do you want to test?",
Expand Down Expand Up @@ -132,7 +165,7 @@ func promptAction(isYesNo bool) (action, error) {
return action(result), nil
}

func dial() (*grpc.ClientConn, error) {
func dial(host string) (*grpc.ClientConn, error) {
dialTime := 10 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), dialTime)
defer cancel()
Expand All @@ -144,9 +177,7 @@ func dial() (*grpc.ClientConn, error) {

network := "tcp"

target := "localhost:50051"

cc, err := grpcurl.BlockingDial(ctx, network, target, creds, opts...)
cc, err := grpcurl.BlockingDial(ctx, network, host, creds, opts...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -190,8 +221,8 @@ func printError(cmd *cobra.Command, err error) {
cmd.Printf("%s %s\n", promptui.IconBad, err.Error())
}

func performMethodCall(cmd *cobra.Command, method protoreflect.MethodDescriptor, data string) (bool, error) {
cc, err := dial()
func performMethodCall(cmd *cobra.Command, address string, method protoreflect.MethodDescriptor, data string) (bool, error) {
cc, err := dial(address)
if err != nil {
return false, err
}
Expand All @@ -216,9 +247,8 @@ func performMethodCall(cmd *cobra.Command, method protoreflect.MethodDescriptor,

out := new(bytes.Buffer)
h := &grpcurl.DefaultEventHandler{
Out: out,
Formatter: formatter,
VerbosityLevel: 0,
Out: out,
Formatter: formatter,
}

err = grpcurl.InvokeRPC(ctx, types.Rdk.DescriptorSource, cc, string(method.FullName()), nil, h, rf.Next)
Expand Down Expand Up @@ -288,31 +318,43 @@ func runTestIntegration(
}
}

success, err := performMethodCall(cmd, execution.method, data)
success, err := performMethodCall(cmd, execution.host, execution.method, data)
execution.success = success
return err
}

func CmdTestIntegration() *cobra.Command {
const flagMethod = "method"
const flagData = "data"
const flagHost = "host"

cmd := &cobra.Command{
Use: "test",
Short: "Test a runtime integration",
RunE: func(cmd *cobra.Command, args []string) error {
execution := executionInfo{}
defaultMethod, _ := cmd.Flags().GetString(flagMethod)
if defaultMethod != "" {
execution.method = findDescriptorMethod(defaultMethod)
execution := newExecutionInfo()
method, _ := cmd.Flags().GetString(flagMethod)
data, _ := cmd.Flags().GetString(flagData)
host, _ := cmd.Flags().GetString(flagHost)

if method != "" {
execution.method = findDescriptorMethod(method)
if execution.method == nil {
return errors.New(fmt.Sprintf("invalid gRPC method %s", defaultMethod))
return errors.New(fmt.Sprintf("invalid gRPC method %s", method))
}
setPosition(&execution)
}

defaultData, _ := cmd.Flags().GetString(flagData)
err := runTestIntegration(cmd, &execution, defaultData)
if host != "" {
execution.host = host
} else if !skipPrompts(cmd) {
err := promptAddress(&execution)
if err != nil {
return err
}
}

err := runTestIntegration(cmd, &execution, data)
if err != nil {
return err
}
Expand Down Expand Up @@ -349,6 +391,7 @@ func CmdTestIntegration() *cobra.Command {
}
cmd.Flags().StringP(flagMethod, "m", "", "gRPC method that you want to test")
cmd.Flags().StringP(flagData, "d", "", "data that you want to send with the gRPC method call")
cmd.Flags().StringP(flagHost, "H", "", "host and port of the runtime server (ex: localhost:50051)")
return cmd
}

Expand Down

0 comments on commit 0850d4a

Please sign in to comment.