summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/miekg/dns/sanitize.go
blob: c415bdd6c36a6b3123e965eace54d9ecf4ee3fbc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package dns

// Dedup removes identical RRs from rrs. It preserves the original ordering.
// The lowest TTL of any duplicates is used in the remaining one. Dedup modifies
// rrs.
// m is used to store the RRs temporary. If it is nil a new map will be allocated.
func Dedup(rrs []RR, m map[string]RR) []RR {
	if m == nil {
		m = make(map[string]RR)
	}
	// Save the keys, so we don't have to call normalizedString twice.
	keys := make([]*string, 0, len(rrs))

	for _, r := range rrs {
		key := normalizedString(r)
		keys = append(keys, &key)
		if _, ok := m[key]; ok {
			// Shortest TTL wins.
			if m[key].Header().Ttl > r.Header().Ttl {
				m[key].Header().Ttl = r.Header().Ttl
			}
			continue
		}

		m[key] = r
	}
	// If the length of the result map equals the amount of RRs we got,
	// it means they were all different. We can then just return the original rrset.
	if len(m) == len(rrs) {
		return rrs
	}

	j := 0
	for i, r := range rrs {
		// If keys[i] lives in the map, we should copy and remove it.
		if _, ok := m[*keys[i]]; ok {
			delete(m, *keys[i])
			rrs[j] = r
			j++
		}

		if len(m) == 0 {
			break
		}
	}

	return rrs[:j]
}

// normalizedString returns a normalized string from r. The TTL
// is removed and the domain name is lowercased. We go from this:
// DomainName<TAB>TTL<TAB>CLASS<TAB>TYPE<TAB>RDATA to:
// lowercasename<TAB>CLASS<TAB>TYPE...
func normalizedString(r RR) string {
	// A string Go DNS makes has: domainname<TAB>TTL<TAB>...
	b := []byte(r.String())

	// find the first non-escaped tab, then another, so we capture where the TTL lives.
	esc := false
	ttlStart, ttlEnd := 0, 0
	for i := 0; i < len(b) && ttlEnd == 0; i++ {
		switch {
		case b[i] == '\\':
			esc = !esc
		case b[i] == '\t' && !esc:
			if ttlStart == 0 {
				ttlStart = i
				continue
			}
			if ttlEnd == 0 {
				ttlEnd = i
			}
		case b[i] >= 'A' && b[i] <= 'Z' && !esc:
			b[i] += 32
		default:
			esc = false
		}
	}

	// remove TTL.
	copy(b[ttlStart:], b[ttlEnd:])
	cut := ttlEnd - ttlStart
	return string(b[:len(b)-cut])
}