#!/usr/bin/perl -w
#
# Given a spam.log and nonspam.log from a "mass-check --bayes" run,
# work out optimum spam/ham cutoffs to maximise efficiency.
# 
# usage: bayes-threshold spam.log nonspam.log

my $spam = $ARGV[0] || "spam.log";
my $nonspam = $ARGV[1] || (-f "good.log" ? "good.log" : "nonspam.log");

my $nbuckets = 50;

my $range_lo = 0.00;
my $range_hi = 1.00;

# shamelessly nicked from spambayes' testing infrastructure; a system to
# compute the "cost" of a pair of thresholds and classifier.  I'm supporting
# it here so our stats are (at least a little) comparable against theirs.
#
# Note that they use a different way to run 10PCV tests; they train with 1
# bucket and test against 9, whereas the lit suggests doing the opposite,
# which is what we do; so the stats may still be unportable due to us getting
# better training and less testing.  TODO

my $best_cutoff_fp_weight = 10.0;
my $best_cutoff_fn_weight = 1.0;
my $best_cutoff_unsure_weight = 0.2;

%bux_sp = ();
%bux_ns = ();

my $step = ($range_hi - $range_lo) / $nbuckets;
my $i;
for ($i = $range_lo; $i <= $range_hi; $i += $step) {
  push (@buckets, $i);
  $bux_ns{$i} = $bux_sp{$i} = 0;
}

foreach my $file ($spam, $nonspam) {
  open (IN, "<$file") || die "Could not open file '$file': $!";

  my $isspam = 0; ($file eq $spam) and $isspam = 1;

  while (<IN>) {
    /^(\.|Y)\s.+bayes=([^\s,]+)/ or next;
    my $score = $2+0;

    my $bucket_id;
    foreach my $bucket (@buckets) {
      if ($score >= $bucket && $score < $bucket+$step) {
        $bucket_id = $bucket; last;
      }
    }

    if ($isspam) {
      $bux_sp{$bucket_id}++;
    } else {
      $bux_ns{$bucket_id}++;
    }
  }
}

my $max_sp = 0;
my $max_ns = 0;
my $tot_sp = 0;
my $tot_ns = 0;
foreach my $bucket (@buckets) {
  $tot_sp += $bux_sp{$bucket};
  if ($bux_sp{$bucket} > $max_sp) 
                        { $max_sp = $bux_sp{$bucket}; }
  $tot_ns += $bux_ns{$bucket};
  if ($bux_ns{$bucket} > $max_ns) 
                        { $max_ns = $bux_ns{$bucket}; }
}

my %results = ();
for ($hamcutoff = 0; $hamcutoff < 1; $hamcutoff += $step) {
  for ($spamcutoff = 0; $spamcutoff < 1; $spamcutoff += $step) {
    my $fn = 0;
    my $fp = 0;
    my $unsure_sp = 0;
    my $unsure_ns = 0;

    for ($i = $range_lo; $i < $hamcutoff; $i += $step) {
      $fn += $bux_sp{$i};
    }
    # total up the unsures (between hamcutoff and spamcutoff)
    for ($i = $hamcutoff; $i < $spamcutoff; $i += $step) {
      $unsure_ns += $bux_ns{$i};
      $unsure_sp += $bux_sp{$i};
    }
    for ($i = $spamcutoff; $i < $range_hi; $i += $step) {
      $fp += $bux_ns{$i};
    }

    my $cost = ($fp * $best_cutoff_fp_weight) 
		+ ($fn * $best_cutoff_fn_weight)
		+ (($unsure_ns+$unsure_sp) * $best_cutoff_unsure_weight);

    $results{"$hamcutoff $spamcutoff"} = {
      'hamcutoff' => $hamcutoff,
      'spamcutoff'=> $spamcutoff,
      'cost' => $cost,
      'unsure_ns' => $unsure_ns,
      'unsure_sp' => $unsure_sp,
      'fp' => $fp,
      'fn' => $fn
    };
  }
}

my $count = 10;
foreach my $r (sort { $a->{cost} <=> $b->{cost} } values %results) {
  printf "Threshold optimization for hamcutoff=%3.2f, spamcutoff=%3.2f: cost=\$%5.2f\n",
		$r->{hamcutoff}, $r->{spamcutoff}, $r->{cost};
  printf "Total ham:spam:   %d:%d\n", $tot_ns, $tot_sp;

  printf "FP: %5d %5.3f%%    ", $r->{fp}, ($r->{fp}*100) / $tot_ns;
  printf "FN: %5d %5.3f%%\n", $r->{fn}, ($r->{fn}*100) / $tot_sp;

  my $unsure = $r->{unsure_ns} + $r->{unsure_sp};
  printf "Unsure: %5d %5.3f%%     ", $unsure,
				($unsure*100) / ($tot_sp+$tot_ns);
  printf "(ham: %5d %5.3f%%    ", $r->{unsure_ns},
				($r->{unsure_ns}*100) / ($tot_ns);
  printf "spam: %5d %5.3f%%)\n", $r->{unsure_sp},
				($r->{unsure_sp}*100) / ($tot_sp);

  # for TCR calc, treat "unsures" as ham
  # TODO: unsure_sp should probably be treated as spam, assuming
  # it'll fall in the 5.0-6.0 score range
  my $fn = $r->{unsure_sp} + $r->{fn};
  my $fp = $r->{fp};
  printf "TCRs:              l=1 %5.3f    l=5 %5.3f    l=9 %5.3f\n",
    tcr ($tot_sp - $fn, $fn, $fp, $tot_ns - $fp, 1),
    tcr ($tot_sp - $fn, $fn, $fp, $tot_ns - $fp, 5),
    tcr ($tot_sp - $fn, $fn, $fp, $tot_ns - $fp, 9);

  print "\n";
  last if ($count-- < 0);
}


sub tcr {
  my ($nspamspam, $nspamlegit, $nlegitspam, $nlegitlegit, $lambda) = @_;

  my $nlegit = $nlegitspam+$nlegitlegit;
  my $nspam = $nspamspam+$nspamlegit;

  my $werr = ($lambda * $nlegitspam + $nspamlegit)
		  / ($lambda * $nlegit + $nspam);

  my $werr_base = $nspam
		  / ($lambda * $nlegit + $nspam);

  $werr ||= 0.000001;     # avoid / by 0
  my $tcr = $werr_base / $werr;

#my $sr = ($nspamspam / $nspam) * 100.0;
#my $sp = ($nspamspam / ($nspamspam + $nlegitspam)) * 100.0;

  $tcr;
}

