#ifndef MARBLESHADER_HXX
#define MARBLESHADER_HXX

#include "Shader.hxx"
#include "Noise.hxx"

#include <vector>
 
/* This class uses procedural shading to create a marble like staining */
class MarbleShader : public Shader
{

private:

    std::vector<Vec3f> colors;   // marble colors
  
    // multiplies for pattern adjustment
    float value_m;   // value multiplier
    float x_m;   // x multiplier    
    float y_m;   // y multiplier
    float z_m;   // z multiplier
    
    // Phong shader parameters
    float ka;    // ambient coefficient
    float kd;    // diffuse reflection coefficients
    float ks;    // specular refelection coefficients
    float ke;    // shininess exponent


public:

    MarbleShader(Scene *scene, std::vector<Vec3f> colors, 
		 float value_m = 20.0f, float x_m = 5.0f, float y_m = 10.0f, float z_m = 1.0f, 
		 float ka = 0.0f, float kd = 1.0f, float ks = 0.0f, float ke = 0.0f)
      : Shader(scene),colors(colors),value_m(value_m),x_m(x_m),y_m(y_m),z_m(z_m),ka(ka),kd(kd),ks(ks),ke(ke)
    {};
    
    Vec3f Shade(Ray &ray)
    {
        /* compute marble color */

	Vec3f point = ray.org + ray.t * ray.dir;

	// compute turbulence using Perlin noise
	float value = 0.0f;
	float scale = 8.0f;

	float x = point.x();
	float y = point.y();
	float z = point.z();

	while(scale >= 1.0f)
	{
	    value += fabs(PerlinNoise3D_new(x * scale, y * scale , z * scale)) / scale;
	    scale /= 2.0f;
	}

	value *= value_m;

	value += point.x() * x_m + point.y() * y_m + point.z() * z_m;

	value = sinf(value);

	int index = (int)floor( fabs(value) * (float)(colors.size() - 1) );

	Vec3f color = lerp(colors[index],colors[index+1],value);


	/* use Phong shading with computed color */

	// get shading normal
        Vec3f normal = ray.hit->GetNormal(ray);
        
        // turn normal to front
        if (Dot(normal,ray.dir) > 0)
	    normal = -normal;
        
        // calculate reflection vector
        Vec3f reflect = ray.dir - 2*Dot(normal,ray.dir)*normal;
        
        // ambient term
        Vec3f ambientIntensity(1,1,1);
        Vec3f ambientColor = ka * color;
        Vec3f result = Product(ambientColor, ambientIntensity);
        
        // shadow ray (up to now only for the light direction)    
        Ray shadow;
        shadow.org = ray.org + ray.t * ray.dir;
        
        // iterate over all light sources
        for (unsigned int l=0; l < scene->mLights.size(); l++)
	{      
            // get direction to light, and intensity 
            Vec3f lightIntensity;
            Vec3f result_local = Vec3f(0.0);

            // check whenever the shader is computing area light source
            for(unsigned int s = 0; s < scene->mLights[l]->GetNumberOfRays(); s++)
	    {
                // illuminate the ray by the light source
                if (scene->mLights[l]->Illuminate(shadow, lightIntensity, s))
		{
		    // compute distance to the light source 
		    float distance = shadow.t;
        
		    // diffuse term, also used as a check if illuminating the front-side
		    float cosLightNormal = Dot(shadow.dir,normal);
		    if (cosLightNormal > 0)
		    {
			// if the ray is occluded, hence there is a shadow
			// we put this here to optimize all things, since we work only for in-front surfaces
			if (scene->castShadows && scene->Intersect(shadow) && shadow.hit->castShadows())
			{
			    // if the shadow ray intersects a surface which is near then the light source, then we 
			    // are in shadow
			    if (shadow.t < distance)
			        continue;
			}
        
			// compute diffuse term            
			Vec3f diffuseColor = kd * color;
			result_local = result_local + Product(diffuseColor * cosLightNormal,
							      lightIntensity);
			            
			// specular term is computed only if shading the front-side
			float cosLightReflect = Dot(shadow.dir,reflect);
			if (cosLightReflect > 0)
			{
			    Vec3f specularColor = ks * Vec3f(1,1,1); // white highlight;
			    result_local = result_local + Product(specularColor * powf(cosLightReflect,ke),
								  lightIntensity);
			}
		    }
		}
	    }  
  
            result += result_local / float(scene->mLights[l]->GetNumberOfRays());
	}
        
     	return result;
    };
};

#endif