From f9a0619703e32a937c1c8302367089186ae276ee Mon Sep 17 00:00:00 2001 From: vyzo Date: Wed, 7 Apr 2021 00:10:37 +0300 Subject: [PATCH 1/7] refactor Resolver to support custom per-TLD resolvers --- mock.go | 31 +++++++++++++ resolve.go | 114 +++++++++++++++++++++--------------------------- resolve_test.go | 4 +- util.go | 57 ++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 66 deletions(-) create mode 100644 mock.go create mode 100644 util.go diff --git a/mock.go b/mock.go new file mode 100644 index 0000000..3a054f8 --- /dev/null +++ b/mock.go @@ -0,0 +1,31 @@ +package madns + +import ( + "context" + "net" +) + +type MockResolver struct { + IP map[string][]net.IPAddr + TXT map[string][]string +} + +var _ BasicResolver = (*MockResolver)(nil) + +func (r *MockResolver) LookupIPAddr(ctx context.Context, name string) ([]net.IPAddr, error) { + results, ok := r.IP[name] + if ok { + return results, nil + } else { + return []net.IPAddr{}, nil + } +} + +func (r *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) { + results, ok := r.TXT[name] + if ok { + return results, nil + } else { + return []string{}, nil + } +} diff --git a/resolve.go b/resolve.go index fd8d5c2..4a3e452 100644 --- a/resolve.go +++ b/resolve.go @@ -9,59 +9,71 @@ import ( ) var ResolvableProtocols = []ma.Protocol{DnsaddrProtocol, Dns4Protocol, Dns6Protocol, DnsProtocol} -var DefaultResolver = &Resolver{Backend: net.DefaultResolver} +var DefaultResolver = &Resolver{def: net.DefaultResolver} const dnsaddrTXTPrefix = "dnsaddr=" -type Backend interface { +// BasicResolver is a low level interface for DNS resolution +type BasicResolver interface { LookupIPAddr(context.Context, string) ([]net.IPAddr, error) LookupTXT(context.Context, string) ([]string, error) } +// Resolver is an object capable of resolving dns multiaddrs by using one or more BasicResolvers; +// it supports custom per TLD resolvers. +// It also implements the BasicResolver interface so that it can act as a custom per TLD resolver. type Resolver struct { - Backend Backend + def BasicResolver + tld map[string]BasicResolver } -var _ Backend = (*MockBackend)(nil) +var _ BasicResolver = (*Resolver)(nil) -type MockBackend struct { - IP map[string][]net.IPAddr - TXT map[string][]string +// NewResolver creates a new Resolver instance with the specified options +func NewResolver(opts ...Option) (*Resolver, error) { + r := &Resolver{def: net.DefaultResolver} + for _, opt := range opts { + err := opt(r) + if err != nil { + return nil, err + } + } + + return r, nil } -func (r *MockBackend) LookupIPAddr(ctx context.Context, name string) ([]net.IPAddr, error) { - results, ok := r.IP[name] - if ok { - return results, nil - } else { - return []net.IPAddr{}, nil +type Option func(*Resolver) error + +// WithDefaultResolver is an option that specifies the default basic resolver, +// which resolves any TLD that doesn't have a custom resolver. +// Defaults to net.DefaultResolver +func WithDefaultResolver(def BasicResolver) Option { + return func(r *Resolver) error { + r.def = def + return nil } } -func (r *MockBackend) LookupTXT(ctx context.Context, name string) ([]string, error) { - results, ok := r.TXT[name] - if ok { - return results, nil - } else { - return []string{}, nil +// WithTLDResolver specifies a custom resolver for a TLD. +func WithTLDResolver(tld string, rslv BasicResolver) Option { + return func(r *Resolver) error { + r.tld[tld] = rslv + return nil } } -func Matches(maddr ma.Multiaddr) (matches bool) { - ma.ForEach(maddr, func(c ma.Component) bool { - switch c.Protocol().Code { - case DnsProtocol.Code, Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: - matches = true - } - return !matches - }) - return matches -} +func (r *Resolver) getResolver(domain string) BasicResolver { + parts := strings.Split(domain, ".") + tld := parts[len(parts)-1] -func Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { - return DefaultResolver.Resolve(ctx, maddr) + rslv, ok := r.tld[tld] + if !ok { + rslv = r.def + } + return rslv } +// Resolve resolves a DNS multiaddr. func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { var results []ma.Multiaddr for i := 0; maddr != nil; i++ { @@ -99,6 +111,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia proto := resolve.Protocol() value := resolve.Value() + rslv := r.getResolver(value) // resolve the dns component var resolved []ma.Multiaddr @@ -114,7 +127,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia // differentiating between IPv6 and IPv4. A v4-in-v6 // AAAA record will _look_ like an A record to us and // there's nothing we can do about that. - records, err := r.Backend.LookupIPAddr(ctx, value) + records, err := rslv.LookupIPAddr(ctx, value) if err != nil { return nil, err } @@ -155,7 +168,7 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia // matching the result of step 2. // First, lookup the TXT record - records, err := r.Backend.LookupTXT(ctx, "_dnsaddr."+value) + records, err := rslv.LookupTXT(ctx, "_dnsaddr."+value) if err != nil { return nil, err } @@ -235,37 +248,10 @@ func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multia return results, nil } -// counts the number of components in the multiaddr -func addrLen(maddr ma.Multiaddr) int { - length := 0 - ma.ForEach(maddr, func(_ ma.Component) bool { - length++ - return true - }) - return length -} - -// trims `offset` components from the beginning of the multiaddr. -func offset(maddr ma.Multiaddr, offset int) ma.Multiaddr { - _, after := ma.SplitFunc(maddr, func(c ma.Component) bool { - if offset == 0 { - return true - } - offset-- - return false - }) - return after +func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) { + return r.getResolver(domain).LookupIPAddr(ctx, domain) } -// takes the cross product of two sets of multiaddrs -// -// assumes `a` is non-empty. -func cross(a, b []ma.Multiaddr) []ma.Multiaddr { - res := make([]ma.Multiaddr, 0, len(a)*len(b)) - for _, x := range a { - for _, y := range b { - res = append(res, x.Encapsulate(y)) - } - } - return res +func (r *Resolver) LookupTXT(ctx context.Context, txt string) ([]string, error) { + return r.getResolver(txt).LookupTXT(ctx, txt) } diff --git a/resolve_test.go b/resolve_test.go index 1334611..ee446bf 100644 --- a/resolve_test.go +++ b/resolve_test.go @@ -29,7 +29,7 @@ var txtd = "dnsaddr=" + txtmd.String() var txte = "dnsaddr=" + txtme.String() func makeResolver() *Resolver { - mock := &MockBackend{ + mock := &MockResolver{ IP: map[string][]net.IPAddr{ "example.com": []net.IPAddr{ip4a, ip4b, ip6a, ip6b}, }, @@ -38,7 +38,7 @@ func makeResolver() *Resolver { "_dnsaddr.matching.com": []string{txtc, txtd, txte, "not a dnsaddr", "dnsaddr=/foobar"}, }, } - resolver := &Resolver{Backend: mock} + resolver := &Resolver{def: mock} return resolver } diff --git a/util.go b/util.go new file mode 100644 index 0000000..2953ddd --- /dev/null +++ b/util.go @@ -0,0 +1,57 @@ +package madns + +import ( + "context" + + ma "github.com/multiformats/go-multiaddr" +) + +func Matches(maddr ma.Multiaddr) (matches bool) { + ma.ForEach(maddr, func(c ma.Component) bool { + switch c.Protocol().Code { + case DnsProtocol.Code, Dns4Protocol.Code, Dns6Protocol.Code, DnsaddrProtocol.Code: + matches = true + } + return !matches + }) + return matches +} + +func Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { + return DefaultResolver.Resolve(ctx, maddr) +} + +// counts the number of components in the multiaddr +func addrLen(maddr ma.Multiaddr) int { + length := 0 + ma.ForEach(maddr, func(_ ma.Component) bool { + length++ + return true + }) + return length +} + +// trims `offset` components from the beginning of the multiaddr. +func offset(maddr ma.Multiaddr, offset int) ma.Multiaddr { + _, after := ma.SplitFunc(maddr, func(c ma.Component) bool { + if offset == 0 { + return true + } + offset-- + return false + }) + return after +} + +// takes the cross product of two sets of multiaddrs +// +// assumes `a` is non-empty. +func cross(a, b []ma.Multiaddr) []ma.Multiaddr { + res := make([]ma.Multiaddr, 0, len(a)*len(b)) + for _, x := range a { + for _, y := range b { + res = append(res, x.Encapsulate(y)) + } + } + return res +} From 807a4262d367e71b660cd7e659eadfe1ec147e64 Mon Sep 17 00:00:00 2001 From: vyzo Date: Wed, 7 Apr 2021 19:53:12 +0300 Subject: [PATCH 2/7] more nuanced handling of custom resolvers --- resolve.go | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/resolve.go b/resolve.go index 4a3e452..eebd31f 100644 --- a/resolve.go +++ b/resolve.go @@ -20,11 +20,12 @@ type BasicResolver interface { } // Resolver is an object capable of resolving dns multiaddrs by using one or more BasicResolvers; -// it supports custom per TLD resolvers. -// It also implements the BasicResolver interface so that it can act as a custom per TLD resolver. +// it supports custom per domain/TLD resolvers. +// It also implements the BasicResolver interface so that it can act as a custom per domain/TLD +// resolver. type Resolver struct { - def BasicResolver - tld map[string]BasicResolver + def BasicResolver + custom map[string]BasicResolver } var _ BasicResolver = (*Resolver)(nil) @@ -54,23 +55,29 @@ func WithDefaultResolver(def BasicResolver) Option { } } -// WithTLDResolver specifies a custom resolver for a TLD. -func WithTLDResolver(tld string, rslv BasicResolver) Option { +// WithTLDResolver specifies a custom resolver for a domain/TLD. +func WithDomainResolver(domain string, rslv BasicResolver) Option { return func(r *Resolver) error { - r.tld[tld] = rslv + r.custom[domain] = rslv return nil } } func (r *Resolver) getResolver(domain string) BasicResolver { - parts := strings.Split(domain, ".") - tld := parts[len(parts)-1] + rslv, ok := r.custom[domain] + if ok { + return rslv + } - rslv, ok := r.tld[tld] - if !ok { - rslv = r.def + for i := strings.Index(domain, "."); i != -1; i = strings.Index(domain, ",") { + domain = domain[i+1:] + rslv, ok = r.custom[domain] + if ok { + return rslv + } } - return rslv + + return r.def } // Resolve resolves a DNS multiaddr. From 95b70c86e010f26f05ed43eb2928fef9f8722c19 Mon Sep 17 00:00:00 2001 From: vyzo Date: Wed, 7 Apr 2021 20:13:21 +0300 Subject: [PATCH 3/7] fix custom resolver map initialization --- resolve.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/resolve.go b/resolve.go index eebd31f..5cbfc68 100644 --- a/resolve.go +++ b/resolve.go @@ -58,6 +58,9 @@ func WithDefaultResolver(def BasicResolver) Option { // WithTLDResolver specifies a custom resolver for a domain/TLD. func WithDomainResolver(domain string, rslv BasicResolver) Option { return func(r *Resolver) error { + if r.custom == nil { + r.custom = make(map[string]BasicResolver) + } r.custom[domain] = rslv return nil } From cf42e5628e2fc6dd969609d59defcad348f5fa15 Mon Sep 17 00:00:00 2001 From: vyzo Date: Wed, 7 Apr 2021 20:13:42 +0300 Subject: [PATCH 4/7] add custom resolver test --- resolve_test.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/resolve_test.go b/resolve_test.go index ee446bf..92d6f5a 100644 --- a/resolve_test.go +++ b/resolve_test.go @@ -1,6 +1,7 @@ package madns import ( + "bytes" "context" "net" "testing" @@ -234,3 +235,87 @@ func TestBadDomain(t *testing.T) { t.Error("expected malformed address to fail to parse") } } + +func TestCustomResolver(t *testing.T) { + ip1 := net.IPAddr{IP: net.ParseIP("1.2.3.4")} + ip2 := net.IPAddr{IP: net.ParseIP("2.3.4.5")} + ip3 := net.IPAddr{IP: net.ParseIP("3.4.5.6")} + ip4 := net.IPAddr{IP: net.ParseIP("4.5.6.8")} + ip5 := net.IPAddr{IP: net.ParseIP("5.6.8.9")} + def := &MockResolver{ + IP: map[string][]net.IPAddr{ + "example.com": []net.IPAddr{ip1}, + }, + } + custom1 := &MockResolver{ + IP: map[string][]net.IPAddr{ + "custom.test": []net.IPAddr{ip2}, + "another.custom.test": []net.IPAddr{ip3}, + }, + } + custom2 := &MockResolver{ + IP: map[string][]net.IPAddr{ + "more.custom.test": []net.IPAddr{ip4}, + "some.more.custom.test": []net.IPAddr{ip5}, + }, + } + + rslv, err := NewResolver( + WithDefaultResolver(def), + WithDomainResolver("custom.test", custom1), + WithDomainResolver("more.custom.test", custom2), + ) + if err != nil { + t.Fatal(err) + } + + sameIP := func(ip1, ip2 net.IPAddr) bool { + return bytes.Equal(ip1.IP, ip2.IP) + } + + ctx := context.Background() + res, err := rslv.LookupIPAddr(ctx, "example.com") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip1) { + t.Fatal("expected result to be ip1") + } + + res, err = rslv.LookupIPAddr(ctx, "custom.test") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip2) { + t.Fatal("expected result to be ip2") + } + + res, err = rslv.LookupIPAddr(ctx, "another.custom.test") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip3) { + t.Fatal("expected result to be ip3") + } + + res, err = rslv.LookupIPAddr(ctx, "more.custom.test") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip4) { + t.Fatal("expected result to be ip4") + } + + res, err = rslv.LookupIPAddr(ctx, "some.more.custom.test") + if err != nil { + t.Fatal(err) + } + + if len(res) != 1 || !sameIP(res[0], ip5) { + t.Fatal("expected result to be ip5") + } +} From 1e9c3ecd7d8775b9971935b6cc3f3106f67fa302 Mon Sep 17 00:00:00 2001 From: vyzo Date: Thu, 8 Apr 2021 17:53:10 +0300 Subject: [PATCH 5/7] fix bug; it's a dot damn it! --- resolve.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resolve.go b/resolve.go index 5cbfc68..612c818 100644 --- a/resolve.go +++ b/resolve.go @@ -72,7 +72,7 @@ func (r *Resolver) getResolver(domain string) BasicResolver { return rslv } - for i := strings.Index(domain, "."); i != -1; i = strings.Index(domain, ",") { + for i := strings.Index(domain, "."); i != -1; i = strings.Index(domain, ".") { domain = domain[i+1:] rslv, ok = r.custom[domain] if ok { From d965d38967a6a5b0576519c7e37ddd98559fa23f Mon Sep 17 00:00:00 2001 From: vyzo Date: Thu, 8 Apr 2021 18:43:20 +0300 Subject: [PATCH 6/7] improve test with failsafe to ensure specific resolver supersedes generic one --- resolve_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/resolve_test.go b/resolve_test.go index 92d6f5a..6cb230a 100644 --- a/resolve_test.go +++ b/resolve_test.go @@ -242,6 +242,7 @@ func TestCustomResolver(t *testing.T) { ip3 := net.IPAddr{IP: net.ParseIP("3.4.5.6")} ip4 := net.IPAddr{IP: net.ParseIP("4.5.6.8")} ip5 := net.IPAddr{IP: net.ParseIP("5.6.8.9")} + ip6 := net.IPAddr{IP: net.ParseIP("6.8.9.10")} def := &MockResolver{ IP: map[string][]net.IPAddr{ "example.com": []net.IPAddr{ip1}, @@ -251,6 +252,7 @@ func TestCustomResolver(t *testing.T) { IP: map[string][]net.IPAddr{ "custom.test": []net.IPAddr{ip2}, "another.custom.test": []net.IPAddr{ip3}, + "more.custom.test": []net.IPAddr{ip6}, }, } custom2 := &MockResolver{ From 45cdfcfc1fd6617653765a102486dfff550f85b7 Mon Sep 17 00:00:00 2001 From: vyzo Date: Thu, 8 Apr 2021 18:43:50 +0300 Subject: [PATCH 7/7] add comments clarifying the left-to-right selection rule in WithDomainResolver --- resolve.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/resolve.go b/resolve.go index 612c818..705fed7 100644 --- a/resolve.go +++ b/resolve.go @@ -55,7 +55,9 @@ func WithDefaultResolver(def BasicResolver) Option { } } -// WithTLDResolver specifies a custom resolver for a domain/TLD. +// WithDomainResolver specifies a custom resolver for a domain/TLD. +// Custom resolver selection matches domains left to right, with more specific resolvers +// superseding generic ones. func WithDomainResolver(domain string, rslv BasicResolver) Option { return func(r *Resolver) error { if r.custom == nil { @@ -67,6 +69,9 @@ func WithDomainResolver(domain string, rslv BasicResolver) Option { } func (r *Resolver) getResolver(domain string) BasicResolver { + // we match left-to-right, with more specific resolvers superseding generic ones. + // So for a domain a.b.c, we will try a.b,c, b.c, c, and fallback to the default if + // there is no match rslv, ok := r.custom[domain] if ok { return rslv