#ifndef BVHSTRUCTURE_HXX_
#define BVHSTRUCTURE_HXX_

#include "./primitive/Box.hxx"

class BVHStructure {
	struct Node 	{
		Box bbox;
		virtual void traverse(Ray &ray) = 0;

		virtual ~Node() {};
	};

	struct InnerNode : public Node	{
		Node *leftChild, *rightChild;

		InnerNode(const Box& aBBox) {
			bbox = aBBox;
		}

		virtual void traverse(Ray &ray)	{
			float leftMinInt = leftChild->bbox.Intersect(ray).first;
			float rightMinInt = rightChild->bbox.Intersect(ray).first;

			if(leftMinInt < rightMinInt) {
				leftChild->traverse(ray);

				if(ray.t < rightMinInt) {
					return;
				}

				rightChild->traverse(ray);
			} else {
				if(rightMinInt < FLT_MAX) {
					rightChild->traverse(ray);

					if(ray.t < leftMinInt) {
						return;
					}

					leftChild->traverse(ray);
				}
			}
		}

		virtual ~InnerNode() {
			if (leftChild) delete leftChild;
			if (rightChild) delete rightChild;
		};
	};

	struct LeafNode : public Node {
		std::vector<Primitive *> primitive;

		LeafNode(const Box& aBBox, std::vector<Primitive *> &prim) {
			bbox = aBBox;
			primitive = prim;
		}

		virtual void traverse(Ray &ray) {
			for(int i=0; i<(int)primitive.size(); i++) {
				primitive[i]->Intersect(ray);
			}
		}

		virtual ~LeafNode() {};
	};

	public:
		int maxDepth, minTri;

		Node *BuildTree(Box &bounds, std::vector<Primitive *> prim, int depth = 0) {
			if(depth > maxDepth || (int)prim.size() <= minTri) {
				return new LeafNode(bounds, prim);
			}

			InnerNode *node = new InnerNode(bounds);

			Vec3f diam = bounds.max - bounds.min;
			int dim = diam.MaxDim();

			std::vector<Vec3f> centers;
			std::vector<Box> primBox;
			float sumx, sumy, sumz;
			sumx = sumy = sumz = 0.0;
			int psize = (int)prim.size();
			for(int i = 0; i < psize; i++) {
				primBox.push_back(prim[i]->CalcBounds());
				centers.push_back((primBox[i].min + primBox[i].max) * 0.5);
				sumx+= centers[i].x();
				sumy+= centers[i].y();
				sumz+= centers[i].z();                    
			}

			Vec3f center = Vec3f(sumx/(float)psize, sumy/(float)psize, sumz/(float)psize);

			Box lBounds, rBounds;
			std::vector<Primitive *> lPrim, rPrim;

			for(int i = 0; i < (int)prim.size(); i++) {
				if(centers[i][dim] <= center[dim]) {
					lPrim.push_back(prim[i]);
					lBounds.Extend(primBox[i]);
				} else {
					rPrim.push_back(prim[i]);
					rBounds.Extend(primBox[i]);
				}
			}

			node->leftChild = BuildTree(lBounds,lPrim,depth+1);
			node->rightChild = BuildTree(rBounds,rPrim,depth+1);

			return node;
		}

		Node *root;

		BVHStructure(Box topBox, std::vector<Primitive *> primitives) {
			maxDepth = 30;
			minTri = 3;
			root = NULL;
			root = BuildTree(topBox, primitives, 0);
		}

		~BVHStructure() {
			if(root) {
				delete root;
			}
		}

		bool Intersect(Ray &ray) {
			root->traverse(ray);

			return ray.hit != NULL;
		}
};

#endif /*BVHSTRUCTURE_HXX_*/