diff --git a/coredns/resolver/controller.go b/coredns/resolver/controller.go index 35e8f7dae..c30dae6f2 100644 --- a/coredns/resolver/controller.go +++ b/coredns/resolver/controller.go @@ -120,7 +120,7 @@ func (c *controller) getAllEndpointSlices(forEPS *discovery.EndpointSlice) []*di var epSlices []*discovery.EndpointSlice for i := range list { eps := list[i].(*discovery.EndpointSlice) - if !isLegacyEndpointSlice(eps) { + if !isOnBroker(eps) && !isLegacyEndpointSlice(eps) { epSlices = append(epSlices, eps) } } @@ -157,8 +157,11 @@ func (c *controller) onServiceImportDelete(obj runtime.Object, _ int) bool { } func (c *controller) ignoreEndpointSlice(eps *discovery.EndpointSlice) bool { - isOnBroker := eps.Namespace != eps.Labels[constants.LabelSourceNamespace] - return isOnBroker || (isLegacyEndpointSlice(eps) && len(c.getAllEndpointSlices(eps)) > 0) + return isOnBroker(eps) || (isLegacyEndpointSlice(eps) && len(c.getAllEndpointSlices(eps)) > 0) +} + +func isOnBroker(eps *discovery.EndpointSlice) bool { + return eps.Namespace != eps.Labels[constants.LabelSourceNamespace] } func isLegacyEndpointSlice(eps *discovery.EndpointSlice) bool { diff --git a/coredns/resolver/controller_test.go b/coredns/resolver/controller_test.go index 2399ec69a..2514ede90 100644 --- a/coredns/resolver/controller_test.go +++ b/coredns/resolver/controller_test.go @@ -134,6 +134,10 @@ var _ = Describe("Controller", func() { ) epsName2 = eps.Name t.createEndpointSlice(eps) + + epsOnBroker := eps.DeepCopy() + epsOnBroker.Namespace = test.RemoteNamespace + t.createEndpointSlice(epsOnBroker) }) Specify("GetDNSRecords should return their DNS record", func() { diff --git a/coredns/resolver/resolver_suite_test.go b/coredns/resolver/resolver_suite_test.go index 249ba5ee6..5cbebfa3a 100644 --- a/coredns/resolver/resolver_suite_test.go +++ b/coredns/resolver/resolver_suite_test.go @@ -19,6 +19,7 @@ limitations under the License. package resolver_test import ( + "context" "flag" "fmt" "reflect" @@ -175,12 +176,13 @@ func (t *testDriver) awaitDNSRecordsFound(ns, name, cluster, hostname string, ex var records []resolver.DNSRecord var found, isHeadless bool - err := wait.PollImmediate(50*time.Millisecond, 5*time.Second, func() (bool, error) { - records, isHeadless, found = t.resolver.GetDNSRecords(ns, name, cluster, hostname) - sortRecords(records) + err := wait.PollUntilContextTimeout(context.Background(), 50*time.Millisecond, 5*time.Second, true, + func(_ context.Context) (bool, error) { + records, isHeadless, found = t.resolver.GetDNSRecords(ns, name, cluster, hostname) + sortRecords(records) - return found && isHeadless == expIsHeadless && reflect.DeepEqual(records, expRecords), nil - }) + return found && isHeadless == expIsHeadless && reflect.DeepEqual(records, expRecords), nil + }) if err == nil { return }