#ifndef TRIANGLE_HXX
#define TRIANGLE_HXX

#include "Primitive.hxx"

#include "math.h"

class Triangle : public Primitive
{
private:

    Box CalcBounds()
    {
        Box bounds_;
        
        bounds_.Extend(a);
        bounds_.Extend(b);
        bounds_.Extend(c);

        return bounds_;
    }

protected:

    Vec3f a, b, c;

public:

    //! Constructor
    Triangle(Vec3f a, Vec3f b, Vec3f c)
        : Primitive(),a(a),b(b),c(c)
    {};

    //! Destructor
    virtual ~Triangle()
    {
    }
    
    //! Compute ray-triangle intersection
    bool Intersect(Ray &ray)
    {
        const Vec3f edge1 = b-a;
        const Vec3f edge2 = c-a;

        const Vec3f pvec = ray.dir^edge2;

        const float det = Dot(edge1, pvec);
        if (fabs(det) < Epsilon) return false;

        const float inv_det = 1.0f / det;

        const Vec3f tvec = ray.org-a;
        float lambda = Dot(tvec, pvec);
        lambda *= inv_det;

        if (lambda < 0.0f || lambda > 1.0f) return false;

        const Vec3f qvec = tvec^edge1;
        float mue = Dot(ray.dir, qvec);
        mue *= inv_det;

        if (mue < 0.0f || mue+lambda > 1.0f) return false;

        float f = Dot(edge2, qvec);
        f = f * inv_det - Epsilon;
        if (ray.t <= f || f <  Epsilon  ) return false;

        ray.u = lambda;
        ray.v = mue;
        ray.t = f;
        ray.hit = this;

        return true;
    };

    //! Get triangle's normal
    virtual Vec3f GetNormal(Ray &ray)
    {
        Vec3f e1 = b - a;
        Vec3f e2 = c - a;
        Vec3f normal = Cross(e1,e2);
        Normalize(normal);
        return normal;
    };
 
    virtual Vec3f getSurfaceVector()
    {
      Vec3f vec = b - a;
      Normalize(vec);
      return vec;
    }

    /* Transformations */

    virtual void Translation(Vec3f d)
    {
        a += d;
	b += d;
	c += d;
    };

    // Rotation around major axis
    virtual void Rotation(int axis, float theta)
    {
        theta = theta * M_PI / 180.0f;

	float tmp = 0.0f;

	// rotate 
	if( axis == 0 ) 
	{
	    // around x axis
	    tmp = a.y() * cosf(theta) - a.z() * sinf(theta);
	    a.z() = a.y() * sinf(theta) + a.z() * cosf(theta);
	    a.y() = tmp;	    

	    tmp = b.y() * cosf(theta) - b.z() * sinf(theta);
	    b.z() = b.y() * sinf(theta) + b.z() * cosf(theta);
	    b.y() = tmp;	    

	    tmp = c.y() * cosf(theta) - c.z() * sinf(theta);
	    c.z() = c.y() * sinf(theta) + c.z() * cosf(theta);
	    c.y() = tmp;
	}

	else if( axis ==  1) 
	{
	    // around y axis
	    tmp = a.x() * cosf(theta) + a.z() * sinf(theta);
	    a.z() = a.z() * cosf(theta) - a.x() * sinf(theta);
	    a.x() = tmp;	    

	    tmp = b.x() * cosf(theta) + b.z() * sinf(theta);
	    b.z() = b.z() * cosf(theta) - b.x() * sinf(theta);
	    b.x() = tmp;

	    tmp = c.x() * cosf(theta) + c.z() * sinf(theta);
	    c.z() = c.z() * cosf(theta) - c.x() * sinf(theta);
	    c.x() = tmp;
	}
	
	else
        {
	    // around z axis
	    tmp = a.x() * cosf(theta) - a.y() * sinf(theta);
	    a.y() = a.x() * sinf(theta) + a.y() * cosf(theta);
	    a.x() = tmp;	    

	    tmp = b.x() * cosf(theta) - b.y() * sinf(theta);
	    b.y() = b.x() * sinf(theta) + b.y() * cosf(theta);
	    b.x() = tmp;	    

	    tmp = c.x() * cosf(theta) - c.y() * sinf(theta);
	    c.y() = c.x() * sinf(theta) + c.y() * cosf(theta);
	    c.x() = tmp;
	}
    };

    // Rotation around arbitrary point
    virtual void Rotation(int axis, Vec3f point, float theta)
    {
      this->Translation(-point);
      this->Rotation(axis, theta);
      this->Translation(point);
    };

    // only uniform scaling possible
    virtual void Scaling(float s)
    {
        a *= s;
	b *= s;
	c *= s;
    };
};

#endif