diff --git a/conn.go b/conn.go index b3e52d6d..9d880e36 100644 --- a/conn.go +++ b/conn.go @@ -173,12 +173,11 @@ type Event struct { type HostProvider interface { // Init is called first, with the servers specified in the connection string. Init(servers []string) error - // Len returns the number of servers. - Len() int - // Next returns the next server to connect to. retryStart will be true if we've looped through - // all known servers without Connected() being called. + // Next returns the next server to connect to. retryStart should be true if this call to Next + // exhausted the list of known servers without Connected being called. If connecting to this final + // host fails, the connect loop will back off before invoking Next again for a fresh server. Next() (server string, retryStart bool) - // Notify the HostProvider of a successful connection. + // Connected notifies the HostProvider of a successful connection. Connected() } @@ -203,12 +202,12 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti srvs := FormatServers(servers) // Randomize the order of the servers to avoid creating hotspots - stringShuffle(srvs) + shuffleSlice(srvs) ec := make(chan Event, eventChanSize) conn := &Conn{ dialer: net.DialTimeout, - hostProvider: &DNSHostProvider{}, + hostProvider: new(StaticHostProvider), conn: nil, state: StateDisconnected, eventChan: ec, @@ -387,7 +386,7 @@ func (c *Conn) sendEvent(evt Event) { } } -func (c *Conn) connect() error { +func (c *Conn) connect() (err error) { var retryStart bool for { c.serverMu.Lock() @@ -396,18 +395,6 @@ func (c *Conn) connect() error { c.setState(StateConnecting) - if retryStart { - c.flushUnsentRequests(ErrNoServer) - select { - case <-time.After(time.Second): - // pass - case <-c.shouldQuit: - c.setState(StateDisconnected) - c.flushUnsentRequests(ErrClosing) - return ErrClosing - } - } - zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) if err == nil { c.conn = zkConn @@ -419,6 +406,18 @@ func (c *Conn) connect() error { } c.logger.Printf("failed to connect to %s: %v", c.Server(), err) + + if retryStart { + c.flushUnsentRequests(ErrNoServer) + select { + case <-time.After(time.Second): + // pass + case <-c.shouldQuit: + c.setState(StateDisconnected) + c.flushUnsentRequests(ErrClosing) + return ErrClosing + } + } } } diff --git a/dnshostprovider.go b/dnshostprovider.go index f4bba8d0..3dd74c87 100644 --- a/dnshostprovider.go +++ b/dnshostprovider.go @@ -6,10 +6,12 @@ import ( "sync" ) -// DNSHostProvider is the default HostProvider. It currently matches -// the Java StaticHostProvider, resolving hosts from DNS once during -// the call to Init. It could be easily extended to re-query DNS -// periodically or if there is trouble connecting. +// DNSHostProvider is a simple implementation of a HostProvider. It resolves the hosts once during +// Init, and iterates through the resolved addresses for every call to Next. Note that if the +// addresses that back the ZK hosts change, those changes will not be reflected. +// +// Deprecated: Because this HostProvider does not attempt to re-read from DNS, it can lead to issues +// if the addresses of the hosts change. It is preserved for backwards compatibility. type DNSHostProvider struct { mu sync.Mutex // Protects everything, so we can add asynchronous updates later. servers []string @@ -30,7 +32,7 @@ func (hp *DNSHostProvider) Init(servers []string) error { lookupHost = net.LookupHost } - found := []string{} + var found []string for _, server := range servers { host, port, err := net.SplitHostPort(server) if err != nil { @@ -46,43 +48,38 @@ func (hp *DNSHostProvider) Init(servers []string) error { } if len(found) == 0 { - return fmt.Errorf("No hosts found for addresses %q", servers) + return fmt.Errorf("zk: no hosts found for addresses %q", servers) } // Randomize the order of the servers to avoid creating hotspots - stringShuffle(found) + shuffleSlice(found) hp.servers = found - hp.curr = -1 - hp.last = -1 + hp.curr = 0 + hp.last = len(hp.servers) - 1 return nil } -// Len returns the number of servers available -func (hp *DNSHostProvider) Len() int { - hp.mu.Lock() - defer hp.mu.Unlock() - return len(hp.servers) -} - -// Next returns the next server to connect to. retryStart will be true -// if we've looped through all known servers without Connected() being -// called. +// Next returns the next server to connect to. retryStart should be true if this call to Next +// exhausted the list of known servers without Connected being called. If connecting to this final +// host fails, the connect loop will back off before invoking Next again for a fresh server. func (hp *DNSHostProvider) Next() (server string, retryStart bool) { hp.mu.Lock() defer hp.mu.Unlock() - hp.curr = (hp.curr + 1) % len(hp.servers) retryStart = hp.curr == hp.last - if hp.last == -1 { - hp.last = 0 - } - return hp.servers[hp.curr], retryStart + server = hp.servers[hp.curr] + hp.curr = (hp.curr + 1) % len(hp.servers) + return server, retryStart } // Connected notifies the HostProvider of a successful connection. func (hp *DNSHostProvider) Connected() { hp.mu.Lock() defer hp.mu.Unlock() - hp.last = hp.curr + if hp.curr == 0 { + hp.last = len(hp.servers) - 1 + } else { + hp.last = hp.curr - 1 + } } diff --git a/dnshostprovider_test.go b/dnshostprovider_test.go index 48000a5f..00bdea80 100644 --- a/dnshostprovider_test.go +++ b/dnshostprovider_test.go @@ -68,7 +68,6 @@ func newLocalHostPortsFacade(inner HostProvider, ports []int) *localHostPortsFac } } -func (lhpf *localHostPortsFacade) Len() int { return lhpf.inner.Len() } func (lhpf *localHostPortsFacade) Connected() { lhpf.inner.Connected() } func (lhpf *localHostPortsFacade) Init(servers []string) error { return lhpf.inner.Init(servers) } func (lhpf *localHostPortsFacade) Next() (string, bool) { @@ -165,60 +164,78 @@ func TestDNSHostProviderReconnect(t *testing.T) { } } -// TestDNSHostProviderRetryStart tests the `retryStart` functionality -// of DNSHostProvider. -// It's also probably the clearest visual explanation of exactly how -// it works. -func TestDNSHostProviderRetryStart(t *testing.T) { +// TestHostProvidersRetryStart tests the `retryStart` functionality of DNSHostProvider and +// StaticHostProvider. +// It's also probably the clearest visual explanation of exactly how it works. +func TestHostProvidersRetryStart(t *testing.T) { t.Parallel() - hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) { - return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil - }} - - if err := hp.Init([]string{"foo.example.com:12345"}); err != nil { - t.Fatal(err) - } - - testdata := []struct { - retryStartWant bool - callConnected bool - }{ - // Repeated failures. - {false, false}, - {false, false}, - {false, false}, - {true, false}, - {false, false}, - {false, false}, - {true, true}, - - // One success offsets things. - {false, false}, - {false, true}, - {false, true}, - - // Repeated successes. - {false, true}, - {false, true}, - {false, true}, - {false, true}, - {false, true}, - - // And some more failures. - {false, false}, - {false, false}, - {true, false}, // Looped back to last known good server: all alternates failed. - {false, false}, - } - - for i, td := range testdata { - _, retryStartGot := hp.Next() - if retryStartGot != td.retryStartWant { - t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant) - } - if td.callConnected { - hp.Connected() - } + lookupHost := func(host string) ([]string, error) { + return []string{host}, nil + } + + providers := []HostProvider{ + &DNSHostProvider{ + lookupHost: lookupHost, + }, + &StaticHostProvider{ + lookupHost: lookupHost, + }, + } + + for _, hp := range providers { + t.Run(fmt.Sprintf("%T", hp), func(t *testing.T) { + if err := hp.Init([]string{"foo.com:2121", "bar.com:2121", "baz.com:2121"}); err != nil { + t.Fatal(err) + } + + testdata := []struct { + retryStartWant bool + callConnected bool + }{ + // Repeated failures. + {false, false}, + {false, false}, + {true, false}, + {false, false}, + {false, false}, + {true, false}, + {false, true}, + + // One success offsets things. + {false, false}, + {false, true}, + {false, true}, + + // Repeated successes. + {false, true}, + {false, true}, + {false, true}, + {false, true}, + {false, true}, + + // And some more failures. + {false, false}, + {false, false}, + {true, false}, // Looped back to last known good server: all alternates failed. + {false, false}, + {false, false}, + {true, false}, + {false, false}, + {false, false}, + {true, false}, + {false, false}, + } + + for i, td := range testdata { + _, retryStartGot := hp.Next() + if retryStartGot != td.retryStartWant { + t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant) + } + if td.callConnected { + hp.Connected() + } + } + }) } } diff --git a/staticdnshostprovider.go b/staticdnshostprovider.go new file mode 100644 index 00000000..cb298ce3 --- /dev/null +++ b/staticdnshostprovider.go @@ -0,0 +1,115 @@ +package zk + +import ( + "fmt" + "log/slog" + "math/rand" + "net" + "sync" +) + +type hostPort struct { + host, port string +} + +func (hp *hostPort) String() string { + return hp.host + ":" + hp.port +} + +// StaticHostProvider is the default HostProvider, and replaces the now deprecated DNSHostProvider. +// It will iterate through the ZK hosts on every call to Next, and return a random address selected +// from the resolved addresses of the ZK host (if the host is already an IP, it will return that +// directly). It is important to manually resolve and shuffle the addresses because the DNS record +// that backs a host may rarely (or never) change, so repeated calls to connect to this host may +// always connect to the same IP. This mode is the default mode, and matches the Java client's +// implementation. Note that if the host cannot be resolved, Next will return it directly, instead of +// an error. This will cause Dial to fail and the loop will move on to a new host. It is implemented +// as a pound-for-pound copy of the standard Java client's equivalent: +// https://github.com/linkedin/zookeeper/blob/629518b5ea2b26d88a9ec53d5a422afe9b12e452/zookeeper-server/src/main/java/org/apache/zookeeper/client/StaticHostProvider.java#L368 +type StaticHostProvider struct { + mu sync.Mutex // Protects everything, so we can add asynchronous updates later. + servers []hostPort + // nextServer is the index (in servers) of the next server that will be returned by Next. + nextServer int + // lastConnectedServer is the index (in servers) of the last server to which a successful connection + // was established. Used to track whether Next iterated through all available servers without + // successfully connecting. + lastConnectedServer int + lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing. +} + +func (shp *StaticHostProvider) Init(servers []string) error { + shp.mu.Lock() + defer shp.mu.Unlock() + + if shp.lookupHost == nil { + shp.lookupHost = net.LookupHost + } + + var found []hostPort + for _, server := range servers { + host, port, err := net.SplitHostPort(server) + if err != nil { + return err + } + // Perform the lookup to validate the initial set of hosts, but discard the results as the addresses + // will be resolved dynamically when Next is called. + _, err = shp.lookupHost(host) + if err != nil { + return err + } + + found = append(found, hostPort{host, port}) + } + + if len(found) == 0 { + return fmt.Errorf("zk: no hosts found for addresses %q", servers) + } + + // Randomize the order of the servers to avoid creating hotspots + shuffleSlice(found) + + shp.servers = found + shp.nextServer = 0 + shp.lastConnectedServer = len(shp.servers) - 1 + + return nil +} + +// Next returns the next server to connect to. retryStart should be true if this call to Next +// exhausted the list of known servers without Connected being called. If connecting to this final +// host fails, the connect loop will back off before invoking Next again for a fresh server. +func (shp *StaticHostProvider) Next() (server string, retryStart bool) { + shp.mu.Lock() + defer shp.mu.Unlock() + retryStart = shp.nextServer == shp.lastConnectedServer + + next := shp.servers[shp.nextServer] + addrs, err := shp.lookupHost(next.host) + if len(addrs) == 0 { + if err == nil { + // If for whatever reason lookupHosts returned an empty list of addresses but a nil error, use a + // default error + err = fmt.Errorf("zk: no hosts resolved by lookup for %q", next.host) + } + slog.Warn("Could not resolve ZK host", "host", next.host, "err", err) + server = next.String() + } else { + server = addrs[rand.Intn(len(addrs))] + ":" + next.port + } + + shp.nextServer = (shp.nextServer + 1) % len(shp.servers) + + return server, retryStart +} + +// Connected notifies the HostProvider of a successful connection. +func (shp *StaticHostProvider) Connected() { + shp.mu.Lock() + defer shp.mu.Unlock() + if shp.nextServer == 0 { + shp.lastConnectedServer = len(shp.servers) - 1 + } else { + shp.lastConnectedServer = shp.nextServer - 1 + } +} diff --git a/staticdnshostprovider_test.go b/staticdnshostprovider_test.go new file mode 100644 index 00000000..7cd2ae86 --- /dev/null +++ b/staticdnshostprovider_test.go @@ -0,0 +1,71 @@ +package zk + +import "testing" + +// The test in TestHostProvidersRetryStart checks that the semantics of StaticHostProvider's +// implementation of Next are correct, this test only checks that the provider correctly interacts +// with the resolver. +func TestStaticHostProvider(t *testing.T) { + const fooPort, barPort = "2121", "6464" + const fooHost, barHost = "foo.com", "bar.com" + hostToPort := map[string]string{ + fooHost: fooPort, + barHost: barPort, + } + hostToAddrs := map[string][]string{ + fooHost: {"0.0.0.1", "0.0.0.2", "0.0.0.3"}, + barHost: {"0.0.0.4", "0.0.0.5", "0.0.0.6"}, + } + addrToHost := map[string]string{} + for host, addrs := range hostToAddrs { + for _, addr := range addrs { + addrToHost[addr+":"+hostToPort[host]] = host + } + } + + hp := &StaticHostProvider{ + lookupHost: func(host string) ([]string, error) { + addrs, ok := hostToAddrs[host] + if !ok { + t.Fatalf("Unexpected argument to lookupHost %q", host) + } + return addrs, nil + }, + } + + err := hp.Init([]string{fooHost + ":" + fooPort, barHost + ":" + barPort}) + if err != nil { + t.Fatalf("Unexpected err from Init %v", err) + } + + addr1, retryStart := hp.Next() + if retryStart { + t.Fatalf("retryStart should be false") + } + addr2, retryStart := hp.Next() + if !retryStart { + t.Fatalf("retryStart should be true") + } + host1, host2 := addrToHost[addr1], addrToHost[addr2] + if host1 == host2 { + t.Fatalf("Next yielded addresses from same host (%q)", host1) + } + + // Final sanity check that it is shuffling the addresses + seenAddresses := map[string]map[string]bool{ + fooHost: {}, + barHost: {}, + } + for i := 0; i < 10_000; i++ { + addr, _ := hp.Next() + seenAddresses[addrToHost[addr]][addr] = true + } + + for host, addrs := range hostToAddrs { + for _, addr := range addrs { + if !seenAddresses[host][addr+":"+hostToPort[host]] { + t.Fatalf("expected addr %q for host %q not seen (seen: %v)", addr, host, seenAddresses) + } + } + } +} diff --git a/tcp_server_test.go b/tcp_server_test.go index 09254948..72bbd09c 100644 --- a/tcp_server_test.go +++ b/tcp_server_test.go @@ -1,17 +1,13 @@ package zk import ( - "fmt" - "math/rand" "net" "testing" "time" ) func WithListenServer(t *testing.T, test func(server string)) { - startPort := int(rand.Int31n(6000) + 10000) - server := fmt.Sprintf("localhost:%d", startPort) - l, err := net.Listen("tcp", server) + l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to start listen server: %v", err) } @@ -26,7 +22,7 @@ func WithListenServer(t *testing.T, test func(server string)) { handleRequest(conn) }() - test(server) + test(l.Addr().String()) } // Handles incoming requests. diff --git a/util.go b/util.go index 5a92b66b..9244a0bb 100644 --- a/util.go +++ b/util.go @@ -49,12 +49,11 @@ func FormatServers(servers []string) []string { return srvs } -// stringShuffle performs a Fisher-Yates shuffle on a slice of strings -func stringShuffle(s []string) { - for i := len(s) - 1; i > 0; i-- { - j := rand.Intn(i + 1) +// shuffleSlice invokes rand.Shuffle on the given slice. +func shuffleSlice[T any](s []T) { + rand.Shuffle(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] - } + }) } // validatePath will make sure a path is valid before sending the request