#ifndef CSG_HXX_
#define CSG_HXX_

#include "../vector/Vec3f.hxx"
#include "Primitive.hxx"
#include "../ray/Ray.hxx"

#include <list> 

class CSGPrimitive : public Primitive {

public:
	enum CSGNODETYPE {
		NODE = 0,
		LEAF = 1
	};

	enum CSGOPTYPE {
		UNION = 0, 
		DIFFERENCE = 1,
		INTERSECTION = 2
	};

	enum CSGCLASSIFICATION {
		ENTER = 0,
		EXIT  = 1,
		MISS  = 2,
	}; 

	class Intersection {
	public:
		float             t;
		Vec3f            normal;
		Primitive         *hit;
		CSGCLASSIFICATION cl;

		Intersection(float t, Vec3f normal, Primitive *hit, CSGCLASSIFICATION cl) :
			t(t), normal(normal), hit(hit), cl(cl) {};
		
		~Intersection() {};
	};

	typedef std::list<Intersection *> CSGilist;

	class CSGNode {
	public:
		class CSGOp {
		public :
			virtual void operation(CSGilist &r, CSGilist &l1, CSGilist &l2) = 0;
			CSGOp() {};
			virtual ~CSGOp() {};
		};

		class CSGUnion : public CSGOp {
			virtual void operation(CSGilist &r, CSGilist &l1, CSGilist &l2) {
				int status = 0;
				CSGilist::iterator l1_iter = l1.begin();
				CSGilist::iterator l2_iter = l2.begin();
				CSGilist::iterator l1_end = l1.end();
				CSGilist::iterator l2_end = l2.end();
				Intersection *is1, *is2, *ist;

				for(;(l1_iter != l1_end) && (l2_iter != l2_end);) {
					is1 = *l1_iter;
					is2 = *l2_iter;

					if(is1->t > is2->t) {
						ist = is1;
						is1 = is2;
						is2 = ist;
						l2_iter++;
					} else {
						l1_iter++;
					}

					if(is1->cl == ENTER) {
						if(status == 0) {
							status++;
							
							r.push_back(is1);
						} else {
							delete is1;
						}
					} else { 
						if(is1->cl == EXIT) {
							if(is2->cl == ENTER) {
								status--;
								
								r.push_back(is1);
							} else {
								delete is1;
							}
						}
					}
				}

				if(l1_iter == l1_end) {
					l1_iter = l2_iter;
					l1_end = l2_end;
				}
				
				for(;l1_iter != l1_end;) {
					r.push_back(*l1_iter++);
				}
			};
			
			~CSGUnion() {};
		};

		class CSGIntersection : public CSGOp {
			virtual void operation(CSGilist &r, CSGilist &l1, CSGilist &l2) {
				int status = 0;
				CSGilist::iterator l1_iter = l1.begin();
				CSGilist::iterator l2_iter = l2.begin();
				CSGilist::iterator l1_end = l1.end();
				CSGilist::iterator l2_end = l2.end();
				Intersection *is1, *is2, *ist;

				for(;(l1_iter != l1_end) && (l2_iter != l2_end);) {
					is1 = *l1_iter;
					is2 = *l2_iter;

					if(is1->t > is2->t) {
						ist = is1;
						is1 = is2;
						is2 = ist;
						l2_iter++;
					} else {
						l1_iter++;
					}

					if(is1->cl == ENTER) {
						status++;
						
						if (status == 2) {
							r.push_back(is1);
						} else {
							delete is1;
						}
					} else { 
						if (is1->cl == EXIT) {
							status--;
							
							if(status == 1) {
								r.push_back(is1);
							} else {
								delete is1;
							}
						}
					}
				}
			};
			
			~CSGIntersection() {};
		};

		class CSGDifference : public CSGOp {
			virtual void operation(CSGilist &r, CSGilist &l1, CSGilist &l2) {
				CSGilist::iterator l1_iter = l1.begin();
				CSGilist::iterator l2_iter = l2.begin();
				CSGilist::iterator l1_end = l1.end();
				CSGilist::iterator l2_end = l2.end();
				Intersection *is1, *is2, *ist;

				for(;(l1_iter != l1_end) && (l2_iter != l2_end);) {
					is1 = *l1_iter;
					is2 = *l2_iter;

					if(is1->t > is2->t) {
						if(is2->cl == ENTER) {
							if(is1->cl == EXIT) {
								delete is2;
							} else { 
								r.push_back(is2);
							}
							
							is2->normal =  is2->normal;
						} else {
							if(is1->cl == EXIT) {
								delete is2;
							} else { 
								r.push_back(is2);
							}
							
							is2->normal =  is2->normal;
						}
						
						l2_iter++;
					} else {
						if(is1->cl == ENTER) {
							if (is2->cl == EXIT) {
								r.push_back(is1);
								is1->normal = - is1->normal;
							} else { 
								delete is1;
							}
						} else {
							if(is2->cl == EXIT) {
								r.push_back(is1);
								is1->normal = - is1->normal;
							} else { 
								delete is1;
							}
						}
						
						l1_iter++;
					}
				}

				for(;l2_iter != l2_end;) {
					r.push_back(*l2_iter++);
				}
			};
			
			~CSGDifference() {};
		};

		CSGNODETYPE  nodetype;
		CSGOPTYPE    optype;   
		CSGNode     *lson;
		CSGNode     *rson;
		Primitive   *primitive;
		CSGOp       *csgop;

		CSGNode(CSGOPTYPE optype, CSGNode *lson, CSGNode *rson) : lson(lson), rson(rson), optype(optype) {
			nodetype = NODE;
			if (optype == UNION) {
				csgop = new CSGUnion();
			} else { 
				if (optype == INTERSECTION) {
					csgop = new CSGIntersection();
				} else { 
					if (optype == DIFFERENCE) {
						csgop = new CSGDifference();
					}
				}
			}
		};

		CSGNode(Primitive *primitive) : primitive(primitive) {
			nodetype = LEAF;
		};

		void Intersect(Ray &ray, CSGilist &ilist) {
			bool ret;

			ilist.clear();
			
			if(nodetype == LEAF) {
				float t = 0;
				Vec3f org = ray.org;
				ray.t = 0;
				
				do{
					t += ray.t;
					ray.org = ray.org + t * ray.dir;
					ray.t = Infinity;
					ret = primitive->Intersect(ray);
					if(ret)  {
						Vec3f norm = primitive->GetNormal(ray);
						CSGCLASSIFICATION c = (Dot(ray.dir, norm) > 0.0f ? EXIT : ENTER);
						ilist.push_back(new Intersection(t + ray.t, norm, primitive, c));
					}
				} while (ret);
				
				ray.org = org;
			} else {
				if (nodetype == NODE) {
					CSGilist rlist;
					CSGilist llist;
					rson->Intersect(ray, rlist);
					lson->Intersect(ray, llist);
					csgop->operation(ilist, rlist, llist);
				}
			}
		};

		Box CalcBounds(void) {
			if (nodetype == LEAF) {
				return primitive->CalcPBounds();
			} else {
				if (nodetype == NODE) {
					Box lbounds = lson->CalcBounds();
					Box rbounds = rson->CalcBounds();
					lbounds.Extend(rbounds);
					return lbounds;
				}
			}

			return Box();
		};

	};

	CSGNode *csgroot;
	Intersection *intersection;

	CSGPrimitive(CSGNode *csgroot) 
	: Primitive(), csgroot(csgroot) {
	};

	virtual bool Intersect(Ray &ray) {
		Ray tray = Ray(ray);
		CSGilist list;

		csgroot->Intersect(tray, list);
		Intersection *is;
		
		if((list.size() != 0)) {
			for(unsigned int i = 0; i < list.size(); i++) {
				is = list.front();
				if (is->t > Epsilon)
					break;
				list.pop_front();
			}
			
			Vec3f ip = tray.org + tray.dir * is->t;

			float t = (is->t == 0 ? 0 : fabsf(is->t) / is->t) * Length(ip - ray.org);
			
			if((ray.t < t) || (t < Epsilon)) {
				return false;
			}
			
			intersection = is;
			ray.hit = this;
			ray.t = t;
			return true;
		}
		else {
			return false; 
		}
	};

	virtual Vec3f GetNormal(Ray &ray) {
		Vec3f norm = intersection->normal;
		Normalize(norm);
		return norm;
	};

	virtual Box CalcBounds() {
		return csgroot->CalcBounds();
	};

	virtual Box CalcPBounds() {
		return csgroot->CalcBounds();
	};
};

#endif /*CSG_HXX_*/