#ifndef SAHKDTREE__HXX_
#define SAHKDTREE__HXX_

#include <list>
#include <fstream>
#include "./primitive/Box.hxx"

class SAHKDTree {

	struct split_plane {
		int dimension; 
		bool side; //0=left, 1=right
		float position; 

		split_plane() {

		}

		split_plane(int new_dim, float new_pos) {
			dimension=new_dim;
			position=new_pos;
		}
	};

	struct KDNode {
		Box bbox;
		virtual void traverse(Ray &ray)=0;
		
		virtual ~KDNode() {};
	};

	struct KDInnerNode : public KDNode {
		KDNode *left_child, *right_child;

		KDInnerNode(Box& bounds) { 
			bbox=bounds;
		}
		
		virtual ~KDInnerNode() {};
		
		virtual void traverse(Ray &ray) {
			float leftMinInt = left_child->bbox.Intersect(ray).first;
			float rightMinInt = right_child->bbox.Intersect(ray).first;

			if(leftMinInt < rightMinInt) {
				left_child->traverse(ray);

				if(ray.t < rightMinInt) {
					return;
				}

				right_child->traverse(ray);
			} else {
				if(rightMinInt < FLT_MAX) {
					right_child->traverse(ray);

					if(ray.t < leftMinInt) {
						return;
					}

					left_child->traverse(ray);
				}
			}
		}
	};

	struct KDLeaf : public KDNode {
		std::vector<Primitive*> prim;

		KDLeaf(Box &bounds, std::vector<Primitive*> &primitives) {
			prim=primitives;
			bbox=bounds;
		}
		
		virtual ~KDLeaf() {};
		
		void traverse(Ray &ray) {
			for(int i=0; i<(int)prim.size(); i++)
				prim[i]->Intersect(ray);
		}
	};

	struct SplitEvent {
		Primitive* prim;
		float pos;
		int type;
		int dim;

		SplitEvent() {
		};

		SplitEvent(Primitive* new_prim, float new_pos, int new_type, int new_dim) {
			prim=new_prim;
			pos=new_pos;
			type=new_type;
			dim=new_dim;
		}

		bool operator<(const SplitEvent& event){
			return (this->pos<event.pos || (this->pos==event.pos && (this->dim<event.dim)));
		} 
	};

	public:
		KDNode *root;
		int maxdepth;
		unsigned int min_tri;
		float kt, ki;

		bool Stop_building(std::vector<Primitive*> &prim, Box &bbox, int depth, split_plane& p) {
			if(depth>maxdepth || prim.size()<=min_tri) {
				return true;
			} else {
				int nl = 0; 
				int nr = 0;

				for(int i=0; i<(int)prim.size(); i++) {
					if(prim.at(i)->CalcBounds().min[p.dimension] <= p.position) {
						nl++;
					} else { 
						nr++;
					}

					if(SAH(bbox, p, nl, 0, nr)>ki*prim.size()) {
						return true;
					}
				}
			}

			return false; 
		}

		KDNode* buildKDTree(std::vector<Primitive*>& prim, Box& bbox, int depth) {
			split_plane split=FindPlane(prim, bbox);

			if(Stop_building(prim, bbox, depth, split)) {
				return new KDLeaf(bbox, prim); 
			} else {
				Box lchildBox, rchildBox;
				std::vector<Primitive*> lprims, rprims;
				int dim=split.dimension; 
				KDInnerNode* node=new KDInnerNode(bbox);

				for (int p=0; p<(int)prim.size(); p++) { 
					if(prim.at(p)->CalcBounds().min[dim]<split.position) {
						lprims.push_back(prim.at(p)); 
						lchildBox.Extend(prim.at(p)->CalcBounds());
					} else {
						if(prim.at(p)->CalcBounds().max[dim]>split.position) {
							rprims.push_back(prim.at(p));
							rchildBox.Extend(prim.at(p)->CalcBounds());
						} else { 
							if(split.side) {	
								rprims.push_back(prim.at(p));
							} else {
								lprims.push_back(prim.at(p));
							}
						}
					}
				}

				node->left_child=buildKDTree(lprims, lchildBox, depth+1);
				node->right_child=buildKDTree(rprims, rchildBox, depth+1);

				return node;
			}
		} 

		split_plane FindPlane(std::vector<Primitive*>& prim, Box bbox) {
			std::list<SplitEvent> eventlist;	
			split_plane p; 
			split_plane return_p; 
			float cost=FLT_MAX; 
			float new_cost=FLT_MAX; 
			int nl, np, nr;
			int ee,ep,es;

			for(int k=0; k<3; k++) {				
				for(int t=0; t<(int)prim.size(); t++) {
					if(prim.at(t)->CalcBounds().min[k]==prim.at(t)->CalcBounds().max[k]) {
						SplitEvent planar_event; 

						planar_event.prim=prim.at(t); planar_event.pos=prim.at(t)->CalcBounds().min[k];
						planar_event.type=1; planar_event.dim=k;

						eventlist.push_back(planar_event); 
					} else {					
						SplitEvent start_event, end_event;

						start_event.prim=prim.at(t); start_event.pos=prim.at(t)->CalcBounds().min[k];
						start_event.type=2; start_event.dim=k;

						eventlist.push_back(start_event);

						end_event.prim=prim.at(t); end_event.pos=prim.at(t)->CalcBounds().max[k];
						end_event.type=0; end_event.dim=k;

						eventlist.push_back(end_event); 
					}
				}
			}

			eventlist.sort();

			for(int d=0; d<3; d++) {
				nl=0; np=0; nr=prim.size();
				for(std::list<SplitEvent>::iterator e=eventlist.begin(); e!=eventlist.end(); e++) {
					if((*e).dim==d && (*e).pos>=bbox.min[(*e).dim] && (*e).pos<=bbox.max[(*e).dim]) {
						es=ee=ep=0;

						SplitEvent event=*e;

						p.dimension=event.dim;
						p.position=event.pos;
						
						if(event.type==0 && event.dim==d) {
							ee++;
						} else {
							if(event.type==1 && event.dim==d) { 
								ep++;
							} else {
								if(event.type==2 && event.dim==d) {
									es++;
								}
							}
						}

						np=ep;
						nr-=ep-ee;

						new_cost=SAH(bbox, p, nl, np, nr);

						if(new_cost<cost) {
							cost=new_cost;
							return_p.position=p.position;
							return_p.dimension=p.dimension;
							return_p.side=p.side;
						} 

						nl+=es; 
						nl+=np; 
						np=0;
					}
				} 
			}

			eventlist.clear();
			return return_p;
		}

		float cost(float pl, float pr, int nl, int nr) {
			if(nl==0 || nr==0) {
				return 0.8*(kt+ki*(pl*nl+pr*nr));
			} else {
				return kt+ki*(pl*nl+pr*nr);
			}
		}

		float SAH(Box& bbox, split_plane& p, int nl, int np, int nr) {
			int split_dim=p.dimension;
			float left_cost, right_cost;
			float pl, pr;

			Box LeftBox, RightBox;
			LeftBox.min=bbox.min;
			RightBox.max=bbox.max;
			LeftBox.max=bbox.max; LeftBox.max[split_dim]=p.position;

			RightBox.min=bbox.min; RightBox.min[split_dim]=p.position;
			pl=calc_area(LeftBox)/calc_area(bbox);
			pr=calc_area(RightBox)/calc_area(bbox);
			left_cost=cost(pl,pr,nl+np,nr);
			right_cost=cost(pl,pr,nl,np+nr);
			
			if (left_cost<right_cost){ 
				p.side=0;
				return left_cost;
			} else {	
				p.side=1;
				return right_cost;
			}

			return left_cost; //never reached
		}

		float calc_area(Box &bbox) {
			float length_x, length_y, length_z;

			if(bbox.min.x()>=0 && bbox.max.x()>=0) {
				length_x=bbox.max.x()-bbox.min.x();
			} else {
				if(bbox.min.x()<0 && bbox.max.x()<0) {
					length_x=-(bbox.min.x())+bbox.max.x();
				} else {
					length_x=bbox.max.x()-bbox.min.x();
				}
			}
			
			if(bbox.min.y()>=0 && bbox.max.y()>=0) {
				length_y=bbox.max.y()-bbox.min.y();
			} else {
				if(bbox.min.y()<0 && bbox.max.y()<0) {
					length_y=-(bbox.min.y())+bbox.max.y();
				} else {
					length_y=bbox.max.y()-bbox.min.y(); 
				}
			}
			
			if(bbox.min.z()>=0 && bbox.max.z()>=0) {
				length_z=bbox.max.z()-bbox.min.z();
			} else {
				if(bbox.min.z()<0 && bbox.max.z()<0) {
					length_z=-(bbox.min.z())+bbox.max.z();
				} else {
					length_z=bbox.max.z()-bbox.min.z();
				}
			}

			return (2*(length_x*length_y+length_y*length_z+length_z*length_x)); 
		}

		SAHKDTree(Box& start_box, std::vector<Primitive*>& prim) {
			maxdepth=15;
			min_tri=3;
			kt=20;
			ki=15;
			root=NULL;
			root=buildKDTree(prim, start_box, 0);
			//cout<<" Finished!\n";
		}

		bool Intersect(Ray &ray) {
			root->traverse(ray);
			return ray.hit != NULL;
		}
};

#endif