How to Write Your Own DNS Proxy On Go Programming Language?
For a long time, we wanted to solve the problem with advertising. The easiest way to do this on all devices was to raise your DNS server with blocking requests to receive IP addresses of advertising domains.
The first thing we started with using dnsmasq, but we wanted to load lists from the Internet and get some statistics on the use. So we decided to write our own DNS server.
Of course, it is not completely written from scratch, all work with DNS is taken from this library but the additional code has been implemented too.
Configuration
The program starts, of course, with the download of the configuration file. Immediately we thought about the need to automatically download the config when changing it in order to avoid server restart. For this, the notify package came in handy.
type Config struct { Nameservers []string `yaml:"nameservers"` Blocklist []string `yaml:"blocklist"` BlockAddress4 string `yaml:"blockAddress4"` BlockAddress6 string `yaml:"blockAddress6"` ConfigUpdate bool `yaml:"configUpdate"` UpdateInterval time.Duration `yaml:"updateInterval"` }
Structure of the config
Here the most interesting thing is to keep track of the configuration file updates. With the help of the library, this is quite simple: we create a Watcher, attach a file to it and listen to events from the channel. True Go!
func configWatcher() { watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatal(err) } defer watcher.Close() err = watcher.Add(*configFile) if err != nil { log.Fatal(err) } for { select { case event := <-watcher.Events: if event.Op&fsnotify.Write == fsnotify.Write { log.Println("Config file updated, reload config") c, err := loadConfig() if err != nil { log.Println("Bad config: ", err) } else { log.Println("Config successfuly updated") config = c if !c.ConfigUpdate { return } } } case err := <-watcher.Errors: log.Println("error:", err) } } }
BlackList
Of course, because the goal is to block out unwanted sites, then they need to be stored somewhere. For this, with a small load, a simple hash table of empty structures is suitable, where the blocked domain is used as the key. we want to note that you need a point on the end.
But since we do not have a simultaneous read/write, we can do without mutexes.
type BlackList struct { data map[string]struct{} } func (b *BlackList) Add(server string) bool { server = strings.Trim(server, " ") if len(server) == 0 { return false } if !strings.HasSuffix(server, ".") { server += "." } b.data[server] = struct{}{} return true } func (b *BlackList) Contains(server string) bool { _, ok := b.data[server] return ok }
Caching
Initially, we thought to do without it, after all, all my devices do not create a significant number of requests. But one fine evening my server somehow found and started flooding it with the same query at a frequency of ~ 100 rps. Yes, that’s a bit, but after all, requests are proxied to real namespace servers (in our case Google) and it would be very unpleasant to get a lock.
The main caching problem is a large number of different requests and they need to be stored separately, so a two-level hash table was obtained.
type Cache interface { Get(reqType uint16, domain string) dns.RR Set(reqType uint16, domain string, ip dns.RR) } type CacheItem struct { Ip dns.RR Die time.Time } type MemoryCache struct { cache map[uint16]map[string]*CacheItem locker sync.RWMutex } func (c *MemoryCache) Get(reqType uint16, domain string) dns.RR { c.locker.RLock() defer c.locker.RUnlock() if m, ok := c.cache[reqType]; ok { if ip, ok := m[domain]; ok { if ip.Die.After(time.Now()) { return ip.Ip } } } return nil } func (c *MemoryCache) Set(reqType uint16, domain string, ip dns.RR) { c.locker.Lock() defer c.locker.Unlock() var m map[string]*CacheItem m, ok := c.cache[reqType] if !ok { m = make(map[string]*CacheItem) c.cache[reqType] = m } m[domain] = &CacheItem{ Ip: ip, Die: time.Now().Add(time.Duration(ip.Header().Ttl) * time.Second), } }
Handler
Of course, the main part of the program is the handler of incoming requests, so we left it for dessert. The basic logic is something like this: we get a request, check its presence in the blacklist, check the presence in the cache, proxy the request for real servers.
The main interest is the function of the lock. In it, we simultaneously send a request to all servers (if we have time before the response arrives) and wait for a successful answer from at least one of them.
func Lookup(req *dns.Msg) (*dns.Msg, error) { c := &dns.Client{ Net: "tcp", ReadTimeout: time.Second * 5, WriteTimeout: time.Second * 5, } qName := req.Question[0].Name res := make(chan *dns.Msg, 1) var wg sync.WaitGroup L := func(nameserver string) { defer wg.Done() r, _, err := c.Exchange(req, nameserver) totalRequestsToGoogle.Inc() if err != nil { log.Printf("%s socket error on %s", qName, nameserver) log.Printf("error:%s", err.Error()) return } if r != nil && r.Rcode != dns.RcodeSuccess { if r.Rcode == dns.RcodeServerFailure { return } } select { case res <- r: default: } } ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() // Start lookup on each nameserver top-down, in every second for _, nameserver := range config.Nameservers { wg.Add(1) go L(nameserver) // but exit early, if we have an answer select { case r := <-res: return r, nil case <-ticker.C: continue } } // wait for all the namservers to finish wg.Wait() select { case r := <-res: return r, nil default: return nil, errors.New("can't resolve ip for" + qName) } }
Metrics
For the metric, we will use the client from Prometheus. It is very simple to use, you must first declare the counter, then register it and call the Inc () method in the right place. The main thing is not to forget to run the web server with Prometheus handler so that it can read metrics.
We think main does not need a presentation and description. In this article, the code is presented in a shortened format
The completed code can be found in the repository here. Also in the repository, there is a file for Docker and an approximate CI configuration for Gitlab.
var ( totalRequestsTcp = prometheus.NewCounter(prometheus.CounterOpts(prometheus.Opts{ Namespace: "dns", Subsystem: "requests", Name: "total", Help: "total requests", ConstLabels: map[string]string{ "type": "tcp", }, })) ) func runPrometheus() { prometheus.MustRegister(totalRequestsTcp) http.Handle("/metrics", promhttp.Handler()) log.Fatal(http.ListenAndServe(":9970", nil)) }