//  CS461 Distributed Computing and Networking
//  Princeton University
//  Tammo Spalink
// 
//  Revision 6
//

#include "interface.h"

#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <netinet/in.h>
#include <net/ethernet.h>
#include <linux/ip.h>
#include <linux/tcp.h>
#include <linux/udp.h>
#include <linux/icmp.h>
#include <linux/if.h>
#include <linux/ip_fw.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <signal.h>
#include <fcntl.h>
#include <string.h>
#include <sys/time.h>

//
// ------ For Debugging ------
//

#define ASSERT(expression) {                                                  \
    if (! (expression)) {                                                     \
        fprintf(stderr, "ASSERTION (%s) FAILED in %s (%s:%d)\n",              \
                         (#expression), __FUNCTION__, __FILE__, __LINE__);    \
        abort();                                                              \
    }                                                                         \
}

#define TRACE(fmt, msg...) {                                                  \
    if (CS461Debug) {                                                     \
        fprintf(stderr, "[" __FUNCTION__ "] " fmt "\n" , ##msg);              \
    }                                                                         \
}

static int 
catchError(int returnValue, 
           int badValue, 
           char *message)
    // Just a guard routine for making system calls
{
    if (returnValue == badValue) {
	perror(message);
	exit(1);
    }
    return returnValue;
}

//
// ------ Datatypes and Useful Defines ------
//

#define CLOCK_INTERVAL_USECS 0
#define CLOCK_INTERVAL_SECS 5

#define IP_NETLINK_DEVICE "/dev/netlink"

#define EXPORT
#define FOREVER ;;

typedef struct {
    __u32 length;                  // Effective deviceHdr + packet
    __u32 mark;                    // we do not use this
    char interface[IFNAMSIZ];
} CS461DeviceHeader;

typedef struct {
    CS461DeviceHeader deviceHeader;
    struct iphdr header;
    __u8 body[65535 - sizeof(struct iphdr)];
} CS461PacketBuf;

//
// ------ Variables Static Global to Interface ------
//

// Do or Don't print debug messages
static int CS461Debug;

// Process IDs of the Main (child) and IO Watcher Daemon (parent)
static pid_t CS461Daemon;
static pid_t CS461Child;

// Frequency of clock ticks in milliseconds
static unsigned long long CS461ClockPeriod;

// Time at which the next clock tick should happen (in milliseconds)
static unsigned long long CS461NextClockEvent;

// The callback to higher level code
static void (* CS461RecvCallback)(void *);
static void (* CS461ClockCallback)(void);

// File Descriptor for the Firewall
static int FirewallFD = -1;
static int FirewallSocketFD = -1;

// Main book-keeping datastructure for ipchains
static struct ip_fw FirewallRule;

//
// ------ Functions Static Global to Interface ------
//

static unsigned long long
CS461GetTime(void)
{
    struct timeval tv;
    
    gettimeofday(&tv, NULL);
    return ((unsigned long long)(tv.tv_usec) / 1000) + 
        ((unsigned long long)(tv.tv_sec) * 1000);
}

static void 
CS461TerminationHandler(
    int sig,
    siginfo_t *info, 
    void *unknown)
    // This function is the signal handler for SIGCHLD and SIGINT.  It
    // tries to cleanly shut down the system and remove the firewall
    // rule.
{
    struct ip_fwchange deleteRecord;

    if (sig == SIGCHLD) {
        int status;

        catchError(waitpid(0, &status, 0), -1, "waitpid");
        if (! WIFEXITED(status)) {
            if (WIFSIGNALED(status))
                fprintf(stderr, "Fatal Program Error "
                        "[signal: %d, status: %d]\n",
                        WTERMSIG(status), WEXITSTATUS(status));
            else 
                fprintf(stderr, "Program Exited\n");
        }
    } else {
        fprintf(stderr, "Interrupted\n");
    }

    TRACE("shutting down...");

    deleteRecord.fwc_rule.ipfw = FirewallRule;
    strcpy(deleteRecord.fwc_rule.label, "DENY");
    strcpy(deleteRecord.fwc_label, "input");

    catchError(setsockopt(FirewallSocketFD, IPPROTO_IP, IP_FW_DELETE, 
                          (char *)(&deleteRecord), sizeof(deleteRecord)),
               -1, "delete firewall rule");
    close(FirewallSocketFD);
    close(FirewallFD);

    TRACE("succeeded");

    exit(info->si_status);
}

static void
CS461RecvPacket(
    CS461PacketBuf *packet)
    // Wait for a packet to arrive.  The packet will be written to
    // "packet", which must be already allocated.
{
    for (FOREVER) {                // Make sure we only get whole packets
        int readLen;

        TRACE("reading...");
        packet->deviceHeader.length = 0;
        readLen = read(FirewallFD, packet, sizeof(CS461PacketBuf));
        ASSERT((packet->deviceHeader.length != 0) && (readLen != 0));

        if (readLen != packet->deviceHeader.length) {
            fprintf(stderr, 
                    "CS461RecvPacket: only got %d of %d bytes of data\n",
                    readLen, packet->deviceHeader.length);
            exit(1);
        }

        if (CS461Debug) { // Print some packet info
            struct in_addr addr;
            TRACE("--- new packet:");
            TRACE("id:   %u", ntohs(packet->header.id));
            TRACE("prot: %u", (unsigned int) packet->header.protocol);
            addr.s_addr = packet->header.saddr;
            TRACE("src:  %s", inet_ntoa(addr));
            addr.s_addr = packet->header.daddr;
            TRACE("dest: %s", inet_ntoa(addr));
            // IP header described in /usr/include/linux/ip.h
            TRACE("len:  %u", ntohs(packet->header.tot_len));
            TRACE("frgo: %u", ntohs(packet->header.frag_off & 0x1FFFF));

            TRACE("--- succeeded");
        }
        return;
    }
}

static void
CS461PerformRecvCallback(void)
    // Should be called only when data is available for reading.  This
    // will fetch the data and pass it to higher levels.  This will do
    // nothing if the callback function is NULL.
{
    CS461PacketBuf packet;

    if (CS461RecvCallback == NULL) return;
    CS461RecvPacket(&packet);
    TRACE("have data, performing callback");
    CS461RecvCallback((void *)&(packet.header));
}

static void
CS461PerformClockCallback(void)
    // Should be called whenever the cock callback should be
    // attempted.  Does nothing for NULL callback.
{
    if (CS461ClockCallback == NULL) return;
    TRACE("clock tick, performing callback");
    CS461ClockCallback();
}

static void 
CS461EventHandler(
    int sig)
    // This function is the signal handler for SIGUSR1 and SIGUSR2.
    // The parent (watcher daemon) process sends these signals when
    // data arrives or when a clock event happens, respectively.
{
    TRACE("trying callback");
    if (sig == SIGUSR1) CS461PerformRecvCallback();
    else                CS461PerformClockCallback();
    
    // Tell the daemon that we are done and that it may continue
    // to monitor for new packet arrivals and clock events.
    TRACE("signaling watcher daemon to continue");
    catchError(kill(CS461Daemon, SIGUSR1), -1, "signal to daemon");
    
    // Re-establish this handler
    catchError((unsigned)signal(sig, CS461EventHandler),
               (unsigned)SIG_ERR, "CS461EventHandler setup");
}

static void
CS461WatcherDaemonHandler(
    int sig)
    // To avoid a race condition in send the child a signal and
    // recieving a response with something like select, this signal
    // handler deals with child responses.  In fact, it does all of
    // the work.  To activate the daemon watcher, the child must send
    // the SIGUSR1 activation signal.
{
    fd_set readSet;
    long long int timediff;
    unsigned long long nextTick;
    struct timeval tv;
    int retval;

    TRACE("activation signal recieved");

    timediff = CS461NextClockEvent - CS461GetTime();
    if (timediff < 0) nextTick = 0;
    else nextTick = timediff;
    tv.tv_sec = nextTick / 1000;
    tv.tv_usec = (nextTick - (tv.tv_sec * 1000)) * 1000;
    TRACE("clock timeout is %ld sec, %ld usec", 
          (long int)tv.tv_sec, (long int)tv.tv_usec);
        
    FD_ZERO(&readSet);
    FD_SET(FirewallFD, &readSet);

    // Wait for either a packet to arrive or a clock event
    TRACE("waiting for events...");
    retval = select(FirewallFD + 1, &readSet, NULL, NULL, &tv);
    catchError(retval, -1, "(select) wait for packet or clock event");

    // Notify the child of whatever happened
    ASSERT((retval == 1) || (retval == 0));
    if (retval == 1) { // Send signal SIGUSR1 for packet events
        TRACE("packet event, signaling child");
        catchError(kill(CS461Child, SIGUSR1), -1, "signal to child");
    }
    if (retval == 0) { // Send signal SIGUSR2 for clock events
        TRACE("clock event, signaling child");
        CS461NextClockEvent = CS461GetTime() + CS461ClockPeriod;
        catchError(kill(CS461Child, SIGUSR2), -1, "signal to child");
    }
    
    // Re-install this handler
    catchError((unsigned)signal(SIGUSR1, CS461WatcherDaemonHandler),
               (unsigned)SIG_ERR, "CS461WatcherDaemonHandler setup");
}

static void 
CS461WatcherDaemon(void)
    // This function is forked off as a seperate process during
    // initialization.  It watches the network device for incoming
    // data and also keeps track of time for the child (rest of the
    // system).  Once such events occur a signal will be sent to the
    // child to notify it.  This means that the child can execute code
    // that does not know or bother about such events until they occur
    // and an appropriate sugfnal handler is given control by the OS
    // and the right thing can happen.
{
    static struct sigaction sa;

    // Initialize the clock
    CS461NextClockEvent = CS461GetTime() + CS461ClockPeriod;

    // Set up a handler to catch when we should shut down (when the
    // main process quits or dies unexpectedly).
    sigemptyset(&sa.sa_mask);
    sa.sa_sigaction = CS461TerminationHandler;
    sa.sa_flags = SA_NOCLDSTOP | SA_SIGINFO;
    catchError(sigaction(SIGCHLD, &sa, NULL), -1, "SIGCHLD setup");
    catchError(sigaction(SIGINT, &sa, NULL), -1, "SIGINT setup");

    // Set up a handler to recieve signals from the child
    catchError((unsigned)signal(SIGUSR1, CS461WatcherDaemonHandler),
               (unsigned)SIG_ERR, "CS461WatcherDaemonHandler setup");

    // Activate myself
    catchError(kill(getpid(), SIGUSR1), -1, "self activation signal");

    for (FOREVER) {
        int retval;

        // Just loop and ignore signals from the child (which are dealt
        // with by the handler).
        retval = select(0, NULL, NULL, NULL, NULL);
        if ((retval != -1) || (errno != EINTR)) {
            catchError(-1, -1, "(select) waiting for reply from child");
        }
    }
}

//
// ------ Visible Exported Interface Functions ------
//

EXPORT void
CS461_Initialize(
    unsigned short protocol,       // e.g. IPPROTO_TCP
    void (* recvCallback)(void *), // called for incoming packets
    void (* clockCallback)(void),  // called at clockPeriod intervals
    unsigned long long clockPeriod,// time in milliseconds
    int debug)                     // Boolean to debug or not to debug
    // This will initialize the interface library.  It will create a
    // firewall rule to capture all packets destined to the extra
    // cluster IP addresses and forward them to user space.  Only
    // "protocol" packets will be captured.  Debugging messages for
    // the library will be turned on if the "debug" flag is TRUE.
    // Whenever data arrives, the callback function will be called
    // with the the new packet.  If either callback is NULL, it will
    // be disabled.  The recvCallback is passed a region of memory
    // that begins with an IP header.  The memory region may be reused
    // after return from the callback, so important data must be
    // copied out.
{
    struct ip_fwnew insertRecord;

    CS461Debug = debug;
    CS461RecvCallback = recvCallback;
    CS461ClockCallback = clockCallback;
    CS461ClockPeriod = clockPeriod;
    CS461Daemon = getpid();
    
    FirewallFD = catchError(open(IP_NETLINK_DEVICE, O_RDONLY), -1,
                            "open of firewall");
    FirewallSocketFD = catchError(socket(AF_INET, SOCK_RAW, IPPROTO_RAW), -1,
                                  "open of firewall socket");

    FirewallRule.fw_src.s_addr = 0;         // No restriction on source
    FirewallRule.fw_dst.s_addr = CS461_GetIpAlias(); // Me
    FirewallRule.fw_smsk.s_addr = 0;        // No restriction on source
    FirewallRule.fw_dmsk.s_addr = inet_addr("128.112.5.255"); // Cluster mask
    FirewallRule.fw_mark = 0;               // An unused field
    FirewallRule.fw_proto = protocol;       // See only <protocol> packets
    FirewallRule.fw_flg = IP_FW_F_NETLINK;  // Copy packets to /dev/netlink
    FirewallRule.fw_invflg = IP_FW_INV_VIA; // XXX necessary?
    FirewallRule.fw_spts[0] = 0x0;          // port number min/max restrictions
    FirewallRule.fw_spts[1] = 0xFFFF;       // port number min/max restrictions
    FirewallRule.fw_dpts[0] = 0x0;          // port number min/max restrictions
    FirewallRule.fw_dpts[1] = 0xFFFF;       // port number min/max restrictions
    FirewallRule.fw_redirpt = 0;            // No redirection
    FirewallRule.fw_outputsize = 65535;     // No size restrictions
    strcpy(FirewallRule.fw_vianame, "tap0"); // XXX necessary?
    FirewallRule.fw_tosand = 0xFF;          // Don't change the TOS hdr field
    FirewallRule.fw_tosxor = 0x00;          // Don't change the TOS hdr field

    insertRecord.fwn_rulenum = 1;           // Make this the first rule
    insertRecord.fwn_rule.ipfw = FirewallRule;
    strcpy(insertRecord.fwn_rule.label, "DENY");
    strcpy(insertRecord.fwn_label, "input");

    catchError(setsockopt(FirewallSocketFD, IPPROTO_IP, IP_FW_INSERT, 
                          (char *)(&insertRecord), sizeof(insertRecord)),
               -1, "insert firewall rule");

    if ((CS461Child = fork()) != 0) {
        // If this is the parent, we will just watch the netlink
        // device for incoming data and then signal the child once
        // data arrives.
        CS461WatcherDaemon();
        exit(0);
    } else {
        // For the child, we set up the signal handler to catch
        // signals from the parent (watcher daemon).  Depending on the
        // state of interrupts, it may call the callbacks provided bye
        // the higher level or update state to have them called one
        // interrupts are allowed once again.
        catchError((unsigned)signal(SIGUSR1, CS461EventHandler),
                   (unsigned)SIG_ERR, "SIGUSR1 CS461EventHandler setup");
        catchError((unsigned)signal(SIGUSR2, CS461EventHandler),
                   (unsigned)SIG_ERR, "SIGUSR2 CS461EventHandler setup");
        TRACE("set up event handler");
    }

    TRACE("succeeded (firewall rule inserted)");
}

EXPORT CS461_InterruptState
CS461_InterruptToggle(CS461_InterruptState state)
    // Toggle the use of the revc and clock callbacks.  This is
    // important to synchronize the execution of the callbacks so that
    // they do not interfere with other tasks.  If a callback would
    // have occurred one or more times during a toggle to off period,
    // it will be called only once immediately after a toggle to on.
    // The previous toggle setting is returned.
{
    static CS461_InterruptState CS461InterruptState;
    CS461_InterruptState oldState;

    ASSERT(FirewallFD != -1);               // Verify we are initialized

    TRACE("Clock: %s, Packets: %s",
          state & CS461_INTCLOCK ? "Y" : "N",
          state & CS461_INTRECV  ? "Y" : "N");

    if (CS461InterruptState == state) return CS461InterruptState;

    if (oldState == CS461_INTOFF) {
        TRACE("signaling watcher daemon to continue");
        catchError(kill(CS461Daemon, SIGUSR1), -1, "signal to daemon");
    }

    if (state & CS461_INTCLOCK & ~oldState) {
        catchError((unsigned)signal(SIGUSR2, CS461EventHandler),
                   (unsigned)SIG_ERR, "CS461EventHandler setup");
    }
    if (~state & CS461_INTCLOCK & oldState) {
        catchError((unsigned)signal(SIGUSR2, SIG_IGN),
                   (unsigned)SIG_ERR, "CS461EventHandler setup");
    }
    if (state & CS461_INTRECV & ~oldState) {
        catchError((unsigned)signal(SIGUSR1, CS461EventHandler),
                   (unsigned)SIG_ERR, "CS461EventHandler setup");
    }
    if (~state & CS461_INTRECV & oldState) {
        catchError((unsigned)signal(SIGUSR1, SIG_IGN),
                   (unsigned)SIG_ERR, "CS461EventHandler setup");
    }

    oldState = CS461InterruptState;
    CS461InterruptState = state;
    return oldState;
}

// Macro used by CS461_GetIpAlias
#define MATCH(x, y) {                                                         \
    if (strcmp(hostname, "micron" #x ".CS.Princeton.EDU") == 0) {             \
        TRACE("IP alias is %s", "128.112.5." #y);                             \
        return inet_addr("128.112.5." #y);                                    \
    }                                                                         \
}

EXPORT unsigned int
CS461_GetIpAlias(void)
    // Returns the IP address at which data may be recieved.
{
    char hostname[256];

    catchError(gethostname(hostname, 256), -1, "gethostname");

    MATCH(01, 54);
    MATCH(02, 55);
    MATCH(03, 56);
    MATCH(04, 57);
    MATCH(34, 58);
    MATCH(06, 59);
    MATCH(21, 60);
    MATCH(22, 61);
    MATCH(09, 62);
    MATCH(20, 63);
    MATCH(11, 64);
    MATCH(12, 65);
    MATCH(13, 66);
    MATCH(14, 67);
    MATCH(15, 68);
    MATCH(16, 69);
    MATCH(17, 70);

    return 0;
}

EXPORT void
CS461_Block(void)
    // Wait for a an event to occur.  Regardless of the state of
    // CS461_Interrupts, CS461_Block return after either a clock or
    // packet event has occurred and after the appropriate callback
    // has been issued.
{
    fd_set readSet;
    CS461_InterruptState userInterruptState;
    int retval;

    ASSERT(FirewallFD != -1);               // Verify we are initialized
    
    // To avoid race conditions involving packets being dealt with
    // just before select is called, we only allow clock interrupts.
    // This means that it is possible (by the same style of race) to
    // miss a very close clock interrupt.  We will see the next one,
    // however.
    userInterruptState = CS461_InterruptToggle(CS461_INTCLOCK);

    // Wait until we are interrupted
    TRACE("waiting...");
    FD_ZERO(&readSet);
    FD_SET(FirewallFD, &readSet);
    retval = select(FirewallFD + 1, &readSet, NULL, NULL, NULL);
    if (retval == 1) CS461PerformRecvCallback();

    CS461_InterruptToggle(userInterruptState);
}

EXPORT void
CS461_SendIov(
    unsigned long int nextHop,  // Where packet is actually sent to
    struct iovec *iov,          // Vector of data to send from
    size_t iovlen)              // Number of elements in the vector
    // Send the data specified by the iov.  Wait until the send has
    // completed.  The data must start with an IP header.  IOVs are
    // described in bits/uio.h
{
    struct sockaddr_in destAddr;
    struct msghdr msg;          // defined in bits/socket.h

    ASSERT(FirewallFD != -1);               // Verify we are initialized

    TRACE("sending packet...");
    if (CS461Debug) { // Print the addresses
        struct in_addr addr;
        addr.s_addr = nextHop;
        TRACE("nhop: %s", inet_ntoa(addr));
    }

    destAddr.sin_family = AF_INET;
    destAddr.sin_addr.s_addr = nextHop;

    msg.msg_name = &destAddr;
    msg.msg_namelen = sizeof(destAddr);
    msg.msg_iov = iov;
    msg.msg_iovlen = iovlen;
    msg.msg_control = NULL;
    msg.msg_controllen = 0;
    msg.msg_flags = 0;

    catchError(sendmsg(FirewallSocketFD, &msg, 0), -1, "sendmsg");
    TRACE("succeeded");
}

