#ifndef DIAMONDSHADER_HXX
#define DIAMONDSHADER_HXX

#include "Shader.hxx"

/* This class calculates shading for isotropic crystals, e.g. diamonds,
** using Snell's law and Fresnel's equation. 
*/
class DiamondShader : public Shader
{

private:

    // refraction index of diamond
    float refIndex;

public:
      
    DiamondShader(Scene *scene, float refId)
      : Shader(scene),refIndex(refId)
    {};
    
    /* shading of isotropic crystals, e.g. diamonds */
    virtual Vec3f Shade(Ray &ray)
    {
        // reflection and refraction ray 
        Ray reflection;
	Ray refraction;

	// final color
	Vec3f result;

	// ratio of refraction indices
	// ref. index of air ~1.0
	float n = 1.0f / refIndex;

        // get shading normal
        Vec3f normal = ray.hit->GetNormal(ray);

        // turn normal to front
        if (Dot(normal,ray.dir) > 0)
        {
	    normal = -normal;
	
	    //we are inside
	    n = refIndex / 1.0f;
	}        

	// calculate reflection vector
        Vec3f reflect = ray.dir - 2*Dot(normal,ray.dir)*normal;
	Normalize(reflect);

	// initialize reflection ray
        reflection.org = ray.org + ray.t * ray.dir;
        reflection.dir = reflect;
        reflection.t = Infinity;
        reflection.hit = NULL;
        reflection.u = reflection.v = 0.0;

	// get color of reflection
	Vec3f reflectColor = scene->RayTrace(reflection);

	// compute transmission part
	float c = Dot(normal, -ray.dir);
	float term = 1.0f - n * n * (1.0f - c * c);
	
	// testing for total internal reflection
	if(term < 0)
	{
	    result = reflectColor;
	}

	else
	{
 	    // calculate refraction vector with Snell's law
	    Vec3f refract = (n * c - sqrtf(term)) * normal + n * ray.dir; 
	    Normalize(refract);
	    
	    // initialize refraction ray
	    refraction.org = ray.org + ray.t * ray.dir;
	    refraction.dir = refract;
	    refraction.t = Infinity;
	    refraction.hit = NULL;
	    refraction.u = refraction.v = 0.0;

	    // get color of refraction
	    Vec3f refractColor = scene->RayTrace(refraction);

	    // compute reflection and refraction cofficients by solving Fresnel's equation
            n = 1.0f / n;
            float g = n * n + c * c -1;
	    float R = 0.5f * powf(g-c, 2.0f) /  powf(g+c, 2.0f) * 
	      (1.0f + powf( c * (g+c) - 1.0f, 2.0f) / powf( c * (g-c) + 1.0f, 2.0f));
	    float T = 1.0f - R;

	    result = R * reflectColor + T * refractColor;	      
	}

        return result;
    };
};

#endif