#include <unistd.h>
#include <stdlib.h>
#include <fcntl.h>
#include <libgen.h>
#include <stdio.h>
#include <assert.h>
#include <math.h>

#include <sys/types.h>
#include <sys/stat.h>
#include <sys/mman.h>

#include "rabinpoly.h"

/* Length of file message digest (MD) in bytes. Longer MD's are
   better, but increase processing time for diminishing returns.
   Must be multiple of NUM_HASHES_PER_CHAR / 8, and at least 24
   for good results 
*/
#define MD_LENGTH 32
#define MD_BITS (MD_LENGTH * 8)

/* Has to be power of two. Since the Rabin hash only has 63
   usable bits, the number of hashes is limited to 32.
   Lower powers of two could be used for speeding up processing
   of very large files.  */
#define NUM_HASHES_PER_CHAR 32


/* For the final counting, do not count each bit individually, but
   group them. Must be power of two, at most NUM_HASHES_PER_CHAR.
   However, larger sizes result in higher cache usage. Use 8 bits
   per group for efficient processing of large files on fast machines
   with decent caches, or 4 bits for faster processing of small files
   and for machines with small caches.  */
#define GROUP_BITS 4
#define GROUP_COUNTERS (1<<GROUP_BITS)


/* The RABIN_WINDOW_SIZE is the size of fingerprint window used by 
   Rabin algorithm. This is not a modifiable parameter.

   The first RABIN_WINDOW_SIZE - 1 bytes are skipped, in order to ensure
   fingerprints are good hashes. This does somewhat reduce the
   influence of the first few bytes in the file (they're part of
   fewer windows, like the last few bytes), but that actually isn't
   so bad as files often start with fixed content that may bias comparisons.
*/

/* The MIN_FILE_SIZE indicates the absolute minimal file size that
   can be processed. As indicated above, the first and last 
   RABIN_WINDOW_SIZE - 1 bytes are skipped. 
   In order to get at least an average of 12 samples
   per bit in the final message digest, require at least 3 * MD_LENGTH
   complete windows in the file.  */
#define MIN_FILE_SIZE (3 * MD_LENGTH + 2 * (RABIN_WINDOW_SIZE - 1))

/* Limit matching algorithm to files less than 256 MB, so we can use
   32 bit integers everywhere without fear of overflow. For larger
   files we should add logic to mmap the file by piece and accumulate
   the frequency counts. */
#define MAX_FILE_SIZE (256*1024*1024 - 1)

/* Size of cache used to eliminate duplicate substrings.
   Make small enough to comfortably fit in L1 cache.  */
#define DUP_CACHE_SIZE 256

#define MIN(x,y) ((y)<(x) ? (y) : (x))
#define MAX(x,y) ((y)>(x) ? (y) : (x))

typedef struct fileinfo
{ char		*name;
  size_t	length;
  u_char	md[MD_LENGTH];
  int		match;
} File;

int flag_verbose = 0;
int flag_debug = 0;
int flag_warning = 0;
char *flag_relative = 0;

char cmd[12] = "        ...";
char md_strbuf[MD_LENGTH * 2 + 1];
u_char relative_md [MD_LENGTH];

File *file;
int    file_count;
size_t file_bytes;

FILE *msgout;

char hex[17] = "0123456789abcdef";
double pi = 3.14159265358979323844;

int freq[MD_BITS];
u_int64_t freq_dups = 0;

void usage()
{  fprintf (stderr, "usage: %s [-dhvw] [-r fingerprint] file ...\n", cmd);
   fprintf (stderr, " -d\tdebug output, repeate for more verbosity\n");
   fprintf (stderr, " -h\tshow this usage information\n");
   fprintf (stderr, " -r\tshow distance relative to fingerprint "
                    "(%u hex digits)\n", MD_LENGTH * 2);
   fprintf (stderr, " -v\tverbose output, repeat for even more verbosity\n");
   fprintf (stderr, " -w\tenable warnings for suspect statistics\n");
   exit (1);
}

int dist (u_char *l, u_char *r)
{ int j, k;
  int d = 0;

  for (j = 0; j < MD_LENGTH; j++)
  { u_char ch = l[j] ^ r[j];

    for (k = 0; k < 8; k++) d += ((ch & (1<<k)) > 0);
  } 

  return d;
}

char *md_to_str(u_char *md)
{ int j;

  for (j = 0; j < MD_LENGTH; j++)
  { u_char ch = md[j];

    md_strbuf[j*2] = hex[ch >> 4];
    md_strbuf[j*2+1] = hex[ch & 0xF];
  }

  md_strbuf[j*2] = 0;
  return md_strbuf;
}

u_char *str_to_md(char *str, u_char *md)
{ int j;

  if (!md || !str) return 0;

  bzero (md, MD_LENGTH);
  
  for (j = 0; j < MD_LENGTH * 2; j++)
  { char ch = str[j];

    if (ch >= '0' && ch <= '9')
    { md [j/2] = (md [j/2] << 4) + (ch - '0'); 
    }
    else
    { ch |= 32;

      if (ch < 'a' || ch > 'f') break;
      md [j/2] = (md[j/2] << 4) + (ch - 'a' + 10);
  } } 

  return (j != MD_LENGTH * 2 || str[j] != 0) ? 0 : md;
}
    
void freq_to_md(u_char *md)
{ int j, k;
  int num = MD_BITS;

  for (j = 0; j < MD_LENGTH; j++)
  { u_char ch = 0;

    for (k = 0; k < 8; k++) ch = 2*ch + (freq[8*j+k] > 0);
    md[j] = ch;
  }

  if (flag_debug)
  { for (j = 0; j < num; j++)
    { if (j % 8 == 0) printf ("\n%3u: ", j);
      printf ("%7i ", freq[j]);
    }
    printf ("\n");
  }
  bzero (freq, sizeof(freq));
  freq_dups = 0;
}

void process_data (char *name, u_char *data, unsigned len, u_char *md)
{ size_t j = 0;
  u_int32_t ofs;
  u_int32_t dup_cache[DUP_CACHE_SIZE];
  u_int32_t count [MD_BITS * (GROUP_COUNTERS/GROUP_BITS)];
  bzero (dup_cache, DUP_CACHE_SIZE * sizeof (u_int32_t));
  bzero (count, (MD_BITS * (GROUP_COUNTERS/GROUP_BITS) * sizeof (u_int32_t)));

  /* Ignore incomplete substrings */
  while (j < len && j < RABIN_WINDOW_SIZE) rabin_slide8 (data[j++]);

  while (j < len)
  { u_int64_t hash;
    u_int32_t ofs, sum;
    u_char idx;
    int k;

    hash = rabin_slide8 (data[j++]);

    /* In order to update a much larger frequency table
       with only 32 bits of checksum, randomly select a
       part of the table to update. The selection should
       only depend on the content of the represented data,
       and be independent of the bits used for the update.
       
       Instead of updating 32 individual counters, process
       the checksum in MD_BITS / GROUP_BITS groups of 
       GROUP_BITS bits, and count the frequency of each bit pattern.
    */

    idx = (hash >> 32);
    sum = (u_int32_t) hash;
    ofs = idx % (MD_BITS / NUM_HASHES_PER_CHAR) * NUM_HASHES_PER_CHAR;
    idx %= DUP_CACHE_SIZE;
    if (dup_cache[idx] == sum)
    { freq_dups++; 
    }
    else
    { dup_cache[idx] = sum; 
      for (k = 0; k < NUM_HASHES_PER_CHAR / GROUP_BITS; k++)
      { count[ofs * GROUP_COUNTERS / GROUP_BITS + (sum % GROUP_COUNTERS)]++;
        ofs += GROUP_BITS;
        sum >>= GROUP_BITS;
  } } }

  /* Distribute the occurrences of each bit group over the frequency table. */
  for (ofs = 0; ofs < MD_BITS; ofs += GROUP_BITS)
  { int j;
    for (j = 0; j < GROUP_COUNTERS; j++)
    { int k;
      for (k = 0; k < GROUP_BITS; k++)
      { freq[ofs + k] += ((1<<k) & j) 
          ? count[ofs * GROUP_COUNTERS / GROUP_BITS + j]
          : -count[ofs * GROUP_COUNTERS / GROUP_BITS + j];
  } } }
      
  { int j;
    int num = MD_BITS;
    int stat_warn = 0;
    double sum = 0.0;
    double sumsqr = 0.0;
    double average, variance, stddev, bits, exp_average, max_average;

    assert (num >= 2);

    sum = 0;

    for (j = 0; j < num; j++)
    { double f = abs ((double) freq[j]);
      sum += f;
      sumsqr += f*f;
    }

    variance = (sumsqr - (sum * sum / num)) / (num - 1);
    average = sum / num;
    stddev = sqrt (variance);
    bits = (NUM_HASHES_PER_CHAR * (file[file_count].length - freq_dups)) 
             / (8 * MD_LENGTH);
    /* Random files, or short files with few repetitions should have
       average very close to the expected average. Large deviations
       show there is too much redundancy, or there is another problem
       with the statistical fundamentals of the algorithm. */
    exp_average = sqrt (2 * bits / pi);
    max_average = 2.0 * pow (2 * bits / pi, 0.6);

    stat_warn = flag_warning
      && (average < exp_average * 0.5 || average > max_average);
    if (stat_warn)
    { fprintf (stdout, "%s: warning: "
               "too much redundancy, fingerprint may not be accurate\n",
               file[file_count].name);
      
    }

    if (flag_verbose > 1 || (flag_verbose && stat_warn))
    { printf 
        ("%i frequencies, average %5.1f, std dev %5.1f, %2.1f %% duplicates, "
         "\"%s\"\n",
         num, average, stddev,
         100.0 * freq_dups / (double) file[file_count].length,
         file[file_count].name);
      printf
        ("%1.0f expected bits per frequency, "
         "expected average %1.1f, max average %1.1f\n",
         bits, exp_average, max_average);
  } }

  if (md)
  { rabin_reset();
    freq_to_md (md);
    if (flag_relative)
    { int d = dist (md, relative_md);
      double sim = 1.0 - MIN (1.0, (double) (d) / (MD_LENGTH * 4 - 1));
      fprintf (stdout, "%s %llu %u %s %u %3.1f\n", 
               md_to_str (md), (long long unsigned) 0, len, name, 
               d, 100.0 * sim);
    }
    else
    {
      fprintf (stdout, "%s %llu %u %s\n", 
               md_to_str (md), (long long unsigned) 0, len, name);
} } }

void process_file (char *name)
{ int fd;
  struct stat fs;
  u_char *data;
  File *fi = file+file_count;;

  fd = open (name, O_RDONLY, 0);
  if (fd < 0) 
  { perror (name);
    exit (2);
  }

  if (fstat (fd, &fs))
  { perror (name);
    exit (2);
  }

  if (fs.st_size >= MIN_FILE_SIZE
      && fs.st_size <= MAX_FILE_SIZE)
  { fi->length = fs.st_size;
    fi->name = name;

    data = (u_char *) mmap (0, fs.st_size, PROT_READ, MAP_PRIVATE, fd, 0);

    if (data == (u_char *) -1)
    { perror (name);
      exit (2);
    }

    process_data (name, data, fs.st_size, fi->md);
    munmap (data, fs.st_size);
    file_bytes += fs.st_size;
    file_count++;
  } else if (flag_verbose) 
  { fprintf (stdout, "skipping %s (size %llu)\n", name, fs.st_size); }

  close (fd);
}

int main (int argc, char *argv[])
{ int ch, j;

  strncpy (cmd, basename (argv[0]), 8);
  msgout = stdout;

  while ((ch = getopt(argc, argv, "dhr:vw")) != -1)
  { switch (ch) 
    { case 'd': flag_debug++;
		break;
      case 'r': if (!optarg)
                { fprintf (stderr, "%s: missing argument for -r\n", cmd);
                  return 1;
                }
                if (str_to_md (optarg, relative_md)) flag_relative = optarg;
                else
                { fprintf (stderr, "%s: not a valid fingerprint\n", optarg);
                  return 1;
                }
                break;
      case 'v': flag_verbose++;
                break;
      case 'w': flag_warning++;
                break;
      default : usage();
                return (ch != 'h');
  } }

  argc -= optind;
  argv += optind;

  if (argc == 0) usage();

  rabin_reset ();
  if (flag_verbose && flag_relative)
  { fprintf (stdout, "distances are relative to %s\n", flag_relative);
  }

  file = (File *) calloc (argc, sizeof (File));

  for (j = 0; j < argc; j++) process_file (argv[j]);

  if (flag_verbose) 
  { fprintf (stdout, "%li bytes in %i files\n", file_bytes, file_count);
  }

  return 0;
}

