Skip to content

Commit

Permalink
Don't return duplicate ports from SRV query
Browse files Browse the repository at this point in the history
Fixes #1656

Signed-off-by: Tom Pantelis <[email protected]>
  • Loading branch information
tpantelis committed Dec 4, 2024
1 parent c14f162 commit f61f699
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
47 changes: 44 additions & 3 deletions coredns/plugin/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ var (
}

port2 = mcsv1a1.ServicePort{
Name: "udp",
Protocol: v1.ProtocolUDP,
Port: 42,
}

port3 = mcsv1a1.ServicePort{
Name: "tcp",
Protocol: v1.ProtocolTCP,
Port: 42,
}

port4 = mcsv1a1.ServicePort{
Name: "dns",
Protocol: v1.ProtocolUDP,
Port: 53,
Expand Down Expand Up @@ -886,14 +898,18 @@ func testSRVMultiplePorts() {

t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, mcsv1a1.ClusterSetIP))

t.lh.Resolver.PutEndpointSlices(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1, port2},
t.lh.Resolver.PutEndpointSlices(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1, port2, port3},
newEndpoint(endpointIP, "", true)))

t.lh.Resolver.PutEndpointSlices(newEndpointSlice(namespace1, service1, clusterID2,
[]mcsv1a1.ServicePort{port1, port2, port3, port4},
newEndpoint(serviceIP2, "", true)))

rec = dnstest.NewRecorder(&test.ResponseWriter{})
})

Context("a DNS query of type SRV", func() {
Specify("without a port name should return all the ports", func() {
Specify("without a port name should return all the unique ports", func() {
qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1)

t.executeTestCase(rec, test.Case{
Expand Down Expand Up @@ -929,9 +945,21 @@ func testSRVMultiplePorts() {
test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, port2.Port, service1, namespace1)),
},
})

qname = fmt.Sprintf("%s.%s.%s.%s.svc.clusterset.local.", port3.Name, port3.Protocol, service1, namespace1)

t.executeTestCase(rec, test.Case{
Qname: qname,
Qtype: dns.TypeSRV,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname,
port3.Port, service1, namespace1)),
},
})
})

Specify("with a DNS cluster name requested should return all the ports from the cluster", func() {
Specify("with a DNS cluster name requested should return all the unique ports from the cluster", func() {
qname := fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID, service1, namespace1)

t.executeTestCase(rec, test.Case{
Expand All @@ -943,6 +971,19 @@ func testSRVMultiplePorts() {
test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)),
},
})

qname = fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID2, service1, namespace1)

t.executeTestCase(rec, test.Case{
Qname: qname,
Qtype: dns.TypeSRV,
Rcode: dns.RcodeSuccess,
Answer: []dns.RR{
test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port2.Port, qname)),
test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port4.Port, qname)),
test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)),
},
})
})

Specify("with a port name requested with underscore prefix should return the port", func() {
Expand Down
10 changes: 10 additions & 0 deletions coredns/plugin/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
"github.com/submariner-io/lighthouse/coredns/resolver"
"k8s.io/utils/set"
"sigs.k8s.io/mcs-api/pkg/apis/v1alpha1"
)

Expand Down Expand Up @@ -83,14 +84,23 @@ func (lh *Lighthouse) createSRVRecords(dnsrecords []resolver.DNSRecord, state *r
target = dnsRecord.HostName + "." + target
}

portsSeen := set.New[int32]()

for _, port := range reqPorts {
if portsSeen.Has(port.Port) {
continue
}

portsSeen.Insert(port.Port)

record := &dns.SRV{
Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeSRV, Class: state.QClass(), Ttl: lh.TTL},
Priority: 0,
Weight: 50,
Port: uint16(port.Port), //nolint:gosec // Need to ignore integer conversion error
Target: target,
}

records = append(records, record)
}
}
Expand Down

0 comments on commit f61f699

Please sign in to comment.