#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <string.h>
#include "ip-mtbdd.h"
#include "bitstream.h"

int mtbdd_malloced_blocks=0;
MTBDD* const mtbdd_empty=NULL;

#define assert_value(value) assert(value==MTBDD_NODE_VALUE ||\
				   (value>=0 && value<256))

static MTBDD *create_node(MTBDD *st0, MTBDD *st1)
{
  if (st0 != NULL && st1 != NULL &&
      st0->value != MTBDD_NODE_VALUE &&
      st0->value == st1->value)
  {
    free(st1);
    return st0;
  } else
  {
    MTBDD *bdd=malloc(sizeof(MTBDD));
    mtbdd_malloced_blocks++;
    bdd->value = MTBDD_NODE_VALUE;
    bdd->node.subtrees[0] = st0;
    bdd->node.subtrees[1] = st1;
    return bdd;
  }
}

static MTBDD *replace_undefined(MTBDD *bdd, int value)
{
  assert(value != MTBDD_NODE_VALUE);
  assert_value(value);

  if (bdd==NULL)
  {
    bdd = malloc(sizeof(bdd->value));
    mtbdd_malloced_blocks++;
    bdd->value = value;
    return bdd;
  }
  if (bdd->value == MTBDD_NODE_VALUE)
  {
    MTBDD *st0=bdd->node.subtrees[0], *st1=bdd->node.subtrees[1];
    free(bdd);
    return create_node(replace_undefined(st0, value),
		       replace_undefined(st1, value));
  }
  return bdd;
}

MTBDD* mtbdd_add(MTBDD *bdd, IPaddress key, int netmask, int value)
{
  assert(netmask >= 0);
  assert(value != MTBDD_NODE_VALUE);

  if (netmask == 0)
  {
    if (bdd==NULL || bdd->value == MTBDD_NODE_VALUE)
      return replace_undefined(bdd, value);
    else
    {
      bdd->value = value;
      return bdd;
    }
  } else
  {
    if (bdd == NULL)
    {
      MTBDD *subtree = mtbdd_add(NULL, key<<1, netmask-1, value);
      bdd = malloc(sizeof(MTBDD));
      mtbdd_malloced_blocks++;
      if ((key>>31) & 1)
      {
	bdd->node.subtrees[1] = subtree;
        bdd->node.subtrees[0] = NULL;
      } else
      {
        bdd->node.subtrees[0] = subtree;
        bdd->node.subtrees[1] = NULL;
      }
      bdd->value = MTBDD_NODE_VALUE;
      return bdd;
    } else
    {
      if (bdd->value == MTBDD_NODE_VALUE)
      {
	MTBDD *st0, *st1;

	st0=bdd->node.subtrees[0];
	st1=bdd->node.subtrees[1];
	free(bdd);

	return (((key>>31) & 1) ?
		create_node(st0,
			    mtbdd_add(st1, key<<1, netmask-1, value)) :
		create_node(mtbdd_add(st0, key<<1, netmask-1, value),
			    st1));
      } else
      {
	int old_value=bdd->value;
	free(bdd);
        return replace_undefined(mtbdd_add(NULL, key, netmask, value),
				 old_value);
      }
    }
  }
}

int mtbdd_find(MTBDD *bdd, IPaddress key, int netmask)
{
  assert(netmask >= 0);
  while(netmask > 0)
  {
    if (bdd==NULL || bdd->value != MTBDD_NODE_VALUE) break;
    bdd = bdd->node.subtrees[(key >> 31) & 1];
    netmask --;
    key <<= 1;
  }
  if (bdd==NULL) return MTBDD_UNDEFINED_VALUE;
  if (bdd->value==MTBDD_NODE_VALUE) return MTBDD_NOT_ENOUGH_BITS;
  return bdd->value;
}

void mtbdd_print(FILE *fp, char *buffer, int size, int i, MTBDD *bdd)
{
  assert(i < size);

  if (bdd==NULL) return;
  if (bdd->value == MTBDD_NODE_VALUE)
  {
    buffer[i]='0';
    mtbdd_print(fp, buffer, size, i+1, bdd->node.subtrees[0]);
    buffer[i]='1';
    mtbdd_print(fp, buffer, size, i+1, bdd->node.subtrees[1]);
  } else
  {
    buffer[i]=0;
    fprintf(fp, "%s -> %d\n", buffer, bdd->value);
  } 
}

int mtbdd_size(MTBDD *bdd)
{
  if (bdd == NULL) return 0;
  if (bdd->value == MTBDD_NODE_VALUE)
  {
    return 1+mtbdd_size(bdd->node.subtrees[0])
            +mtbdd_size(bdd->node.subtrees[1]);
  } else return 1;
}

void mtbdd_delete(MTBDD *bdd)
{
  if (bdd==NULL) return;
  if (bdd->value == MTBDD_NODE_VALUE)
  {
    mtbdd_delete(bdd->node.subtrees[0]);
    mtbdd_delete(bdd->node.subtrees[1]);
  }
  free(bdd);
}

static void mtbdd_save_intern(BITSTREAM *skel, FILE *leaf, const MTBDD *bdd)
{
  if (bdd == NULL)
  {
    bitstream_wbit(skel, 0);    
    bitstream_wbit(skel, 0);    
  } else
  {
    if (bdd->value == MTBDD_NODE_VALUE)
    {
      bitstream_wbit(skel, 1);        
      bitstream_wbit(skel, 1);        
      mtbdd_save_intern(skel, leaf, bdd->node.subtrees[0]);
      mtbdd_save_intern(skel, leaf, bdd->node.subtrees[1]);
    } else
    {
      bitstream_wbit(skel, 0);
      bitstream_wbit(skel, 1);             
      putc(bdd->value, leaf);
    }
  }
}

int mtbdd_save(const char *filename, const MTBDD *bdd)
{
  BITSTREAM *skeleton_bsp;
  FILE *leaf_fp;
  int error;

  {
    int filename_length=strlen(filename);
    char *filename_buf = malloc(filename_length + 5);
    strcpy(filename_buf, filename);
    strcpy(filename_buf+filename_length, ".skel");
    skeleton_bsp = bitstream_wopen(fopen(filename_buf, "wb"));
    strcpy(filename_buf+filename_length, ".leaf");
    leaf_fp = fopen(filename_buf, "wb");
    free(filename_buf);
  }
  if (skeleton_bsp == NULL)
  {
    if (leaf_fp != NULL) fclose(leaf_fp);
    return 1;
  }
  if (leaf_fp == NULL)
  {
    if (skeleton_bsp != NULL) bitstream_wclose(skeleton_bsp);
    return 2;
  }

  mtbdd_save_intern(skeleton_bsp, leaf_fp, bdd);

  error = fclose(leaf_fp);
  return (bitstream_wclose(skeleton_bsp) || error);
}

static MTBDD *mtbdd_read_intern(BITSTREAM *skel, FILE *leaf)
{
  if (bitstream_rbit(skel))
  {
    MTBDD *st0, *st1;
    assert(bitstream_rbit(skel));
    st0=mtbdd_read_intern(skel, leaf),
    st1=mtbdd_read_intern(skel, leaf);
    return create_node(st0, st1);
  } else if (bitstream_rbit(skel))
  { int value=getc(leaf);
    MTBDD *bdd;
    assert (value != EOF);
    bdd = malloc(sizeof(bdd->value));
    bdd->value = value;
    return bdd;
  } else
  { return NULL;
  }
}

MTBDD *mtbdd_read(const char *filename)
{
  BITSTREAM *skeleton_bsp;
  FILE *leaf_fp;
  MTBDD *result;

  {
    int filename_length=strlen(filename);
    char *filename_buf = malloc(filename_length + 5);
    strcpy(filename_buf, filename);
    strcpy(filename_buf+filename_length, ".skel");
    skeleton_bsp = bitstream_ropen(fopen(filename_buf, "rb"));
    strcpy(filename_buf+filename_length, ".leaf");
    leaf_fp = fopen(filename_buf, "rb");
    free(filename_buf);
  }
  if (skeleton_bsp == NULL)
  {
    if (leaf_fp != NULL) fclose(leaf_fp);
    return NULL;
  }
  if (leaf_fp == NULL)
  {
    if (skeleton_bsp != NULL) bitstream_rclose(skeleton_bsp);
    return NULL;
  }

  result=mtbdd_read_intern(skeleton_bsp, leaf_fp);

  fclose(leaf_fp);
  bitstream_rclose(skeleton_bsp);
  return result;
}


