#ifndef SMOOTH_TRIANGLE_HXX
#define SMOOTH_TRIANGLE_HXX

#include "Triangle.hxx"

class SmoothTriangle : public Triangle
{

private:

    Vec3f na, nb, nc;

public:

    //! Constructor
    SmoothTriangle(Vec3f a, Vec3f b, Vec3f c)
        : Triangle(a,b,c)
    {};

    //! Constructor
    SmoothTriangle(Vec3f a, Vec3f b, Vec3f c, Vec3f na, Vec3f nb, Vec3f nc)
        : Triangle(a,b,c),na(na), nb(nb), nc(nc)
    {};

    //! Destructor
    virtual ~SmoothTriangle()
    {
    }

    //! Set vertex normals
    void SetNormals(const Vec3f& _na, const Vec3f& _nb, const Vec3f& _nc)
    {
        na = _na;
        nb = _nb;
        nc = _nc;

        Normalize(na);
        Normalize(nb);
        Normalize(nc);
    }

    //! Get normal, but this time it is a smooth normal
    Vec3f GetNormal(Ray &ray)
    {
        // assume u/v coordinates in ray correspond to beta(u) and gamma
        // (v) barycentric coordinates of hit point on triangle (have to be
        // stored like this in the intersection code !)

        Vec3f normal = ray.u * nb + ray.v * nc + (1 - ray.u - ray.v) * na;
        Normalize(normal);

        return normal;
    };

    /* Transformations */

    // Rotation around major axis
    virtual void Rotation(int axis, float theta)
    {
        theta = theta * M_PI / 180.0f;

	float tmp = 0.0f;
      
	// rotate
	if( axis == 0 )
	{
	    // around x axis
	    tmp = a.y() * cosf(theta) - a.z() * sinf(theta);
	    a.z() = a.y() * sinf(theta) + a.z() * cosf(theta);
	    a.y() = tmp;
	    
	    tmp = b.y() * cosf(theta) - b.z() * sinf(theta);
	    b.z() = b.y() * sinf(theta) + b.z() * cosf(theta);
	    b.y() = tmp;
	    
	    tmp = c.y() * cosf(theta) - c.z() * sinf(theta);
	    c.z() = c.y() * sinf(theta) + c.z() * cosf(theta);
	    c.y() = tmp;
	  
	    tmp = na.y() * cosf(theta) - na.z() * sinf(theta);
	    na.z() = na.y() * sinf(theta) + na.z() * cosf(theta);
	    na.y() = tmp;
	    
	    tmp = nb.y() * cosf(theta) - nb.z() * sinf(theta);
	    nb.z() = nb.y() * sinf(theta) + nb.z() * cosf(theta);
	    nb.y() = tmp;
	    
	    tmp = nc.y() * cosf(theta) - nc.z() * sinf(theta);
	    nc.z() = nc.y() * sinf(theta) + nc.z() * cosf(theta);
	    nc.y() = tmp;
	}

	else if( axis ==  1)
	{
	    // around y axis
	    tmp = a.x() * cosf(theta) + a.z() * sinf(theta);
	    a.z() = a.z() * cosf(theta) - a.x() * sinf(theta);
	    a.x() = tmp;
	    
	    tmp = b.x() * cosf(theta) + b.z() * sinf(theta);
	    b.z() = b.z() * cosf(theta) - b.x() * sinf(theta);
	    b.x() = tmp;
	    
	    tmp = c.x() * cosf(theta) + c.z() * sinf(theta);
	    c.z() = c.z() * cosf(theta) - c.x() * sinf(theta);
	    c.x() = tmp;
	    
	    tmp = na.x() * cosf(theta) + na.z() * sinf(theta);
	    na.z() = na.z() * cosf(theta) - na.x() * sinf(theta);
	    na.x() = tmp;

	    tmp = nb.x() * cosf(theta) + nb.z() * sinf(theta);
	    nb.z() = nb.z() * cosf(theta) - nb.x() * sinf(theta);
	    nb.x() = tmp;
	    
	    tmp = nc.x() * cosf(theta) + nc.z() * sinf(theta);
	    nc.z() = nc.z() * cosf(theta) - nc.x() * sinf(theta);
	}
	
	else
	{
	    // around z axis
	    tmp = a.x() * cosf(theta) - a.y() * sinf(theta);
	    a.y() = a.x() * sinf(theta) + a.y() * cosf(theta);
	    a.x() = tmp;
	    
	    tmp = b.x() * cosf(theta) - b.y() * sinf(theta);
	    b.y() = b.x() * sinf(theta) + b.y() * cosf(theta);
	    b.x() = tmp;
	    
	    tmp = c.x() * cosf(theta) - c.y() * sinf(theta);
	    c.y() = c.x() * sinf(theta) + c.y() * cosf(theta);
	    c.x() = tmp;
	    
	    tmp = na.x() * cosf(theta) - na.y() * sinf(theta);
	    na.y() = na.x() * sinf(theta) + na.y() * cosf(theta);
	    na.x() = tmp;

	    tmp = nb.x() * cosf(theta) - nb.y() * sinf(theta);
	    nb.y() = nb.x() * sinf(theta) + nb.y() * cosf(theta);
	    nb.x() = tmp;
	    
	    tmp = nc.x() * cosf(theta) - nc.y() * sinf(theta);
	    nc.y() = nc.x() * sinf(theta) + nc.y() * cosf(theta);
	    nc.x() = tmp;
	}
    };
};

#endif