// Source file for mesh search tree class



////////////////////////////////////////////////////////////////////////
// Include files
////////////////////////////////////////////////////////////////////////

#include "R3Shapes/R3Shapes.h"





////////////////////////////////////////////////////////////////////////
// Constant definitions
////////////////////////////////////////////////////////////////////////

static const int max_faces_per_node = 32;
static const int max_depth = 1024;





////////////////////////////////////////////////////////////////////////
// Face class definition
////////////////////////////////////////////////////////////////////////

class R3MeshSearchTreeFace {
public:
  R3MeshSearchTreeFace(R3MeshFace *face) 
  : face(face), mark(0) {};

public:
  R3MeshFace *face;
  RNMark mark;
};



////////////////////////////////////////////////////////////////////////
// Node class definition
////////////////////////////////////////////////////////////////////////

// Node declaration

class R3MeshSearchTreeNode {
public:
  R3MeshSearchTreeNode(R3MeshSearchTreeNode *parent = NULL)
  : parent(parent), split_coordinate(0), split_dimension(0), nfaces(0)
  { children[0] = NULL; children[1] = NULL; };

public:
  class R3MeshSearchTreeNode *parent;
  class R3MeshSearchTreeNode *children[2];
  RNScalar split_coordinate;
  RNDimension split_dimension;
  R3MeshSearchTreeFace *faces[max_faces_per_node];
  int nfaces;
};




////////////////////////////////////////////////////////////////////////
// Constructor/destructor functions
////////////////////////////////////////////////////////////////////////

R3MeshSearchTree::
R3MeshSearchTree(R3Mesh *mesh)
  : mesh(mesh),
    nnodes(1),
    mark(1)
{
  // Create root 
  root = new R3MeshSearchTreeNode(NULL);
  assert(root);

  // Insert faces into tree
  for (int i = 0; i < mesh->NFaces(); i++) {
    R3MeshFace *face = mesh->Face(i);
    InsertFace(face);
  }
}



R3MeshSearchTree::
~R3MeshSearchTree(void)
{
  // Check root
  if (!root) return;

  // Traverse tree deleting nodes
  RNArray<R3MeshSearchTreeNode *> stack;
  stack.InsertTail(root);
  while (!stack.IsEmpty()) {
    R3MeshSearchTreeNode *node = stack.Tail();
    stack.RemoveTail();
    if (node->children[0]) stack.Insert(node->children[0]);
    if (node->children[1]) stack.Insert(node->children[1]);
    for (int i = 0; i < node->nfaces; i++) delete node->faces[i];
    delete node;
  }
}



////////////////////////////////////////////////////////////////////////
// Property functions
////////////////////////////////////////////////////////////////////////

const R3Box& R3MeshSearchTree::
BBox(void) const
{
  // Return bounding box of the whole KD tree
  return mesh->BBox();
}



////////////////////////////////////////////////////////////////////////
// Insert functions
////////////////////////////////////////////////////////////////////////

void R3MeshSearchTree::
InsertFace(R3MeshFace *face)
{
  // Create container
  R3MeshSearchTreeFace *face_container = new R3MeshSearchTreeFace(face);
  assert(face_container);

  // Insert face into root
  InsertFace(face_container, root, BBox(), 0);
}



void R3MeshSearchTree::
InsertFace(R3MeshSearchTreeFace *face, R3MeshSearchTreeNode *node, const R3Box& node_box, int depth) 
{
  // Check if interior node
  if (node->children[0]) {
    // Interior node -- Insert into children
    assert(node->children[1]);
    const R3Box& face_box = mesh->FaceBBox(face->face);
    if (face_box[RN_LO][node->split_dimension] <= node->split_coordinate) {
      R3Box node0_box(node_box);
      node0_box[RN_HI][node->split_dimension] = node->split_coordinate;
      InsertFace(face, node->children[0], node0_box, depth + 1);
    }
    if (face_box[RN_HI][node->split_dimension] >= node->split_coordinate) {
      R3Box node1_box(node_box);
      node1_box[RN_LO][node->split_dimension] = node->split_coordinate;
      InsertFace(face, node->children[1], node1_box, depth + 1);
    }
  }
  else {
    // Leaf node -- Check if there is room for this face
    if (node->nfaces < max_faces_per_node) {
      // There is room -- simply insert into list
      node->faces[node->nfaces++] = face;
    }
    else {
      // No room for more faces -- check depth
      if (depth < max_depth) {
        // Create two children 
        node->children[0] = new R3MeshSearchTreeNode(node);
        node->children[1] = new R3MeshSearchTreeNode(node);
        node->split_dimension = node_box.LongestAxis();
        node->split_coordinate = node_box.AxisCenter(node->split_dimension);
        nnodes += 2;

        // Re-insert faces into subtree
        InsertFace(face, node, node_box, depth);
        for (int i = 0; i < node->nfaces; i++) {
          InsertFace(node->faces[i], node, node_box, depth);
        }

        // Clear out faces from node that is now interior
        node->nfaces = 0;
      }
      else {
        fprintf(stderr, "Warning: KD tree depth exceeded %d -- skipped face %d\n", max_depth, mesh->FaceID(face->face));
      }
    }
  }
}



////////////////////////////////////////////////////////////////////////
// Search functions
////////////////////////////////////////////////////////////////////////

void R3MeshSearchTree::
FindClosest(const R3Point& query_position, const R3Vector& query_normal, R3MeshIntersection& closest, 
  RNScalar min_distance_squared, RNScalar& max_distance_squared, 
  int (*IsCompatible)(const R3Point&, const R3Vector&, R3MeshFace *, void *), void *compatible_data,
  R3MeshFace *face) const
{
  // Check distance to plane
  const R3Plane& plane = mesh->FacePlane(face);
  RNScalar plane_signed_distance = R3SignedDistance(plane, query_position);
  RNScalar plane_distance_squared = plane_signed_distance * plane_signed_distance;
  if (plane_distance_squared >= max_distance_squared) return;

  // Check distance to bounding box
  RNScalar bbox_distance_squared = DistanceSquared(query_position, mesh->FaceBBox(face), max_distance_squared);
  if (bbox_distance_squared >= max_distance_squared) return;

  // Check compatibility 
  if (IsCompatible) {
    if (!(*IsCompatible)(query_position, query_normal, face, compatible_data)) return;
  }

  // Get face vertices
  R3MeshVertex *v0 = mesh->VertexOnFace(face, 0);
  R3MeshVertex *v1 = mesh->VertexOnFace(face, 1);
  R3MeshVertex *v2 = mesh->VertexOnFace(face, 2);

  // Get vertex positions
  const R3Point& p0 = mesh->VertexPosition(v0);
  const R3Point& p1 = mesh->VertexPosition(v1);
  const R3Point& p2 = mesh->VertexPosition(v2);

  // Project query point onto face plane
  const R3Vector& face_normal = mesh->FaceNormal(face);
  R3Point plane_point = query_position - plane_signed_distance * face_normal;

  // Check sides of edges
  R3Vector e0 = p1 - p0;
  e0.Normalize();
  R3Vector n0 = mesh->FaceNormal(face) % e0;
  R3Plane s0(p0, n0);
  RNScalar b0 = R3SignedDistance(s0, plane_point);
  R3Vector e1 = p2 - p1;
  e1.Normalize();
  R3Vector n1 = mesh->FaceNormal(face) % e1;
  R3Plane s1(p1, n1);
  RNScalar b1 = R3SignedDistance(s1, plane_point);
  R3Vector e2 = p0 - p2;
  e2.Normalize();
  R3Vector n2 = mesh->FaceNormal(face) % e2;
  R3Plane s2(p2, n2);
  RNScalar b2 = R3SignedDistance(s2, plane_point);

  // Consider plane_point's position in relation to edges of the triangle
  if ((b0 >= 0) && (b1 >= 0) && (b2 >= 0)) {
    // Point is inside face
    if (plane_distance_squared >= min_distance_squared) {
      closest.type = R3_MESH_FACE_TYPE;
      closest.face = face;
      closest.point = plane_point;
      max_distance_squared = plane_distance_squared;
    }
  }
  else {
    // Point is outside face -- check each edge
    if (b0 < 0) {
      // Outside edge0
      R3Vector edge_vector = p1 - p0;
      RNScalar edge_length = edge_vector.Length();
      if (edge_length > 0) {
        edge_vector /= edge_length;
        R3Vector point_vector = plane_point - p0;
        RNScalar t = edge_vector.Dot(point_vector);
        if (t <= 0) {
          RNScalar distance_squared = DistanceSquared(query_position, p0);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_VERTEX_TYPE;
            closest.vertex = v0;
            closest.point = p0;
            max_distance_squared = distance_squared;
          }
        }
        else if (t >= edge_length) {
          RNScalar distance_squared = DistanceSquared(query_position, p1);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_VERTEX_TYPE;
            closest.vertex = v1;
            closest.point = p1;
            max_distance_squared = distance_squared;
          }
        }
        else {
          R3Point point = p0 + t * edge_vector;
          RNScalar distance_squared = DistanceSquared(query_position, point);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_EDGE_TYPE;
            closest.edge = mesh->EdgeOnFace(face, 0);
            closest.point = point;
            max_distance_squared = distance_squared;
          }
        }
      }
    }
    if (b1 < 0) {
      // Outside edge1
      R3Vector edge_vector = p2 - p1;
      RNScalar edge_length = edge_vector.Length();
      if (edge_length > 0) {
        edge_vector /= edge_length;
        R3Vector point_vector = plane_point - p1;
        RNScalar t = edge_vector.Dot(point_vector);
        if (t <= 0) {
          RNScalar distance_squared = DistanceSquared(query_position, p1);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_VERTEX_TYPE;
            closest.vertex = v1;
            closest.point = p1;
            max_distance_squared = distance_squared;
          }
        }
        else if (t >= edge_length) {
          RNScalar distance_squared = DistanceSquared(query_position, p2);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_VERTEX_TYPE;
            closest.vertex = v2;
            closest.point = p2;
            max_distance_squared = distance_squared;
          }
        }
        else {
          R3Point point = p1 + t * edge_vector;
          RNScalar distance_squared = DistanceSquared(query_position, point);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_EDGE_TYPE;
            closest.edge = mesh->EdgeOnFace(face, 1);
            closest.point = point;
            max_distance_squared = distance_squared;
          }
        }
      }
    }
    if (b2 < 0) {
      // Outside edge2
      R3Vector edge_vector = p0 - p2;
      RNScalar edge_length = edge_vector.Length();
      if (edge_length > 0) {
        edge_vector /= edge_length;
        R3Vector point_vector = plane_point - p2;
        RNScalar t = edge_vector.Dot(point_vector);
        if (t <= 0) {
          RNScalar distance_squared = DistanceSquared(query_position, p2);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_VERTEX_TYPE;
            closest.vertex = v2;
            closest.point = p2;
            max_distance_squared = distance_squared;
          }
        }
        else if (t >= edge_length) {
          RNScalar distance_squared = DistanceSquared(query_position, p0);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_VERTEX_TYPE;
            closest.vertex = v0;
            closest.point = p0;
            max_distance_squared = distance_squared;
          }
        }
        else {
          R3Point point = p2 + t * edge_vector;
          RNScalar distance_squared = DistanceSquared(query_position, point);
          if ((distance_squared >= min_distance_squared) && (distance_squared < max_distance_squared)) {
            closest.type = R3_MESH_EDGE_TYPE;
            closest.edge = mesh->EdgeOnFace(face, 2);
            closest.point = point;
            max_distance_squared = distance_squared;
          }
        }
      }
    }
  }
}



void R3MeshSearchTree::
FindClosest(const R3Point& query_position, const R3Vector& query_normal, R3MeshIntersection& closest, 
  RNScalar min_distance_squared, RNScalar& max_distance_squared, 
  int (*IsCompatible)(const R3Point&, const R3Vector&, R3MeshFace *, void *), void *compatible_data,
  R3MeshSearchTreeNode *node, const R3Box& node_box) const
{
  // Compute distance (squared) from query point to node bbox
  RNScalar distance_squared = DistanceSquared(query_position, node_box, max_distance_squared);
  if (distance_squared >= max_distance_squared) return;

  // Check if node is interior
  if (node->children[0]) {
    assert(node->children[1]);

    // Compute distance from query point to split plane
    RNScalar side = query_position[node->split_dimension] - node->split_coordinate;

    // Search children nodes
    if (side <= 0) {
      // Search negative side first
      R3Box child_box(node_box);
      child_box[RN_HI][node->split_dimension] = node->split_coordinate;
      FindClosest(query_position, query_normal, closest, 
        min_distance_squared, max_distance_squared, IsCompatible, compatible_data,
        node->children[0], child_box);
      if (side*side < max_distance_squared) {
        R3Box child_box(node_box);
        child_box[RN_LO][node->split_dimension] = node->split_coordinate;
        FindClosest(query_position, query_normal, closest, 
          min_distance_squared, max_distance_squared, IsCompatible, compatible_data,
          node->children[1], child_box);
      }
    }
    else {
      // Search positive side first
      R3Box child_box(node_box);
      child_box[RN_LO][node->split_dimension] = node->split_coordinate;
      FindClosest(query_position, query_normal, closest, 
        min_distance_squared, max_distance_squared, IsCompatible, compatible_data,
        node->children[1], child_box);
      if (side*side < max_distance_squared) {
        R3Box child_box(node_box);
        child_box[RN_HI][node->split_dimension] = node->split_coordinate;
        FindClosest(query_position, query_normal, closest, 
          min_distance_squared, max_distance_squared, IsCompatible, compatible_data,
          node->children[0], child_box);
      }
    }
  }
  else {
    // Update based on distance to each face
    for (int i = 0; i < node->nfaces; i++) {
      // Get face container and check mark
      R3MeshSearchTreeFace *face_container = node->faces[i];
      if (face_container->mark == mark) continue;
      face_container->mark = mark;

      // Find closest point in mesh face
      FindClosest(query_position, query_normal, closest, 
        min_distance_squared, max_distance_squared, 
        IsCompatible, compatible_data, face_container->face);
    }
  }
}



void R3MeshSearchTree::
FindClosest(const R3Point& query_position, const R3Vector& query_normal, R3MeshIntersection& closest,
  RNScalar min_distance, RNScalar max_distance, 
  int (*IsCompatible)(const R3Point&, const R3Vector&, R3MeshFace *, void *), void *compatible_data)
{
  // Initialize result
  closest.type = R3_MESH_NULL_TYPE;
  closest.vertex = NULL;
  closest.edge = NULL;
  closest.face = NULL;
  closest.point = R3zero_point;
  closest.t = 0;

  // Check root
  if (!root) return;

  // Update mark (used to avoid checking same face twice)
  mark++;

  // Use squared distances for efficiency
  RNScalar min_distance_squared = min_distance * min_distance;
  RNScalar closest_distance_squared = max_distance * max_distance;

  // Search nodes recursively
  FindClosest(query_position, query_normal, closest, 
    min_distance_squared, closest_distance_squared, 
    IsCompatible, compatible_data, 
    root, BBox());

  // Update result
  closest.t = sqrt(closest_distance_squared);
  if (closest.type == R3_MESH_VERTEX_TYPE) { 
    closest.edge = mesh->EdgeOnVertex(closest.vertex); 
    closest.face = mesh->FaceOnEdge(closest.edge);
  }
  else if (closest.type == R3_MESH_EDGE_TYPE) { 
    closest.vertex = NULL; 
    closest.face = mesh->FaceOnEdge(closest.edge);
  }
  else if (closest.type == R3_MESH_FACE_TYPE) { 
    closest.vertex = NULL; 
    closest.edge = NULL; 
  }
}



void R3MeshSearchTree::
FindClosest(const R3Point& query_position, R3MeshIntersection& closest,
  RNScalar min_distance, RNScalar max_distance,
  int (*IsCompatible)(const R3Point&, const R3Vector&, R3MeshFace *, void *), void *compatible_data)
{
  // Find closest point, ignoring normal
  FindClosest(query_position, R3zero_vector, closest, min_distance, max_distance, IsCompatible, compatible_data);
}



////////////////////////////////////////////////////////////////////////
// Visualization and debugging functions
////////////////////////////////////////////////////////////////////////

void R3MeshSearchTree::
Outline(R3MeshSearchTreeNode *node, const R3Box& node_box) const
{
  // Draw kdtree nodes recursively
  if (node->children[0]) {
    assert(node->children[1]);
    assert(node->split_coordinate >= node_box[RN_LO][node->split_dimension]);
    assert(node->split_coordinate <= node_box[RN_HI][node->split_dimension]);
    R3Box child0_box(node_box);
    R3Box child1_box(node_box);
    child0_box[RN_HI][node->split_dimension] = node->split_coordinate;
    child1_box[RN_LO][node->split_dimension] = node->split_coordinate;
    Outline(node->children[0], child0_box);
    Outline(node->children[1], child1_box);
  }
  else {
    node_box.Outline();
  }
}



void R3MeshSearchTree::
Outline(void) const
{
  // Draw kdtree nodes recursively
  if (!root) return;
  Outline(root, BBox());
}



int R3MeshSearchTree::
Print(R3MeshSearchTreeNode *node, int depth) const
{
  // Check node
  if (!node) return 0;

  // Initialize number of decendents
  int ndecendents0 = 0;
  int ndecendents1 = 0;

  // Process interior node
  if (node->children[0] && node->children[1]) {
    // Print balance of children
    ndecendents0 = Print(node->children[0], depth+1);
    ndecendents1 = Print(node->children[1], depth+1);

    // Print balance of this node
    printf("%d", depth);
    for (int i = 0; i <= depth; i++) printf("  ");
    printf("I %d %d %g\n", ndecendents0, ndecendents1, (double) ndecendents0 / (double) ndecendents1);
  }
  else {
    printf("%d", depth);
    for (int i = 0; i <= depth; i++) printf("  ");
    printf("L %d\n", node->nfaces);
  }

  // Return number of nodes rooted in this subtree
  return 1 + ndecendents0 + ndecendents1;
}



void R3MeshSearchTree::
Print(void) const
{
  // Print recursively
  Print(root, 0);
}


int R3MeshSearchTree::
NNodes(void) const
{
  // Return number of nodes
  return nnodes;
}



////////////////////////////////////////////////////////////////////////
// Utility functions
////////////////////////////////////////////////////////////////////////

RNScalar R3MeshSearchTree::
DistanceSquared(const R3Point& query_position, const R3Point& point) const
{
  // Compute squared distance from query to point
  RNScalar dx = query_position[0] - point[0];
  RNScalar dy = query_position[1] - point[1];
  RNScalar dz = query_position[2] - point[2];
  return dx*dx + dy*dy + dz*dz;
}



RNScalar R3MeshSearchTree::
DistanceSquared(const R3Point& query_position, const R3Box& box, RNScalar max_distance_squared) const
{
  // Find and check axial distances from face to node box
  RNScalar dx, dy, dz;
  if (query_position.X() > box.XMax()) dx = query_position.X() - box.XMax();
  else if (query_position.X() < box.XMin()) dx = box.XMin()- query_position.X();
  else dx = 0.0;
  RNScalar dx_squared = dx * dx;
  if (dx_squared >= max_distance_squared) return dx_squared;
  if (query_position.Y() > box.YMax()) dy = query_position.Y() - box.YMax();
  else if (query_position.Y() < box.YMin()) dy = box.YMin()- query_position.Y();
  else dy = 0.0;
  RNScalar dy_squared = dy * dy;
  if (dy_squared >= max_distance_squared) return dy_squared;
  if (query_position.Z() > box.ZMax()) dz = query_position.Z() - box.ZMax();
  else if (query_position.Z() < box.ZMin()) dz = box.ZMin()- query_position.Z();
  else dz = 0.0;
  RNScalar dz_squared = dz * dz;
  if (dz_squared >= max_distance_squared) return dz_squared;
    
  // Find and check actual distance from face to node box
  RNScalar distance_squared = 0;
  if ((dy == 0.0) && (dz == 0.0)) distance_squared = dx_squared;
  else if ((dx == 0.0) && (dz == 0.0)) distance_squared = dy_squared;
  else if ((dx == 0.0) && (dy == 0.0)) distance_squared = dz_squared;
  else distance_squared = dx_squared + dy_squared + dz_squared;

  // Return distance squared
  return distance_squared;
}






