Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade HostProvider #6

Merged
merged 8 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,11 @@
type HostProvider interface {
// Init is called first, with the servers specified in the connection string.
Init(servers []string) error
// Len returns the number of servers.
Len() int
// Next returns the next server to connect to. retryStart will be true if we've looped through
// all known servers without Connected() being called.
// Next returns the next server to connect to. retryStart should be true if this call to Next
// exhausted the list of known servers without Connected being called. If connecting to this final
// host fails, the connect loop will back off before invoking Next again for a fresh server.
Next() (server string, retryStart bool)
// Notify the HostProvider of a successful connection.
// Connected notifies the HostProvider of a successful connection.
Connected()
}

Expand All @@ -203,12 +202,12 @@
srvs := FormatServers(servers)

// Randomize the order of the servers to avoid creating hotspots
stringShuffle(srvs)
shuffleSlice(srvs)

ec := make(chan Event, eventChanSize)
conn := &Conn{
dialer: net.DialTimeout,
hostProvider: &DNSHostProvider{},
hostProvider: new(StaticHostProvider),
conn: nil,
state: StateDisconnected,
eventChan: ec,
Expand Down Expand Up @@ -387,7 +386,7 @@
}
}

func (c *Conn) connect() error {
func (c *Conn) connect() (err error) {
var retryStart bool
for {
c.serverMu.Lock()
Expand All @@ -396,18 +395,6 @@

c.setState(StateConnecting)

if retryStart {
c.flushUnsentRequests(ErrNoServer)
select {
case <-time.After(time.Second):
// pass
case <-c.shouldQuit:
c.setState(StateDisconnected)
c.flushUnsentRequests(ErrClosing)
return ErrClosing
}
}

zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout)
if err == nil {
c.conn = zkConn
Expand All @@ -419,6 +406,18 @@
}

c.logger.Printf("failed to connect to %s: %v", c.Server(), err)

if retryStart {
c.flushUnsentRequests(ErrNoServer)
select {
case <-time.After(time.Second):
// pass
case <-c.shouldQuit:
c.setState(StateDisconnected)
c.flushUnsentRequests(ErrClosing)
return ErrClosing
}
}
}
}

Expand Down Expand Up @@ -758,17 +757,17 @@

binary.BigEndian.PutUint32(buf[:4], uint32(n))

c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10))

Check failure on line 760 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetWriteDeadline` is not checked (errcheck)

Check failure on line 760 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetWriteDeadline` is not checked (errcheck)
_, err = c.conn.Write(buf[:n+4])
c.conn.SetWriteDeadline(time.Time{})

Check failure on line 762 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetWriteDeadline` is not checked (errcheck)

Check failure on line 762 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetWriteDeadline` is not checked (errcheck)
if err != nil {
return err
}

// Receive and decode a connect response.
c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10))

Check failure on line 768 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetReadDeadline` is not checked (errcheck)

Check failure on line 768 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetReadDeadline` is not checked (errcheck)
_, err = io.ReadFull(c.conn, buf[:4])
c.conn.SetReadDeadline(time.Time{})

Check failure on line 770 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetReadDeadline` is not checked (errcheck)

Check failure on line 770 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetReadDeadline` is not checked (errcheck)
if err != nil {
return err
}
Expand Down Expand Up @@ -833,7 +832,7 @@
c.requests[req.xid] = req
c.requestsLock.Unlock()

c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))

Check failure on line 835 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetWriteDeadline` is not checked (errcheck)

Check failure on line 835 in conn.go

View workflow job for this annotation

GitHub Actions / lint (1.21)

Error return value of `c.conn.SetWriteDeadline` is not checked (errcheck)
_, err = c.conn.Write(c.buf[:n+4])
c.conn.SetWriteDeadline(time.Time{})
if err != nil {
Expand Down
47 changes: 22 additions & 25 deletions dnshostprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"sync"
)

// DNSHostProvider is the default HostProvider. It currently matches
// the Java StaticHostProvider, resolving hosts from DNS once during
// the call to Init. It could be easily extended to re-query DNS
// periodically or if there is trouble connecting.
// DNSHostProvider is a simple implementation of a HostProvider. It resolves the hosts once during
// Init, and iterates through the resolved addresses for every call to Next. Note that if the
// addresses that back the ZK hosts change, those changes will not be reflected.
//
// Deprecated: Because this HostProvider does not attempt to re-read from DNS, it can lead to issues
// if the addresses of the hosts change. It is preserved for backwards compatibility.
type DNSHostProvider struct {
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
Expand All @@ -30,7 +32,7 @@ func (hp *DNSHostProvider) Init(servers []string) error {
lookupHost = net.LookupHost
}

found := []string{}
var found []string
for _, server := range servers {
host, port, err := net.SplitHostPort(server)
if err != nil {
Expand All @@ -46,43 +48,38 @@ func (hp *DNSHostProvider) Init(servers []string) error {
}

if len(found) == 0 {
return fmt.Errorf("No hosts found for addresses %q", servers)
return fmt.Errorf("zk: no hosts found for addresses %q", servers)
}

// Randomize the order of the servers to avoid creating hotspots
stringShuffle(found)
shuffleSlice(found)

hp.servers = found
hp.curr = -1
hp.last = -1
hp.curr = 0
hp.last = len(hp.servers) - 1

return nil
}

// Len returns the number of servers available
func (hp *DNSHostProvider) Len() int {
hp.mu.Lock()
defer hp.mu.Unlock()
return len(hp.servers)
}

// Next returns the next server to connect to. retryStart will be true
// if we've looped through all known servers without Connected() being
// called.
// Next returns the next server to connect to. retryStart should be true if this call to Next
// exhausted the list of known servers without Connected being called. If connecting to this final
// host fails, the connect loop will back off before invoking Next again for a fresh server.
func (hp *DNSHostProvider) Next() (server string, retryStart bool) {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.curr = (hp.curr + 1) % len(hp.servers)
retryStart = hp.curr == hp.last
if hp.last == -1 {
hp.last = 0
}
return hp.servers[hp.curr], retryStart
server = hp.servers[hp.curr]
hp.curr = (hp.curr + 1) % len(hp.servers)
return server, retryStart
}

// Connected notifies the HostProvider of a successful connection.
func (hp *DNSHostProvider) Connected() {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.last = hp.curr
if hp.curr == 0 {
hp.last = len(hp.servers) - 1
} else {
hp.last = hp.curr - 1
}
}
125 changes: 71 additions & 54 deletions dnshostprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ func newLocalHostPortsFacade(inner HostProvider, ports []int) *localHostPortsFac
}
}

func (lhpf *localHostPortsFacade) Len() int { return lhpf.inner.Len() }
func (lhpf *localHostPortsFacade) Connected() { lhpf.inner.Connected() }
func (lhpf *localHostPortsFacade) Init(servers []string) error { return lhpf.inner.Init(servers) }
func (lhpf *localHostPortsFacade) Next() (string, bool) {
Expand Down Expand Up @@ -165,60 +164,78 @@ func TestDNSHostProviderReconnect(t *testing.T) {
}
}

// TestDNSHostProviderRetryStart tests the `retryStart` functionality
// of DNSHostProvider.
// It's also probably the clearest visual explanation of exactly how
// it works.
func TestDNSHostProviderRetryStart(t *testing.T) {
// TestHostProvidersRetryStart tests the `retryStart` functionality of DNSHostProvider and
// StaticHostProvider.
// It's also probably the clearest visual explanation of exactly how it works.
func TestHostProvidersRetryStart(t *testing.T) {
t.Parallel()

hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
}}

if err := hp.Init([]string{"foo.example.com:12345"}); err != nil {
t.Fatal(err)
}

testdata := []struct {
retryStartWant bool
callConnected bool
}{
// Repeated failures.
{false, false},
{false, false},
{false, false},
{true, false},
{false, false},
{false, false},
{true, true},

// One success offsets things.
{false, false},
{false, true},
{false, true},

// Repeated successes.
{false, true},
{false, true},
{false, true},
{false, true},
{false, true},

// And some more failures.
{false, false},
{false, false},
{true, false}, // Looped back to last known good server: all alternates failed.
{false, false},
}

for i, td := range testdata {
_, retryStartGot := hp.Next()
if retryStartGot != td.retryStartWant {
t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant)
}
if td.callConnected {
hp.Connected()
}
lookupHost := func(host string) ([]string, error) {
return []string{host}, nil
}

providers := []HostProvider{
&DNSHostProvider{
lookupHost: lookupHost,
},
&StaticHostProvider{
lookupHost: lookupHost,
},
}

for _, hp := range providers {
t.Run(fmt.Sprintf("%T", hp), func(t *testing.T) {
if err := hp.Init([]string{"foo.com:2121", "bar.com:2121", "baz.com:2121"}); err != nil {
t.Fatal(err)
}

testdata := []struct {
retryStartWant bool
callConnected bool
}{
// Repeated failures.
{false, false},
{false, false},
{true, false},
{false, false},
{false, false},
{true, false},
{false, true},

// One success offsets things.
{false, false},
{false, true},
{false, true},

// Repeated successes.
{false, true},
{false, true},
{false, true},
{false, true},
{false, true},

// And some more failures.
{false, false},
{false, false},
{true, false}, // Looped back to last known good server: all alternates failed.
{false, false},
{false, false},
{true, false},
{false, false},
{false, false},
{true, false},
{false, false},
}

for i, td := range testdata {
_, retryStartGot := hp.Next()
if retryStartGot != td.retryStartWant {
t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant)
}
if td.callConnected {
hp.Connected()
}
}
})
}
}
Loading
Loading