#ifndef LENSCAMERA_HXX
#define LENSCAMERA_HXX

#include "Camera.hxx"

/* This class simulates a camera with a lens and the resulting depth of field. 
** It is bases on the chapter 'Camera Models' 
** in the book "Principles of Digital Image Synthesis".
*/
class LensCamera : public Camera
{

private:

    // input values
    Vec3f pos, dir, up;
    float focus, lensradius, sharp_dist;
  
    // preprocessed values
    Vec3f xAxis, yAxis, zAxis, proj_center;
    float proj_dist;
    float aspect;

    // generate randomly a point on lens 
    Vec3f getLensPoint(bool inside)
    {
        float n;
	if(inside)
	    n = 0.0f;  // inner circle
	else
	    n = 1.0f;  // outer circle

	float r = (frand() + n) * (lensradius/2.0f);
	float phi = frand() * 2.0f * static_cast<float>(M_PI);
	
	Vec3f point = pos;
	point += r * cosf(phi) * xAxis + r * sinf(phi) * yAxis;

	return point;
    }


public:

    LensCamera(Vec3f pos,Vec3f _dir,Vec3f up,int resX, int resY, 
	       float f, float radius, float sharp)   
      : Camera(resX,resY),pos(pos),up(up),focus(f),lensradius(radius),sharp_dist(sharp)
        {
  	    // viewing direction
	    dir = _dir;
	    Normalize(dir);
	    
	    // setup coordinate system
	    zAxis = dir;
	    xAxis = Cross(dir,up);
	    yAxis = Cross(xAxis,zAxis);
	    
	    Normalize(xAxis);
	    Normalize(yAxis);
	    Normalize(zAxis);
	    
	    aspect = static_cast<float>(resX) / static_cast<float>(resY);
	    
	    // compute distance between lens and projection plane
	    proj_dist = (sharp_dist * focus) / (sharp_dist - focus);

	    // compute center of projection plane
	    proj_center =  pos - proj_dist * zAxis;
	
	}

        virtual ~LensCamera()
        {}

	virtual bool InitRay(float x, float y, Ray &ray)
	{
	    // compute point on image plane
	    Vec3f image_point = proj_center;
	    image_point = image_point + (((static_cast<float>(resX) - (x+0.5f)) / static_cast<float>(resX) - 0.5f) * xAxis) / (focus * 1000);
	    image_point = image_point + (((static_cast<float>(resY) - (y+0.5f)) / static_cast<float>(resY) - 0.5f) * yAxis / aspect) / (focus * 1000);

  
	    // define focal plane in distance sharp_dist
	    Vec3f normal = zAxis;
	    Vec3f plane_point = pos + sharp_dist * zAxis;

	    // define direction from image_point to Q
	    Vec3f direction = pos - image_point;
	    Normalize(direction);

	    // compute intersection with focal plane
	    Vec3f Q = image_point + ( Dot(plane_point - image_point, normal) / Dot(direction, normal) ) * direction;

	    // initialize ray   
	    ray.org = getLensPoint(ray.inside);
	    ray.dir = Q - ray.org;
	    Normalize(ray.dir);
	    ray.t = Infinity; 
	    ray.hit = NULL;
	    ray.u = ray.v = 0.0;	
	    
	    return true;
	};
};
#endif