#!/usr/bin/env python

# Copyright (c) 2012 Aalto University and RWTH Aachen University.
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation
# files (the "Software"), to deal in the Software without
# restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following
# conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.

"""HIP name look-up daemon for HIPL hosts file and DNS servers."""

# Usage: Basic usage without any command line options.
#        See getopt() for the options.
#
# Working test cases with hipdnsproxy
# - Interoperates with libc and dnsmasq
# - Resolvconf(on/off) + dnsmasq (on/off)
#    - initial look up (check HIP and non-hip look up)
#      - check that ctrl+c restores /etc/resolv.conf
#    - change access network (check HIP and non-hip look up)
#      - check that ctrl+c restores /etc/resolv.conf
# - Watch out for cached entries! Restart dnmasq and hipdnsproxy after
#   each test.
# - Test name resolution with following methods:
#   - Non-HIP records
#   - Hostname to HIT resolution
#     - HITs and LSIs from /etc/hip/hosts
#     - On-the-fly generated LSI; HIT either from from DNS or hosts
#     - HI records from DNS
#   - PTR records: maps HITs to hostnames from /etc/hip/hosts
#
# Actions to resolv.conf files and dnsproxy hooking:
# - Dnsmasq=on, revolvconf=on: only hooks dnsmasq
# - Dnsmasq=off, revolvconf=on: rewrites /etc/resolvconf/run/resolv.conf
# - Dnsmasq=on, revolvconf=off: hooks dnsmasq and rewrites /etc/resolv.conf
# - Dnsmasq=off, revolvconf=off: rewrites /etc/resolv.conf
#
# TBD:
# - rewrite the code to more object oriented
# - the use of alternative (multiple) dns servers
# - implement TTLs for cache
#   - applicable to HITs, LSIs and IP addresses
#   - host files: forever (purged when the file is changed)
#   - dns records: follow DNS TTL
# - bind to ::1, not 127.0.0.1 (setsockopt blah blah)
# - remove hardcoded addresses from ifconfig commands
# - compatibility with "unbound"

import copy
import errno
import fileinput
import logging
import logging.handlers
import os
import re
import select
import signal
import socket
import subprocess
import sys
import time

#local imports

# prepending (instead of appending) to make sure hosts.py does not
# collide with the system default
import hosts
import util
from DNS import Serialize, DeSerialize


DEFAULT_HOSTS = '/etc/hosts'
LSI_RE = re.compile(r'(?P<lsi>1\.\d+\.\d+\.\d+)')


def usage(unused_utyp, *msg):
    """Print usage instructions and exit."""
    sys.stderr.write('Usage: %s\n' % os.path.split(sys.argv[0])[1])
    if msg:
        sys.stderr.write('Error: %r\n' % msg)
    sys.exit(1)


# Done: forking affects this. Fixed in forkme
MYID = '%d-%d' % (time.time(), os.getpid())


def add_hit_ip_map(hit, addr):
    """Add IP for HIT."""
    logging.info('Associating HIT %s with IP %s', hit, addr)
    subprocess.check_call(['hipconf', 'daemon', 'add', 'map', hit, addr],
                          stdout=open(os.devnull, 'w'),
                          stderr=subprocess.STDOUT)


def hit_to_lsi(hit):
    """Return LSI for HIT if found."""
    output = subprocess.Popen(['hipconf', 'daemon', 'hit-to-lsi', hit],
                              stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT).stdout

    for line in output:
        match = LSI_RE.search(line)
        if match:
            return match.group('lsi')


def lsi_to_hit(lsi):
    """Return HIT for LSI if found."""
    output = subprocess.Popen(['hipconf', 'daemon', 'lsi-to-hit', lsi],
                              stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT).stdout

    for line in output:
        match = hosts.HIT_RE.search(line)
        if match:
            return match.group('hit')


def is_reverse_hit_query(name):
    """Check if the query is a reverse query to a HIT.

    >>> is_reverse_hit_query('::1')
    False
    >>> is_reverse_hit_query('8.e.b.8.b.3.c.9.1.a.0.c.e.e.2.c.c.e.d.0.9.c.'
    ...                      '9.a.e.1.0.0.1.0.0.2.hit-to-ip.infrahip.net')
    True
    """
    if (name.endswith('.1.0.0.1.0.0.2.hit-to-ip.infrahip.net') and
        len(name) == 86):
        return True
    return False


class ResolvConf:
    """Handle resolv.conf."""
    re_nameserver = re.compile(r'nameserver\s+(\S+)$')

    def __init__(self, dnsp, filetowatch=None):
        self.dnsmasq_initd_script = '/etc/init.d/dnsmasq'
        if os.path.exists('/etc/redhat-release'):
            self.distro = 'redhat'
            self.rh_before = '# See how we were called.'
            self.rh_inject = '. /etc/sysconfig/dnsmasq # Added by hipdnsproxy'
        elif os.path.exists('/etc/debian_version'):
            self.distro = 'debian'
        else:
            self.distro = 'unknown'

        if self.distro == 'redhat':
            self.dnsmasq_defaults = '/etc/sysconfig/dnsmasq'
            if not os.path.exists(self.dnsmasq_defaults):
                open(self.dnsmasq_defaults, 'w').close()
        else:
            self.dnsmasq_defaults = '/etc/default/dnsmasq'

        self.dnsmasq_defaults_backup = (self.dnsmasq_defaults +
                                        '.backup.hipdnsproxy')

        if (os.path.isdir('/etc/resolvconf/.') and
            os.path.exists('/sbin/resolvconf') and
            os.path.exists('/etc/resolvconf/run/resolv.conf')):
            self.use_resolvconf = True
        else:
            self.use_resolvconf = False
        self.use_dnsmasq_hook = False
        self.resolvconfd = None
        self.dnsmasq_hook = None
        self.alt_port = None

        self.dnsmasq_resolv = '/var/run/dnsmasq/resolv.conf'
        self.resolvconf_run = '/etc/resolvconf/run/resolv.conf'
        if self.use_resolvconf:
            self.resolvconf_towrite = '/etc/resolvconf/run/resolv.conf'
        else:
            self.resolvconf_towrite = '/etc/resolv.conf'

        self.dnsmasq_restart = [self.dnsmasq_initd_script, 'restart']
        if filetowatch is None:
            self.filetowatch = self.guess_resolvconf()
        self.resolvconf_orig = self.filetowatch
        self.old_rc_mtime = os.stat(self.filetowatch).st_mtime
        self.resolvconf_bkname = '%s-%s' % (self.resolvconf_towrite, MYID)
        self.overwrite_resolv_conf = dnsp.overwrite_resolv_conf

    def guess_resolvconf(self):
        """Guess the location of the correct resolv.conf file."""
        if self.use_dnsmasq_hook and self.use_resolvconf:
            return self.dnsmasq_resolv
        elif self.use_resolvconf:
            return self.resolvconf_run
        else:
            return '/etc/resolv.conf'

    def reread_old_rc(self):
        """Re-read the old resolv.conf."""
        self.resolvconfd = {}
        ifile = open(self.filetowatch)
        for line in ifile.xreadlines():
            line = line.strip()
            if 'nameserver' not in self.resolvconfd:
                match = self.re_nameserver.match(line)
                if match:
                    self.resolvconfd['nameserver'] = match.group(1)
        return self.resolvconfd

    def set_dnsmasq_hook(self, dnsp):
        """Set the dnsmasq hook."""
        self.alt_port = dnsp.bind_alt_port
        self.use_dnsmasq_hook = True
        logging.info('Dnsmasq-resolvconf installation detected')
        if self.distro == 'redhat':
            self.dnsmasq_hook = ('OPTIONS+="--no-hosts --no-resolv '
                                 '--cache-size=0 --server=%s#%s"\n' % (
                                     dnsp.bind_ip, self.alt_port))
        else:
            self.dnsmasq_hook = ('DNSMASQ_OPTS="--no-hosts --no-resolv '
                                 '--cache-size=0 --server=%s#%s"\n' % (
                                     dnsp.bind_ip, self.alt_port))
        return

    def old_has_changed(self):
        """Return true if the old resolv.conf file has changed."""
        old_rc_mtime = os.stat(self.filetowatch).st_mtime
        if old_rc_mtime != self.old_rc_mtime:
            self.reread_old_rc()
            self.old_rc_mtime = old_rc_mtime
            return True
        else:
            return False

    def save_resolvconf_dnsmasq(self):
        """Inject the dnsmasq hook and restart dnsmasq if necessary."""
        if self.use_dnsmasq_hook:
            if os.path.exists(self.dnsmasq_defaults):
                ifile = open(self.dnsmasq_defaults, 'r')
                line = ifile.readline()
                ifile.close()
                if (line.find('server=127') != -1 and
                    line[:line.find('server=')] ==
                    self.dnsmasq_hook[:self.dnsmasq_hook.find('server=')]):
                    logging.info('Dnsmasq configuration file seems to be '
                                 'written by dnsproxy. Zeroing.')
                    ofile = open(self.dnsmasq_defaults, 'w')
                    ofile.write('')
                    ofile.close()
                os.rename(self.dnsmasq_defaults,
                          self.dnsmasq_defaults_backup)
            dmd = open(self.dnsmasq_defaults, 'w')
            dmd.write(self.dnsmasq_hook)
            dmd.close()
            if self.distro == 'redhat':
                for line in fileinput.input(self.dnsmasq_initd_script,
                                            inplace=1):
                    if line.find(self.rh_before) == 0:
                        print self.rh_inject
                    print line,
            subprocess.check_call(self.dnsmasq_restart,
                                  stdout=open(os.devnull, 'w'),
                                  stderr=subprocess.STDOUT)
            logging.info('Hooked with dnsmasq')
            # Restarting of dnsproxy changes also resolv conf. Reset timer
            # to make sure that we don't load dnsproxy's IP address. Otherwise
            # all DNS requests are blocked when running dnsproxy in
            # combination with dnsmasq.
            self.old_rc_mtime = os.stat(self.filetowatch).st_mtime
        if (not (self.use_dnsmasq_hook and self.use_resolvconf) and
            self.overwrite_resolv_conf):
            os.link(self.resolvconf_towrite, self.resolvconf_bkname)
        return

    def restore_resolvconf_dnsmasq(self):
        """Restore old dnsmasq config and restart dnsmasq if necessary."""
        if self.use_dnsmasq_hook:
            logging.info('Removing dnsmasq hooks')
            if os.path.exists(self.dnsmasq_defaults_backup):
                os.rename(self.dnsmasq_defaults_backup,
                          self.dnsmasq_defaults)
            if self.distro == 'redhat':
                for line in fileinput.input(self.dnsmasq_initd_script,
                                            inplace=1):
                    if line.find(self.rh_inject) == -1:
                        print line,
            subprocess.check_call(self.dnsmasq_restart,
                                  stdout=open(os.devnull, 'w'),
                                  stderr=subprocess.STDOUT)
        if (not (self.use_dnsmasq_hook and self.use_resolvconf) and
            self.overwrite_resolv_conf):
            os.rename(self.resolvconf_bkname, self.resolvconf_towrite)
            logging.info('resolv.conf restored')
        return

    def write(self, params):
        """Write resolv.conf."""
        tmp = '%s.tmp-%s' % (self.resolvconf_towrite, MYID)
        ofile = open(tmp, 'w')
        ofile.write('# This is written by hipdnsproxy\n')
        for key, value in params.iteritems():
            if type(value) is str:
                value = (value,)
            for val in value:
                ofile.write('%-10s %s\n' % (key, val))
        ofile.close()
        os.rename(tmp, self.resolvconf_towrite)
        self.old_rc_mtime = os.stat(self.filetowatch).st_mtime

    def overwrite_resolvconf(self):
        """Rewrite the contents of resolv.conf.

        TODO(ptman): would just changing the mtime suffice?
        """
        tmp = '%s.tmp-%s' % (self.resolvconf_towrite, MYID)
        ifile = open(self.resolvconf_towrite, 'r')
        ofile = open(tmp, 'w')
        while True:
            buf = ifile.read(16384)
            if not buf:
                break
            ofile.write(buf)
        ifile.close()
        ofile.close()
        os.rename(tmp, self.resolvconf_towrite)
        logging.info('Rewrote resolv.conf')

    def start(self):
        """Perform startup routines."""
        self.save_resolvconf_dnsmasq()
        if (not (self.use_dnsmasq_hook and self.use_resolvconf) and
            self.overwrite_resolv_conf):
            self.overwrite_resolvconf()

    def restart(self):
        """Perform restart routines."""
        if (not (self.use_dnsmasq_hook and self.use_resolvconf) and
            self.overwrite_resolv_conf):
            self.overwrite_resolvconf()
        self.old_rc_mtime = os.stat(self.filetowatch).st_mtime

    def stop(self):
        """Perform shutdown routines."""
        self.restore_resolvconf_dnsmasq()
        subprocess.check_call(['ifconfig', 'lo:53', 'down'])
        # Sometimes hipconf processes get stuck, particularly when
        # hipd is busy or unresponsive. This is a workaround.
        subprocess.call(['killall', '--quiet', 'hipconf'],
                        stderr=open(os.devnull, 'w'))


class DNSProxy:
    """HIP DNS proxy main class."""
    re_nameserver = re.compile(r'nameserver\s+(\S+)$')

    def __init__(self, bind_ip=None, bind_port=None, disable_lsi=False,
                 dns_timeout=2.0, fork=False, hiphosts=None, hostsnames=None,
                 overwrite_resolv_conf=True, pidfile=None, prefix=None,
                 resolv_conf=None, server_ip=None, server_port=None):
        self.bind_ip = bind_ip
        self.bind_port = bind_port
        self.disable_lsi = disable_lsi
        self.dns_timeout = dns_timeout
        self.fork = fork
        self.hiphosts = hiphosts
        self.hostsnames = hostsnames
        self.overwrite_resolv_conf = overwrite_resolv_conf
        self.pidfile = pidfile
        self.prefix = prefix
        self.resolv_conf = resolv_conf
        self.server_ip = server_ip
        self.server_port = server_port

        if self.hostsnames is None:
            self.hostsnames = []

        self.bind_alt_port = None
        self.use_alt_port = False
        self.app_timeout = 1
        self.hosts_ttl = 122
        self.sent_queue = []
        self.rc1 = None
        self.resolvconfd = None
        self.hosts = None
        # Keyed by ('server_ip',server_port,query_id) tuple
        self.sent_queue_d = {}
        # required for ifconfig and hipconf in Fedora
        # (rpm and "make install" targets)
        os.environ['PATH'] += ':/sbin:/usr/sbin:/usr/local/sbin'

    def add_query(self, server_ip, server_port, query_id, query):
        """Add a pending DNS query"""
        key = (server_ip, server_port, query_id)
        value = (key, time.time(), query)
        self.sent_queue.append(value)
        self.sent_queue_d[key] = value

    def find_query(self, server_ip, server_port, query_id):
        """Find a pending DNS query"""
        key = (server_ip, server_port, query_id)
        query = self.sent_queue_d.get(key)
        if query:
            idx = self.sent_queue.index(query)
            self.sent_queue.pop(idx)
            del self.sent_queue_d[key]
            return query[2]
        return None

    def clean_queries(self):
        """Clean old unanswered queries"""
        texp = time.time() - 30
        while self.sent_queue:
            if self.sent_queue[0][1] < texp:
                # TODO(ptman): test that key is used properly in del
                key = self.sent_queue[0][0]
                self.sent_queue.pop(0)
                del self.sent_queue_d[key]
            else:
                break
        return

    def read_resolv_conf(self, cfile=None):
        """Read resolv.conf."""
        if cfile is None:
            cfile = self.resolv_conf

        options = {}

        ifile = open(cfile)
        for line in ifile.xreadlines():
            line = line.strip()
            if 'nameserver' not in options:
                match = self.re_nameserver.match(line)
                if match:
                    options['nameserver'] = match.group(1)

        if self.server_ip is None:
            nameserver = options.get('nameserver', None)
            self.server_ip = nameserver

        self.resolvconfd = options
        return options

    def parameter_defaults(self):
        """Missing default parameters."""
        env = os.environ
        if self.server_ip is None:
            self.server_ip = env.get('SERVER', None)
        if self.server_port is None:
            server_port = env.get('SERVERPORT', None)
            if server_port is not None:
                self.server_port = int(server_port)
        if self.server_port is None:
            self.server_port = 53
        if self.bind_ip is None:
            self.bind_ip = env.get('IP', None)
        if self.bind_ip is None:
            self.bind_ip = '127.0.0.53'
        if self.bind_port is None:
            bind_port = env.get('PORT', None)
            if bind_port is not None:
                self.bind_port = int(bind_port)
        if self.bind_port is None:
            self.bind_port = 53
        if self.bind_alt_port is None:
            self.bind_alt_port = 60600

    def hosts_recheck(self):
        """Recheck all hosts files."""
        for hostsdb in self.hosts:
            hostsdb.recheck()
        return

    def getaddr(self, ahn):
        """Get a hostname matching address."""
        for hostsdb in self.hosts:
            result = hostsdb.getaddr(ahn)
            if result:
                return result

    def getaaaa(self, ahn):
        """Get an AAAA record from the hosts files."""
        for hostsdb in self.hosts:
            result = hostsdb.getaaaa(ahn)
            if result:
                return result

    def getaaaa_hit(self, ahn):
        """Get and HIT record from the hosts files."""
        for hostsdb in self.hosts:
            result = hostsdb.getaaaa_hit(ahn)
            if result:
                return result

    def cache_name(self, name, addr, ttl):
        """Cache the name-address mapping with ttl in all hosts files."""
        for hostsdb in self.hosts:
            hostsdb.cache_name(name, addr, ttl)

    def geta(self, ahn):
        """Get an A record from the hosts files."""
        for hostsdb in self.hosts:
            result = hostsdb.geta(ahn)
            if result:
                return result

    def forkme(self):
        """Daemonize current process."""
        pid = os.fork()
        if pid:
            return False
        else:
            # we are the child
            global MYID
            MYID = '%d-%d' % (time.time(), os.getpid())
            loghandler = logging.handlers.SysLogHandler(address='/dev/log',
                facility=logging.handlers.SysLogHandler.LOG_DAEMON)
            loghandler.setFormatter(logging.Formatter(
                'hipdnsproxy[%(process)s] %(levelname)-8s %(message)s'))
            logging.getLogger().addHandler(loghandler)
            stdin = file(os.devnull, 'r')
            stdout = file(os.devnull, 'a+')
            stderr = file(os.devnull, 'a+', 0)
            os.dup2(stdin.fileno(), sys.stdin.fileno())
            os.dup2(stdout.fileno(), sys.stdout.fileno())
            os.dup2(stderr.fileno(), sys.stderr.fileno())
            return True

    def killold(self):
        """Kill process with PID from pidfile."""
        try:
            ifile = open(self.pidfile, 'r')
        except IOError, ioe:
            if ioe[0] == errno.ENOENT:
                return
            else:
                logging.error('Error opening pid file: %s', ioe)
                sys.exit(1)
        try:
            os.kill(int(ifile.readline().rstrip()), signal.SIGTERM)
        except OSError, ose:
            if ose[0] == errno.ESRCH:
                ifile.close()
                return
            else:
                logging.error('Error terminating old process: %s', ose)
                sys.exit(1)
        time.sleep(3)
        ifile.close()

    def recovery(self):
        """Recover from being harshly killed."""
        try:
            ifile = open(self.pidfile, 'r')
        except IOError, ioe:
            if ioe[0] == errno.ENOENT:
                return
            else:
                logging.error('Error opening pid file: %s', ioe)
                sys.exit(1)
        ifile.readline()
        bk_path = '%s-%s' % (self.rc1.resolvconf_towrite,
                             ifile.readline().rstrip())
        if os.path.exists(bk_path):
            logging.info('resolv.conf backup found. Restoring.')
            tmp = self.rc1.resolvconf_bkname
            self.rc1.resolvconf_bkname = bk_path
            self.rc1.restore_resolvconf_dnsmasq()
            self.rc1.resolvconf_bkname = tmp
        ifile.close()

    def savepid(self):
        """Write PID and MYID to pidfile."""
        try:
            ofile = open(self.pidfile, 'w')
        except IOError, ioe:
            logging.error('Error opening pid file for writing: %s', ioe)
            sys.exit(1)
        ofile.write('%d\n' % (os.getpid(),))
        ofile.write('%s\n' % MYID)
        ofile.close()

    def write_local_hits_to_hosts(self):
        """Add local HITs to the hosts files.

        Otherwise certain services (sendmail, cups, httpd) timeout when they
        are started and they query the local HITs from the DNS.

        FIXME: should we really write the local hits to a file rather than just
        adding them to the cache?
        """
        localhit = []
        proc = subprocess.Popen(['ifconfig', 'dummy0'],
                                stdout=subprocess.PIPE,
                                stderr=subprocess.STDOUT).stdout
        result = proc.readline()
        while result:
            start = result.find('2001:1')
            end = result.find('/28')
            if start != -1 and end != -1:
                hit = result[start:end]
                if not self.getaddr(hit):
                    localhit.append(hit)
            result = proc.readline()
        proc.close()
        ofile = open(self.hiphosts, 'a')
        for i in range(len(localhit)):
            ofile.write('%s\tlocalhit%s\n' % (localhit[i], i + 1))
        ofile.close()

    def hip_cache_lookup(self, packet):
        """Make a cache lookup."""
        result = None
        qname = packet['questions'][0][0]
        qtype = packet['questions'][0][1]

        if self.prefix and qname.startswith(self.prefix):
            qname = qname[len(self.prefix):]

        # convert 1.2....1.0.0.1.0.0.2.ip6.arpa to a HIT and
        # map host name to address from cache
        if qtype == 12:
            lr_ptr = None
            addr_str = hosts.ptr_to_addr(qname)
            if (not self.disable_lsi and addr_str is not None and
                hosts.valid_lsi(addr_str)):
                addr_str = lsi_to_hit(addr_str)
            lr_ptr = self.getaddr(addr_str)
            lr_aaaa_hit = None
        else:
            lr_a = self.geta(qname)
            lr_aaaa = self.getaaaa(qname)
            lr_aaaa_hit = self.getaaaa_hit(qname)

        if (lr_aaaa_hit is not None and
            (not self.prefix or
             packet['questions'][0][0].startswith(self.prefix))):
            if lr_a is not None:
                add_hit_ip_map(lr_aaaa_hit[0], lr_a[0])
            if lr_aaaa is not None:
                add_hit_ip_map(lr_aaaa_hit[0], lr_aaaa[0])
            if qtype == 28:               # 28: AAAA
                result = lr_aaaa_hit
            elif qtype == 1 and not self.disable_lsi:  # 1: A
                lsi = hit_to_lsi(lr_aaaa_hit[0])
                if lsi is not None:
                    result = (lsi, lr_aaaa_hit[1])
        elif self.prefix and packet['questions'][0][0].startswith(self.prefix):
            result = None
        elif qtype == 28:
            result = lr_aaaa
        elif qtype == 1:
            result = lr_a
        elif qtype == 12 and lr_ptr is not None:  # 12: PTR
            result = (lr_ptr, self.hosts_ttl)

        if result is not None:
            packet['answers'].append([packet['questions'][0][0], qtype, 1,
                                     result[1], result[0]])
            packet['ancount'] = len(packet['answers'])
            packet['qr'] = 1
            return True

        return False

    def hip_lookup(self, packet):
        """Make a lookup."""
        qname = packet['questions'][0][0]
        qtype = packet['questions'][0][1]

        dns_hit_found = False
        for answer in packet['answers']:
            if answer[1] == 55:
                dns_hit_found = True
                break

        lsi = None
        hit_found = dns_hit_found is not None
        if hit_found:
            hit_ans = []
            lsi_ans = []

            for answer in packet['answers']:
                if answer[1] != 55:
                    continue

                hit = socket.inet_ntop(socket.AF_INET6, answer[7])
                hit_ans.append([qname, 28, 1, answer[3], hit])

                if qtype == 1 and not self.disable_lsi:
                    lsi = hit_to_lsi(hit)
                    if lsi is not None:
                        lsi_ans.append([qname, 1, 1, self.hosts_ttl, lsi])

                self.cache_name(qname, hit, answer[3])

        if qtype == 28 and hit_found:
            packet['answers'] = hit_ans
        elif lsi is not None:
            packet['answers'] = lsi_ans
        else:
            packet['answers'] = []

        packet['ancount'] = len(packet['answers'])

    def mainloop(self, unused_args):
        """HIP DNS proxy main loop."""
        connected = False

        logging.info('Dns proxy for HIP started')

        self.parameter_defaults()

        # Default virtual interface and address for dnsproxy to
        # avoid problems with other dns forwarders (e.g. dnsmasq)
        os.system("ifconfig lo:53 %s" % (self.bind_ip,))

        servsock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            servsock.bind((self.bind_ip, self.bind_port))
        except socket.error:
            logging.info('Port %d occupied, falling back to port %d',
                         self.bind_port, self.bind_alt_port)
            servsock.bind((self.bind_ip, self.bind_alt_port))
            self.use_alt_port = True

        servsock.settimeout(self.app_timeout)

        rc1 = self.rc1
        if self.use_alt_port and os.path.exists(rc1.dnsmasq_defaults):
            rc1.set_dnsmasq_hook(self)

        if rc1.use_dnsmasq_hook and rc1.use_resolvconf:
            conf_file = rc1.guess_resolvconf()
        else:
            conf_file = None

        if conf_file is not None:
            logging.info('Using conf file %s', conf_file)

        self.read_resolv_conf(conf_file)
        if self.server_ip is not None:
            logging.info('DNS server is %s', self.server_ip)

        self.hosts = []
        if self.hostsnames:
            for hostsname in self.hostsnames:
                self.hosts.append(hosts.Hosts(hostsname))
        else:
            if os.path.exists(self.hiphosts):
                self.hosts.append(hosts.Hosts(self.hiphosts))

        if os.path.exists(DEFAULT_HOSTS):
            self.hosts.append(hosts.Hosts(DEFAULT_HOSTS))

        self.write_local_hits_to_hosts()

        util.init_wantdown()
        util.init_wantdown_int()        # Keyboard interrupts

        rc1.start()

        if (not (rc1.use_dnsmasq_hook and rc1.use_resolvconf) and
            self.overwrite_resolv_conf):
            rc1.write({'nameserver': self.bind_ip})

        if self.server_ip is not None:
            if self.server_ip.find(':') == -1:
                server_family = socket.AF_INET
            else:
                server_family = socket.AF_INET6
            clisock = socket.socket(server_family, socket.SOCK_DGRAM)
            clisock.settimeout(self.dns_timeout)
            try:
                clisock.connect((self.server_ip, self.server_port))
                connected = True
            except socket.error:
                connected = False

        query_id = 1

        while not util.wantdown():
            try:
                self.hosts_recheck()
                if rc1.old_has_changed():
                    connected = False
                    self.server_ip = rc1.resolvconfd.get('nameserver')
                    if self.server_ip is not None:
                        if self.server_ip.find(':') == -1:
                            server_family = socket.AF_INET
                        else:
                            server_family = socket.AF_INET6
                        clisock = socket.socket(server_family,
                                                socket.SOCK_DGRAM)
                        clisock.settimeout(self.dns_timeout)
                        try:
                            clisock.connect((self.server_ip, self.server_port))
                            connected = True
                            logging.info('DNS server is %s', self.server_ip)
                        except socket.error:
                            connected = False

                    rc1.restart()
                    if (not (rc1.use_dnsmasq_hook and rc1.use_resolvconf) and
                        self.overwrite_resolv_conf):
                        rc1.write({'nameserver': self.bind_ip})

                if connected:
                    rlist, _, _ = select.select([servsock, clisock], [], [],
                                                5.0)
                else:
                    rlist, _, _ = select.select([servsock], [], [], 5.0)
                self.clean_queries()
                if servsock in rlist:          # Incoming DNS request
                    inbuf, from_a = servsock.recvfrom(2048)

                    packet = DeSerialize(inbuf).get_dict()
                    qtype = packet['questions'][0][1]

                    sent_answer = False

                    if qtype in (1, 28, 12):
                        if self.hip_cache_lookup(packet):
                            try:
                                outbuf = Serialize(packet).get_packet()
                                servsock.sendto(outbuf, from_a)
                                sent_answer = True
                            except socket.error:
                                logging.exception('Exception:')
                    elif (self.prefix and
                          packet['questions'][0][0].startswith(self.prefix)):
                        # Query with HIP prefix for unsupported RR type.
                        # Send empty response.
                        packet['qr'] = 1
                        try:
                            outbuf = Serialize(packet).get_packet()
                            servsock.sendto(outbuf, from_a)
                            sent_answer = True
                        except socket.error:
                            logging.exception('Exception:')

                    if connected and not sent_answer:
                        logging.info('Query type %d for %s from %s',
                            qtype, packet['questions'][0][0],
                            (self.server_ip, self.server_port))

                        query = (packet, from_a[0], from_a[1], qtype)
                        # FIXME: Should randomize for security
                        query_id = (query_id % 65535) + 1
                        pckt = copy.copy(packet)
                        pckt['id'] = query_id
                        if ((qtype == 28 or
                             (qtype == 1 and not self.disable_lsi)) and not
                            is_reverse_hit_query(packet['questions'][0][0])):

                            if not self.prefix:
                                pckt['questions'][0][1] = 55
                            if (self.prefix and
                                pckt['questions'][0][0].startswith(
                                    self.prefix)):
                                pckt['questions'][0][0] = pckt[
                                    'questions'][0][0][len(self.prefix):]
                                pckt['questions'][0][1] = 55

                        if qtype == 12 and not self.disable_lsi:
                            qname = packet['questions'][0][0]
                            addr_str = hosts.ptr_to_addr(qname)
                            if (addr_str is not None and
                                hosts.valid_lsi(addr_str)):
                                query = (packet, from_a[0], from_a[1], qname)
                                hit_str = lsi_to_hit(addr_str)
                                if hit_str is not None:
                                    pckt['questions'][0][0] = \
                                            hosts.addr_to_ptr(hit_str)

                        outbuf = Serialize(pckt).get_packet()
                        clisock.sendto(outbuf, (self.server_ip,
                                                self.server_port))

                        self.add_query(self.server_ip, self.server_port,
                                       query_id, query)

                if connected and clisock in rlist:   # Incoming DNS reply
                    inbuf, from_a = clisock.recvfrom(2048)
                    logging.info('Packet from DNS server %d bytes from %s',
                        len(inbuf), from_a)
                    packet = DeSerialize(inbuf).get_dict()

                    # Find original query
                    query_id_o = packet['id']
                    query_o = self.find_query(from_a[0], from_a[1], query_id_o)
                    if query_o:
                        qname = packet['questions'][0][0]
                        qtype = packet['questions'][0][1]
                        send_reply = True
                        query_again = False
                        hit_found = False
                        packet_o = query_o[0]
                        # Replace with the original query id
                        packet['id'] = packet_o['id']

                        if qtype == 55 and query_o[3] in (1, 28):
                            # Restore qtype
                            packet['questions'][0][1] = query_o[3]
                            self.hip_lookup(packet)
                            if packet['ancount'] > 0:
                                hit_found = True
                            if (not self.prefix or
                                (hit_found and not (self.getaaaa(qname) or
                                                    self.geta(qname)))):
                                query_again = True
                                send_reply = False
                            elif self.prefix:
                                hit_found = True
                                packet['questions'][0][0] = (
                                    self.prefix + packet['questions'][0][0])
                                for answer in packet['answers']:
                                    answer[0] = self.prefix + answer[0]

                        elif qtype in (1, 28):
                            hit = self.getaaaa_hit(qname)
                            ip6 = self.getaaaa(qname)
                            ip4 = self.geta(qname)
                            for answer in packet['answers']:
                                if answer[1] in (1, 28):
                                    self.cache_name(qname, answer[4],
                                                    answer[3])
                            if hit is not None:
                                for answer in packet['answers']:
                                    if (answer[1] == 1 or
                                        (answer[1] == 28 and not
                                         hosts.valid_hit(answer[4]))):
                                        add_hit_ip_map(hit[0], answer[4])
                                # Reply with HIT/LSI once it's been mapped to
                                # an IP
                                if ip6 is None and ip4 is None:
                                    if (packet_o['ancount'] == 0 and
                                        not self.prefix):
                                        # No LSI available. Return IPv4
                                        tmp = packet['answers']
                                        packet = packet_o
                                        packet['answers'] = tmp
                                        packet['ancount'] = len(
                                            packet['answers'])
                                    else:
                                        packet = packet_o
                                        if self.prefix:
                                            packet['questions'][0][0] = \
                                                    (self.prefix +
                                                     packet['questions'][0][0])
                                            for answer in packet['answers']:
                                                answer[0] = (self.prefix +
                                                             answer[0])
                                else:
                                    send_reply = False
                            elif query_o[3] == 0:
                                # Prefix is in use
                                # IP was queried for cache only
                                send_reply = False

                        elif qtype == 12 and isinstance(query_o[3], str):
                            packet['questions'][0][0] = query_o[3]
                            for answer in packet['answers']:
                                answer[0] = query_o[3]

                        if query_again:
                            if hit_found:
                                qtypes = [28, 1]
                                pckt = copy.deepcopy(packet)
                            else:
                                qtypes = [query_o[3]]
                                pckt = copy.copy(packet)
                            pckt['qr'] = 0
                            pckt['answers'] = []
                            pckt['ancount'] = 0
                            pckt['nslist'] = []
                            pckt['nscount'] = 0
                            pckt['additional'] = []
                            pckt['arcount'] = 0
                            for qtype in qtypes:
                                if self.prefix:
                                    query = (packet, query_o[1], query_o[2], 0)
                                else:
                                    query = (packet, query_o[1], query_o[2],
                                             qtype)
                                query_id = (query_id % 65535) + 1
                                pckt['id'] = query_id
                                pckt['questions'][0][1] = qtype
                                outbuf = Serialize(pckt).get_packet()
                                clisock.sendto(outbuf, (self.server_ip,
                                                        self.server_port))
                                self.add_query(self.server_ip,
                                               self.server_port,
                                               query_id, query)
                            packet['questions'][0][1] = query_o[3]

                        if send_reply:
                            outbuf = Serialize(packet).get_packet()
                            servsock.sendto(outbuf, (query_o[1], query_o[2]))
            except (select.error, OSError), exc:
                if exc[0] == errno.EINTR:
                    pass
                else:
                    logging.exception('Exception:')

        logging.info('Wants down')
        rc1.stop()

if __name__ == '__main__':
    import doctest
    doctest.testmod(raise_on_error=True)
