#ifndef TRANSPARENTPHONGSHADER_HXX_
#define TRANSPARENTPHONGSHADER_HXX_

#include "Shader.hxx"

#define AREALIGHT 
#define NUM_AREA_SAMPLES 1

class TransparentPhongShader : public Shader {
public:
	Vec3f color;
	float ka, kd, ks, ke, tr, rr, rl;

	TransparentPhongShader(Scene *scene, Vec3f color, float ka,float kd, float ks, float ke, 
			float trans=0.0, float refr=0.0, float refl=0.0 )
	: Shader(scene),color(color),ka(ka),kd(kd),ks(ks),ke(ke),tr(trans),rr(refr),rl(refl) {};

	virtual Vec3f Shade(Ray &ray) {
		Vec3f normal = ray.hit->GetNormal(ray);

		if (Dot(normal,ray.dir) > 0) {
			normal = -normal;
		}

		Vec3f reflect = ray.dir - 2*Dot(normal,ray.dir)*normal;
		Vec3f result = ka * color;

		Ray shadow;
		shadow.org = ray.org + ray.t * ray.dir;

		float diffuse = kd;
		float specular = ks;

		for(unsigned int l=0; l < scene->mLights.size(); l++) {
			Vec3f intensity = Vec3f(1);

			for(unsigned int s = 0; s < scene->mLights[l]->GetNumberOfRays(); s++) {
				if (scene->mLights[l]->Illuminate(shadow, intensity)) {

					float cosLightNormal = Dot(shadow.dir,normal);
					if (cosLightNormal > 0) {
						if(!scene->castShadows) {
							diffuse = diffuse * 0.5;        
							specular = specular * 0.5;
						}

						Vec3f diffuseColor = diffuse * color;
						result +=  Product(diffuseColor * cosLightNormal, intensity);
					}

					float cosLightReflect = Dot(shadow.dir,reflect);
					if (cosLightReflect > 0) {
						result += Product(specular * Vec3f(1,1,1) * powf(cosLightReflect,ke),intensity);
					}
				} 
			}
			
			result = result / scene->mLights[l]->GetNumberOfRays();
		}

		if (rl > 0) {
			Ray reflec;
			reflec.org = ray.org + ray.t * ray.dir;
			reflec.dir = ray.dir - 2*normal*Dot(normal,ray.dir);
			reflec.t = Infinity;
			reflec.hit = NULL;

			result += rl*scene->RayTrace(reflec);
		}


		if(tr > 0.0) {
			Ray refrac;
			refrac.org = ray.org + ray.t*ray.dir;

			Normalize(normal);
			Normalize(ray.dir);

			float dot_prd = Dot(ray.dir, normal);
			float angle_cos = sqrt(1.0 - rr * rr* (1.0 - dot_prd * dot_prd));

			refrac.dir = rr * ray.dir - normal*(angle_cos  + rr * dot_prd);
			refrac.t = Infinity;
			refrac.hit = NULL;

			Vec3f refres = scene->RayTrace(refrac);      
			result = (1.0 - tr) *result + tr * refres;
		}
		
		// get texture coordinates
	    for(unsigned int i=0; i < getTextureCount(); i++) {
	        // get texture color
	        const Vec4f& tex = getTexture(i)->GetTexel(ray.hit->GetUV(ray));
	
	        // combine texture color with the computed result
	        result = result * tex.xyz();
	    }

		return result;
	};
};


#endif /*TRANSPERANTPHONGSHADER_HXX_*/