#ifndef CRYSTALSHADER_HXX
#define CRYSTALSHADER_HXX

#include "Shader.hxx"

/* This class calculates shading for a uniaxial crystal 
** based on the algorithm described in 
** "Realistic Rendering of Birefringency in Uniaxial Crystals"
*/ 
class CrystalShader : public Shader
{

private:

    // refraction index of ordinary ray
    float n_o; 
    
    // refraction index of extraordinary ray
    float n_e;

    // optical axis 
    Vec3f A;

public:
      
    CrystalShader(Scene *scene, float ref_o, float ref_e, Vec3f axis)
      : Shader(scene),n_o(ref_o),n_e(ref_e),A(axis)
    {
        Normalize(A);
    };
    
    virtual Vec3f Shade(Ray &ray)
    {
        // reflection and refraction rays
        Ray reflection;
	Ray refraction;
	Ray refrac_extra;

	// final color
	Vec3f result;

	// store if ray propagates inside crystal
	bool inside = false;

	// get shading normal
        Vec3f normal = ray.hit->GetNormal(ray);

	// test if we are inside crystal
	if (Dot(normal,ray.dir) > 0)
        {
	    // turn normal to front
	    normal = -normal;
	
	    inside = true;
	}

	// calculate reflection vector
	Vec3f reflect = ray.dir - 2*Dot(normal,ray.dir)*normal;
	Normalize(reflect);

	// initialize reflection ray
	reflection.org = ray.org + ray.t * ray.dir;
	reflection.dir = reflect;
	reflection.t = Infinity;
	reflection.hit = NULL;
	reflection.u = reflection.v = 0.0;
	
	// get color of reflection
	Vec3f reflectColor = scene->RayTrace(reflection);
	    
	// compute refraction part

	// ratio of refraction indices
	// ref. index of air ~1.0
	float n = n_o / 1.0f;
	
	// compute transmission part
	float c = Dot(normal, -ray.dir);
	float term = 1.0f - n * n * (1.0f - c * c);
	
	// testing for total internal reflection
	if(term < 0)
	{
	    result = reflectColor;
	}
	
	else
	{
	    // calculate ordinary refraction vector with Snell's law
	    Vec3f refract = (n * c - sqrtf(term)) * normal + n * ray.dir; 
	    Normalize(refract);
	    
	    // initialize ordinary refraction ray
	    refraction.org = ray.org + ray.t * ray.dir;
	    refraction.dir = refract;
	    refraction.t = Infinity;
	    refraction.hit = NULL;
	    refraction.u = refraction.v = 0.0;
	    
	    // get color of refraction
	    Vec3f refractColor = scene->RayTrace(refraction);
	    
	    // only one refraction ray
	    if(inside)
	    {
	        // compute reflection and refraction cofficients by solving Fresnel's equation
		n = 1.0f / n;
		float g = n * n + c * c -1;
		float R = 0.5f * powf(g-c, 2.0f) /  powf(g+c, 2.0f) * 
		  (1.0f + powf( c * (g+c) - 1.0f, 2.0f) / powf( c * (g-c) + 1.0f, 2.0f));
		float T = 1.0f - R;
		
		result = R * reflectColor + T * refractColor;
	    }

	    // ray hits crystal from outside => split refracted part
	    else
	    {
	        // define coordinate system whose z axis is surface normal
	        Vec3f zAxis = normal;
		Vec3f xAxis = ray.hit->getSurfaceVector();
		Vec3f yAxis = Cross(xAxis,zAxis);
		Normalize(yAxis);
		
		// compute direction cosines
		Vec3f A_cos = Vec3f(Dot(A,xAxis), Dot(A,yAxis), Dot(A,zAxis)); 
		Vec3f So_cos = Vec3f(Dot(refract,xAxis), Dot(refract,yAxis), Dot(refract,zAxis)); 		

		// define required terms
		float n_o_2 = n_o * n_o;
		float n_e_2 = n_e * n_e;

		float N = n_o_2 - n_e_2;
		float gamma = n_e_2 + N * (1.0f - A_cos.z() * A_cos.z()); 
		float delta = sqrt(gamma * (n_e_2 - n_o_2 * (1.0f - So_cos.z() * So_cos.z())) + n_o_2 * N * A_cos.x() * A_cos.x());
		float d_eg = sqrt(n_e_2 * (n_o_2 * gamma * gamma - N * powf(A_cos.z() * delta + n_o_2 * A_cos.x(), 2.0f)));

		// calculate direction cosines of extraordinary ray
		float Se_x = n_o_2 * So_cos.x() * (n_e_2 + N * A_cos.y() * A_cos.y()) -  n_o_2 * So_cos.y() * N * A_cos.x() * A_cos.y();
		Se_x = Se_x - A_cos.z() * N * delta * A_cos.x();

		float Se_y = n_o_2 * So_cos.y() * (n_e_2 + N * A_cos.x() * A_cos.x()) -  n_o_2 * So_cos.x() * N * A_cos.x() * A_cos.y();;
                Se_y = Se_y - A_cos.z() * N * delta * A_cos.y();

		float Se_z = gamma * delta;

		Vec3f Se_cos = 1.0f / d_eg * Vec3f(Se_x,Se_y,Se_z);

		Vec3f extra_dir = Se_cos.x() * xAxis + Se_cos.y() * yAxis + Se_cos.z() * zAxis;
		Normalize(extra_dir);

		// initialize extraordinary refraction ray
		refrac_extra.org = ray.org + ray.t * ray.dir;
		refrac_extra.dir = extra_dir;
		refrac_extra.t = Infinity;
		refrac_extra.hit = NULL;
		refrac_extra.u = refrac_extra.v = 0.0;

		// get color of refraction
		Vec3f refractColor_extra = scene->RayTrace(refrac_extra);

		// exact refraction index of extraordinary ray
		float cos_2 = Dot(extra_dir,A) * Dot(extra_dir,A);
		float n_p = n_o * n_e / sqrt(n_o_2 * (1.0f - cos_2) + n_e_2 * cos_2);

		// calculation of Fresnel amplitudes
		float cos_theta = Dot(normal,-ray.dir);
		float sin_theta = sin(acos(cos_theta));
		float tan_theta = tan(acos(cos_theta));

		float e_o = n_o * n_o;
		float e_e = n_p * n_p;

		float q_o = sqrt(e_o - sin_theta * sin_theta);
		float q_e = A.x() * A.z() * sin_theta * (e_o - e_e); 
		q_e += sqrt(e_o * (A.z() * A.z() * e_e * e_e 
				   + A.y() * A.y() * sin_theta * sin_theta * (e_e - e_o) 
				   + e_e * (e_o - sin_theta * sin_theta - A.z() * A.z() * e_o)));
		q_e /= A.z() * A.z() * e_e + e_o - A.z() * A.z() * e_o;

		float N_e = sqrt(powf(sin_theta * A.z() * q_e - A.x() * q_o * q_o, 2.0f) + A.y() * A.y() * e_o * e_o + 
				 powf(sin_theta * A.x() * q_e + A.z() * q_e * q_e - A.z() * e_o, 2.0f));
		float N_o = sqrt(1.0f / (A.y() * A.y() * e_o + powf(A.z() * sin_theta - A.x() * sqrt(- sin_theta * sin_theta + e_o), 2.0f)));

		float E_ox = A.y() * sqrt(- sin_theta * sin_theta + e_o) * N_o;
		float E_oy = (- A.z() * sin_theta + A.x() * sqrt(- sin_theta * sin_theta + e_o)) * N_o;
		float E_oz = A.y() * sin_theta * N_o;

		float E_ex = - (A.x() * q_o * q_o - A.z() * q_e * sin_theta) / N_e;
		float E_ey = A.y() * e_o / N_e;
		float E_ez = - (A.z() - e_o + q_e * q_e - A.x() * q_e * sin_theta) / N_e;

		float A = (q_o + cos_theta + sin_theta * tan_theta) * E_ox - sin_theta * E_oz; 
		float B = (q_e + cos_theta + sin_theta * tan_theta) * E_ex - sin_theta * E_ez;
		float C = (cos_theta - q_e) * A * E_ey - (cos_theta + q_o) * B * E_oy; 
 
		float t_so = -2.0f * cos_theta * B / C;
		float t_po = 2.0f * (cos_theta + q_e) * E_ey / C;
		float t_se = -2.0f * cos_theta * A / C;;
		float t_pe = -2.0f * (cos_theta + q_o) * E_oy / C;;

		//calculation of Fresnel coefficients
		float T_so = n_o * Dot(-normal,refract) / cos_theta * t_so * t_so;
		float T_po = n_o * Dot(-normal,refract) / cos_theta * t_po * t_po;
		float T_se = n_p * Dot(normal,extra_dir) / cos_theta * t_se * t_se;
		float T_pe = n_p * Dot(normal,extra_dir) / cos_theta * t_pe * t_pe;;

		if(T_po > 1.0f)
		    T_po = 0.5f;

		if(T_se > 1.0f)
		    T_se = 0.5f;
	
		float T_o = (T_so + T_po) / 2.0f;
		float T_e = (T_se + T_pe) / 2.0f; 

		float R = 1.0f - T_o - T_e;
	
		result = R * reflectColor + T_e * refractColor_extra + T_o * refractColor;
	    }
	}

	float a = 0.5f;

	return (a * result + (1.0f-a) * Vec3f(0.9f,1.0f,1.0f));
    };
};

#endif