Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wait on daemon down #2279

Merged
merged 11 commits into from
Jul 17, 2024
2 changes: 1 addition & 1 deletion client/cmd/down.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ var downCmd = &cobra.Command{
return err
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
defer cancel()

conn, err := DialClientGRPCServer(ctx, daemonAddr)
Expand Down
36 changes: 34 additions & 2 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,23 @@ func (e *Engine) Stop() error {

e.close()
e.wgConnWorker.Wait()
log.Infof("stopped Netbird Engine")
return nil

maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)

for {
if !e.IsWGIfaceUp() {
log.Infof("stopped Netbird Engine")
return nil
}

select {
case <-timeout:
return fmt.Errorf("timeout when waiting for interface shutdown")
default:
time.Sleep(100 * time.Millisecond)
}
}
}

// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
Expand Down Expand Up @@ -1533,3 +1548,20 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}

func (e *Engine) IsWGIfaceUp() bool {
if e == nil || e.wgInterface == nil {
return false
}
iface, err := net.InterfaceByName(e.wgInterface.Name())
if err != nil {
log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
return false
}

if iface.Flags&net.FlagUp != 0 {
return true
}

return false
}
16 changes: 14 additions & 2 deletions client/internal/networkmonitor/monitor_bsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package networkmonitor

import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
Expand All @@ -21,11 +22,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
return fmt.Errorf("failed to open routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err)
}
}()

go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket")
}
}()

for {
select {
case <-ctx.Done():
Expand All @@ -34,7 +44,9 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
Expand Down
22 changes: 20 additions & 2 deletions client/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
}

// Down engine work in the daemon.
func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()

Expand All @@ -593,7 +593,25 @@ func (s *Server) Down(_ context.Context, _ *proto.DownRequest) (*proto.DownRespo
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)

return &proto.DownResponse{}, nil
maxWaitTime := 5 * time.Second
timeout := time.After(maxWaitTime)

engine := s.connectClient.Engine()

for {
if !engine.IsWGIfaceUp() {
return &proto.DownResponse{}, nil
}

select {
case <-ctx.Done():
return &proto.DownResponse{}, nil
case <-timeout:
return nil, fmt.Errorf("failed to shut down properly")
default:
time.Sleep(100 * time.Millisecond)
}
}
}

// Status returns the daemon status
Expand Down
64 changes: 35 additions & 29 deletions management/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package client

import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
Expand All @@ -11,15 +10,11 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"

"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"

"github.com/cenkalti/backoff/v4"

"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
Expand Down Expand Up @@ -51,26 +46,21 @@ type GrpcClient struct {

// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
var conn *grpc.ClientConn

if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}

mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
mgmCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed creating connection to Management Service %v", err)
log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err
}

Expand Down Expand Up @@ -326,25 +316,41 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
if !c.ready() {
return nil, fmt.Errorf(errMsgNoMgmtConnection)
}

loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})

var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()

var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
log.Printf("Login error: %v", err)
return err
}

return nil
}

err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}

loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt registration message: %s", err)
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
}

Expand Down
35 changes: 12 additions & 23 deletions signal/client/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package client

import (
"context"
"crypto/tls"
"fmt"
"io"
"sync"
Expand All @@ -14,9 +13,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -64,28 +60,21 @@ func (c *GrpcClient) Close() error {

// NewClient creates a new Signal client
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn

transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())

if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}

sigCtx, cancel := context.WithTimeout(ctx, client.ConnectTimeout)
defer cancel()
conn, err := grpc.DialContext(
sigCtx,
addr,
transportOption,
nbgrpc.WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}))

err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed to connect to the signalling server %v", err)
log.Errorf("failed to connect to the signalling server: %v", err)
return nil, err
}

Expand Down Expand Up @@ -408,7 +397,7 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,

if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
//todo send something??
// todo send something??
}
}
}
Expand Down
43 changes: 43 additions & 0 deletions util/grpc/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@ package grpc

import (
"context"
"crypto/tls"
"net"
"os/user"
"runtime"
"time"

"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"

nbnet "github.com/netbirdio/netbird/util/net"
)
Expand Down Expand Up @@ -35,3 +41,40 @@ func WithCustomDialer() grpc.DialOption {
return conn, nil
})
}

// grpcDialBackoff is the backoff mechanism for the grpc calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
b.Clock = backoff.SystemClock
return backoff.WithContext(b, ctx)
}

func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())

if tlsEnabled {
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))
}

connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, err
}

return conn, nil
}
Loading