summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/udp_linux.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/miekg/dns/udp_linux.go')
-rw-r--r--vendor/github.com/miekg/dns/udp_linux.go115
1 files changed, 115 insertions, 0 deletions
diff --git a/vendor/github.com/miekg/dns/udp_linux.go b/vendor/github.com/miekg/dns/udp_linux.go
index 033df4239..13747ed34 100644
--- a/vendor/github.com/miekg/dns/udp_linux.go
+++ b/vendor/github.com/miekg/dns/udp_linux.go
@@ -13,8 +13,34 @@ package dns
import (
"net"
"syscall"
+ "unsafe"
+
+ "github.com/miekg/dns/internal/socket"
+)
+
+const (
+ sizeofInet6Pktinfo = 0x14
+ sizeofInetPktinfo = 0xc
+ protocolIP = 0
+ protocolIPv6 = 41
)
+type inetPktinfo struct {
+ Ifindex int32
+ Spec_dst [4]byte /* in_addr */
+ Addr [4]byte /* in_addr */
+}
+
+type inet6Pktinfo struct {
+ Addr [16]byte /* in6_addr */
+ Ifindex int32
+}
+
+type inetControlMessage struct {
+ Src net.IP // source address, specifying only
+ Dst net.IP // destination address, receiving only
+}
+
// setUDPSocketOptions sets the UDP socket options.
// This function is implemented on a per platform basis. See udp_*.go for more details
func setUDPSocketOptions(conn *net.UDPConn) error {
@@ -103,3 +129,92 @@ func getUDPSocketName(conn *net.UDPConn) (syscall.Sockaddr, error) {
defer file.Close()
return syscall.Getsockname(int(file.Fd()))
}
+
+// marshalInetPacketInfo marshals a ipv4 control message, returning
+// the byte slice for the next marshal, if any
+func marshalInetPacketInfo(b []byte, cm *inetControlMessage) []byte {
+ m := socket.ControlMessage(b)
+ m.MarshalHeader(protocolIP, syscall.IP_PKTINFO, sizeofInetPktinfo)
+ if cm != nil {
+ pi := (*inetPktinfo)(unsafe.Pointer(&m.Data(sizeofInetPktinfo)[0]))
+ if ip := cm.Src.To4(); ip != nil {
+ copy(pi.Spec_dst[:], ip)
+ }
+ }
+ return m.Next(sizeofInetPktinfo)
+}
+
+// marshalInet6PacketInfo marshals a ipv6 control message, returning
+// the byte slice for the next marshal, if any
+func marshalInet6PacketInfo(b []byte, cm *inetControlMessage) []byte {
+ m := socket.ControlMessage(b)
+ m.MarshalHeader(protocolIPv6, syscall.IPV6_PKTINFO, sizeofInet6Pktinfo)
+ if cm != nil {
+ pi := (*inet6Pktinfo)(unsafe.Pointer(&m.Data(sizeofInet6Pktinfo)[0]))
+ if ip := cm.Src.To16(); ip != nil && ip.To4() == nil {
+ copy(pi.Addr[:], ip)
+ }
+ }
+ return m.Next(sizeofInet6Pktinfo)
+}
+
+func parseInetPacketInfo(cm *inetControlMessage, b []byte) {
+ pi := (*inetPktinfo)(unsafe.Pointer(&b[0]))
+ if len(cm.Dst) < net.IPv4len {
+ cm.Dst = make(net.IP, net.IPv4len)
+ }
+ copy(cm.Dst, pi.Addr[:])
+}
+
+func parseInet6PacketInfo(cm *inetControlMessage, b []byte) {
+ pi := (*inet6Pktinfo)(unsafe.Pointer(&b[0]))
+ if len(cm.Dst) < net.IPv6len {
+ cm.Dst = make(net.IP, net.IPv6len)
+ }
+ copy(cm.Dst, pi.Addr[:])
+}
+
+// parseUDPSocketDst takes out-of-band data from ReadMsgUDP and parses it for
+// the Dst address
+func parseUDPSocketDst(oob []byte) (net.IP, error) {
+ cm := new(inetControlMessage)
+ ms, err := socket.ControlMessage(oob).Parse()
+ if err != nil {
+ return nil, err
+ }
+ for _, m := range ms {
+ lvl, typ, l, err := m.ParseHeader()
+ if err != nil {
+ return nil, err
+ }
+ if lvl == protocolIPv6 { // IPv6
+ if typ == syscall.IPV6_PKTINFO && l >= sizeofInet6Pktinfo {
+ parseInet6PacketInfo(cm, m.Data(l))
+ }
+ } else if lvl == protocolIP { // IPv4
+ if typ == syscall.IP_PKTINFO && l >= sizeofInetPktinfo {
+ parseInetPacketInfo(cm, m.Data(l))
+ }
+ }
+ }
+ return cm.Dst, nil
+}
+
+// marshalUDPSocketSrc takes the given src address and returns out-of-band data
+// to give to WriteMsgUDP
+func marshalUDPSocketSrc(src net.IP) []byte {
+ var oob []byte
+ // If the dst is definitely an ipv6, then use ipv6 control to respond
+ // otherwise use ipv4 because the ipv6 marshal ignores ipv4 messages.
+ // See marshalInet6PacketInfo
+ cm := new(inetControlMessage)
+ cm.Src = src
+ if src.To4() == nil {
+ oob = make([]byte, socket.ControlMessageSpace(sizeofInet6Pktinfo))
+ marshalInet6PacketInfo(oob, cm)
+ } else {
+ oob = make([]byte, socket.ControlMessageSpace(sizeofInetPktinfo))
+ marshalInetPacketInfo(oob, cm)
+ }
+ return oob
+}