package descriptors;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.Polygon;
import java.awt.image.BufferedImage;
import java.io.File;
import javax.imageio.ImageIO;

import preprocessing.Lanczos;
import preprocessing.Util;

public class hog {

	public double[] HOG (
			BufferedImage original,
			int rwt,
			int number_of_cells_x,
			int number_of_cells_y,
			int bins_per_cell,
			int new_height,
			boolean normalize_image,
			int option, /*0 - Nothing, 1 - Step, 2 - Bernstein*/
			boolean normalize_histogram_L1) {

		boolean debug = false;

		BufferedImage resize = null;		
		
		if (new_height > 0) {
			/*Lanczos lanczos = new Lanczos();
            resize = lanczos.lanczos(original, new_height);
            if (resize.getHeight() != new_height) {*/
            	resize = Util.imageResize (original, new_height);
            //}
		}
		else {
			resize = Util.imageScale (original, 1.0);
		}

		/*Resized image width*/
		int width  = resize.getWidth();

		/*Resized image height*/
		int height = resize.getHeight();

		assert(height == new_height);

		int n = width * height;

		double[] grey = new double[n];

		double[] grel = new double[n];

		double[] dx = new double[n];

		double[] dy = new double[n];

		double[] dnorm = new double[n];

		double[] dtheta = new double[n];

		double[] weight = new double[2*rwt+1];

		compute_weights (weight, rwt);

		/*for (int i = 0; i < 2*rwt+1; i++) {
                System.err.printf(" %8.6f ", weight[i]);
        }
        System.err.printf("\n");*/

		if (normalize_image) {
			get_grey_image (resize, grey);
			normalize_grey_image (grey, width, height, weight, rwt, grel);
			
			/*Character distribution*/
			/*int tmp_rwt = 16;
			double[] tmp = new double[n];
			for (int i = 0; i < n; i++) { tmp[i] = grey[i]; }
			double[] dist_weight = new double[2*tmp_rwt+1];
			compute_weights (dist_weight, tmp_rwt);
			distribution (width, height, dist_weight, tmp_rwt, grel, tmp);
			Util.writeDoubletoPGM (tmp, width, height, 0, 255, "top_bot");*/
		}
		else {
			get_grey_image (resize, grel);
		}

		/*Number of Cells*/
		int number_of_cells = number_of_cells_x * number_of_cells_y;

		int total_bins = number_of_cells*bins_per_cell;

		double[][] cell_weight_image = new double[number_of_cells][n];
		
		/*Creating a matrix to hold the cell histogram*/
		double[] cells_histogram = new double[total_bins];

		double eps = 0.02; /*nominal deviation of pixel noise*/

		for (int i = 0; i < total_bins; i++) {
			cells_histogram[i] = 0;
		}

		for (int x = 1; x < (width - 1); x++) {

			for (int y = 1; y < (height - 1); y++) {

				int position = y * width + x;
				
				double[] grad = new double[2];
                                 
				//gradient_sobel (grel, width, height, x, y, grad);
				gradient_simple (grel, width, height, x, y, grad);
                
                dx[position] = grad[0];
                
                dy[position] = grad[1];

				/*Compute the gradient norm but return zero if too small*/
				double d2 = dx[position]*dx[position] + dy[position]*dy[position];

				if (d2 <= eps*eps) {
					dnorm[position] = 0.0;
				}
				else {
					dnorm[position] = Math.sqrt(d2 - eps*eps);
				}

				dtheta[position] = Math.atan2(dy[position], dx[position]);

				if (dtheta[position] < 0) {
					dtheta[position] += Math.PI;
				}	

				int bin = (int) Math.floor(bins_per_cell*(dtheta[position]/Math.PI+1)+0.5);
				bin = bin % bins_per_cell;
				assert ( (bin >= 0) && (bin < bins_per_cell));

				if (option > 0) {
					for (int cx = 0; cx < number_of_cells_x; cx++) {
						
						double wtx = cell_weight (option, number_of_cells_x, cx, x, width);
						
						for (int cy = 0; cy < number_of_cells_y; cy++) {
							
							double wty = cell_weight (option, number_of_cells_y, cy, y, height);
							
							int c_pos = cy * number_of_cells_x + cx;

							int bin_pos = c_pos * bins_per_cell + bin;

							cells_histogram[bin_pos] += dnorm[position]*wtx*wty;
							
							cell_weight_image[c_pos][position] = wtx*wty;
						}
					}
				}
				else {
					int cx = (number_of_cells_x * (x-1))/(width-1);
					assert( (cx >= 0) && (cx < number_of_cells_x)); 
					int cy = (number_of_cells_y * (y-1))/(height-1);
					assert( (cy >= 0) && (cy < number_of_cells_y));
					int c_pos = cy * number_of_cells_x + cx;
					int bin_pos = c_pos * bins_per_cell + bin;
					cells_histogram[bin_pos] += dnorm[position];
					
					for (int cp = 0; cp < number_of_cells; cp++) {
						cell_weight_image[cp][position] = (cp==c_pos?1.0:0.0);
					}
					
				}  
			}
		}

		/*Normalize the histogram of each cell to unit L1 or L2 norm*/
		for (int cy = 0; cy < number_of_cells_y; cy++) {

			for (int cx = 0; cx < number_of_cells_x; cx++) {

				int c_pos = cy * number_of_cells_x + cx;

				double sum = 0.0;

				for (int bin = 0; bin < bins_per_cell; bin++) {
					int bin_pos = c_pos * bins_per_cell + bin;
					double v = cells_histogram [bin_pos];
					sum += (normalize_histogram_L1 ? v : v * v);
				}

				double cell_norm = (normalize_histogram_L1 ? sum + 1.0e-100: Math.sqrt(sum + 1.0e-100));

				for (int bin = 0; bin < bins_per_cell; bin++) {
					int bin_pos = c_pos * bins_per_cell + bin;
					double v = cells_histogram [bin_pos];
					cells_histogram [bin_pos] = v/cell_norm;
					if (debug) {
						System.err.printf("cell : %d,%d bin : %d, sum : %f\n", cx, cy, bin, cells_histogram [bin_pos]);
					}
				}
			}
		}


		if (debug) {
			double dnorm_max = 0.5;
			
			Util.writeDoubletoPGM (grel, width, height, 0.0, 1.0, "norm");
			Util.writeDoubletoPGM (dx, width, height, -1.0, 1.0, "dx");
			Util.writeDoubletoPGM (dy, width, height, -1.0, 1.0, "dy");
			Util.writeDoubletoPGM (dnorm, width, height, 0.0, dnorm_max, "dnorm");
			
			Util.writePolartoPNG (dnorm, dtheta, width, height, dnorm_max, 0.0, Math.PI, "dpolar");
			
			Util.writeDoubletoPGM (dtheta, width, height, 0, Math.PI, "dtheta");
			
			for (int cp = 0; cp < number_of_cells; cp++) {
				Util.writeDoubletoPGM (cell_weight_image[cp], width, height, 0.0, 1.0, "cwt_"+ String.format("%02d", cp));
			}

			try {
				String outname1 = "original.png";
				String outname2 = "resized.png";
				ImageIO.write(original, "png", new File(outname1));
				ImageIO.write(resize, "png", new File(outname2));
			}
			catch (Exception e1) {
				System.err.println("error: fail to save the images");
			}
			
			for (int cy = 0; cy < number_of_cells_y; cy++) {

				for (int cx = 0; cx < number_of_cells_x; cx++) {

					int c_pos = cy * number_of_cells_x + cx;
					
					double[] histogram = new double[bins_per_cell];
					
					for (int bin = 0; bin < bins_per_cell; bin++) {
						int bin_pos = c_pos * bins_per_cell + bin;
						histogram[bin] = cells_histogram [bin_pos];
					}
					print_histogram (resize, 
							         histogram, 
							         number_of_cells_x, 
							         number_of_cells_y, 
							         bins_per_cell, "hog_directions_" + cy + "_" + cx + ".png");
				}
			}
		}

		return cells_histogram;
	}
	
	public void gradient_sobel (double[] image, int width, int height, int x, int y, double[] grad) {

		assert( (x >= 1) && (x < width-1));
		assert( (y >= 1) && (y < height-1));

		int position = y * width + x;
		double vmo = image[position - 1];
		double vpo = image[position + 1];
		double vom = image[position - width];
		double vop = image[position + width];

		double vmm = image[position - 1 - width];
		double vmp = image[position - 1 + width];
		double vpm = image[position + 1 - width];
		double vpp = image[position + 1 + width];


		grad[0] = (vpm + 2*vpo + vpp - vmm - 2*vmo - vmp)/8.0;

		grad[1] = (vmm + 2*vom + vpm - vmp - 2*vop - vpp)/8.0;
	}
	
	public void gradient_simple (double[] image, int width, int height, int x, int y, double[] grad) {

		int position = y * width + x;
		int kxm = position - 1;
		int kxp = position + 1;
		int kym = position - width;
		int kyp = position + width;

		grad[0] = (image[kxp] - image[kxm])/2;

		grad[1] = (image[kyp] - image[kym])/2;

	}
	
	/*Computes the cell weight factor for one axis {z} (x or y). Given: number of cells 
	 * {ncz}, cell index {cz}, pixel index {z} and number of pixels {npz}, all along
	 * that axis. The option selects the weight type. */
	public double cell_weight (int option, int ncz, int cz, int z, int npz) {
		
		if (option == 1) {
			return StepFunc (ncz,cz,(z-0.5)/(npz-2));
		}
		else if (option == 2){
			return Bernstein (ncz-1,cz,(z-0.5)/(npz-2));
		}
		else if (option == 3) {
			return EdgeCore (ncz,cz,(z-0.5)/(npz-2));
		}
		else {
			assert(false);
			return 0.0;
		}
	}
	
	public void print_histogram (
			BufferedImage image,
			double[] histogram, 
			int number_of_cells_x, 
			int number_of_cells_y, 
			int bins_per_cell, 
			String name ) {
		
		//int width = 300;
		//int height = 300;
		
		int width = 600;
		int height = 600;
		
		//int magnify = 120;
		//int magnify = 240;
		int magnify = (int) (Math.sqrt(bins_per_cell) * 150);
		//int magnify = 1000;

		//int radius = 45;
		int radius = 90;
		
		int c_x = width/2;
		
		int c_y = height/2;
		
		int n = width * height;
		
		double shift = 10;
				
		BufferedImage temp = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);
		
		Graphics2D g = temp.createGraphics();
		
		g.setStroke(new BasicStroke(1.0f));		
		
		//g.setColor(Color.black);
		float[] hsbvals = new float[3];
		Color.RGBtoHSB(0, 0, 0, hsbvals);
		g.setColor(Color.getHSBColor(hsbvals[0], hsbvals[1], hsbvals[2]));
		
		Graphics2D grey_weak = temp.createGraphics();
		grey_weak.setStroke(new BasicStroke(1.0f));		
		Color.RGBtoHSB(173, 173, 173, hsbvals);
		grey_weak.setColor(Color.getHSBColor(hsbvals[0], hsbvals[1], hsbvals[2]));
		
		Graphics2D grey_dark = temp.createGraphics();
		grey_dark.setStroke(new BasicStroke(1.0f));	
		Color.RGBtoHSB(0, 0, 0, hsbvals);
		grey_dark.setColor(Color.getHSBColor(hsbvals[0], hsbvals[1], hsbvals[2]));
		
		/*White painting*/
		int white = 0xFFFFFF;
		for (int i = 0; i < n; i++) {
			int x = i % width;
			int y = i / width;
			temp.setRGB(x, y, white);
		}

		
		//drawCircle (c_x, c_y, radius, g); //radius = 50 pixels
		
		for (int bin = 0; bin < bins_per_cell; bin++) {
		
			double theta_a = (bin+0.4)*Math.PI/bins_per_cell;
			
			double theta_b = (bin-0.4)*Math.PI/bins_per_cell;
	
			//double theta_a = ((bin-0.5)*Math.PI)/bins_per_cell - 1;
			
			//double theta_b = ((((bin+1)%bins_per_cell)-0.5)*Math.PI)/bins_per_cell - 1;
	
			Polygon polygon = null;
			for (int k = 0; k < 2; k++) {
			    polygon = draw_Polygon ( 
					 c_x, 
					 c_y, 
					 radius, 
					 theta_a+k*Math.PI, 
					 theta_b+k*Math.PI, 
					 histogram[bin],
					 magnify );
			
			
					
			if ((bin % 2) == 0) {
				grey_dark.fillPolygon(polygon);
			}
			else {
				grey_dark.fillPolygon(polygon);
			}
			}
			
			/*draw_String (
					 c_x, 
					 c_y, 
					 radius, 
					 1.0,
					 bin,
					 bins_per_cell,
					 g);*/
			
			/*draw_String (
					 c_x, 
					 c_y, 
					 radius, 
					 histogram[bin]*magnify,
					 bin,
					 histogram[bin],
					 bins_per_cell,
					 g);*/

			//g.drawString(String.format("%.2f", histogram[bin]), (int)(x2), (int)(y2));
			
			
		}
		grey_weak.dispose();
		grey_dark.dispose();
		g.dispose();
		
		/*Writing cells*/
		BufferedImage img_temp = Util.imageScale (image, 1.0);
		Graphics2D g_image = img_temp.createGraphics();
		g_image.setStroke(new BasicStroke(1.0f));		
		g_image.setColor(Color.white);
		int div_x = img_temp.getWidth()/number_of_cells_x;
		int div_y = img_temp.getHeight()/number_of_cells_y;
		for (int cy = 0; cy < number_of_cells_y; cy++) {
			for (int cx = 0; cx < number_of_cells_x; cx++) {
				g_image.drawLine(div_x*(cx+1), 0, div_x*(cx+1), img_temp.getHeight());
			}
			g_image.drawLine(0, div_y*(cy+1), img_temp.getWidth(), div_y*(cy+1));
		}
		g_image.dispose();
		
		try {
			ImageIO.write(temp, "png", new File(name));
			ImageIO.write(img_temp, "png", new File("division.png"));
		}
		catch (Exception e) { 
				System.err.printf("cannot write image\n");
		}
		
	} 
	
	
	
		private void draw_String (
			 int c_x, 
			 int c_y, 
			 int radius, 
			 double shift,
			 int bin,
			 int nbins,
			 Graphics2D g) { 
		
        double shift_bin = bin + 0.5;	
		
		double theta = (shift_bin * Math.PI)/nbins - 1;
			
		double x_string = c_x + (radius*shift) * Math.cos(theta);

		double y_string = c_y + (radius*shift) * Math.sin(theta);
		
		g.drawString(String.format("%d", bin), (int)x_string, (int)y_string);

		
	}
	
	private void draw_String (
			 int c_x, 
			 int c_y, 
			 int radius, 
			 double shift,
			 int bin,
			 double value,
			 int nbins,
			 Graphics2D g) { 
		
        double shift_bin = bin + 0.5;	
		
		double theta = (shift_bin * Math.PI)/nbins - 1;
			
		double x_string = c_x + (radius*shift) * Math.cos(theta);

		double y_string = c_y + (radius*shift) * Math.sin(theta);
		
		g.drawString(String.format("%.2f", value), (int)x_string, (int)y_string);

		
	}
	
	private Polygon draw_Polygon (
			 int c_x, 
			 int c_y, 
			 int radius, 
			 double theta_a, 
			 double theta_b, 
			 double value,
			 int magnify) { 
		
		double x1 = c_x + radius * Math.cos(theta_a);
		double y1 = c_y + radius * Math.sin(theta_a);
		double x2 = c_x + (radius + value*magnify) * Math.cos(theta_a);
		double y2 = c_y + (radius + value*magnify) * Math.sin(theta_a);
		double x3 = c_x + radius * Math.cos(theta_b);
		double y3 = c_y + radius * Math.sin(theta_b);
		double x4 = c_x + (radius + value*magnify) * Math.cos(theta_b);
		double y4 = c_y + (radius + value*magnify) * Math.sin(theta_b);
		
		Polygon polygon = new Polygon();
		polygon.addPoint((int)x1, (int)y1);
		polygon.addPoint((int)x2, (int)y2);
		polygon.addPoint((int)x4, (int)y4);
		polygon.addPoint((int)x3, (int)y3);
		
		return polygon;
	}
	
	private void drawCircle (int x, int y, int radius, Graphics2D g){
		g.drawOval(x - radius, y - radius, radius*2, radius*2);
	}

	/*Divides the interval [0-1] into {n} equal parts and returns 1.0 if
	 *{z} is in part number {k} (0..n-1), 0 otherwise. */
	public double StepFunc (int n, int k, double z) {
		assert((k >= 0) && (k < n));
		return ( (k <= z*n) && (z*n < k+1) ? 1.0 : 0.0 );
	}

	/*Computes the Bernstein polynomial of degree {n} and index {k} for
	 *the argument {z}.*/
	public double Bernstein (int n, int k, double z) {
		assert((k >= 0) && (k <= n));
		double res = 1.0;
		for (int i = 0; i < k; i++) {
			res = (res * (n - i))/(i+1)*z;
		}
		return res*Math.pow(1-z,n-k);
	}
	
	/*An edge-core weight function. If {n == 1} returns 1, if (n == 2) returns 
	 *weight 1.0 near the edges, or 1.0 in the core region depending on {k}*/
	public double EdgeCore (int n, int k, double z) {
		assert((n==1) || (n == 2));
		assert((k >= 0) && (k < n));
		if (n==1) { return 1.0; }
		else {
			double v = 4 * z * (1 - z);
			v = v*v;
			return (k==0? 1 - v: v);
		}
	}
	
	
	
	/*AUXILIARY FUNCTIONS*/

	public void get_grey_image (BufferedImage image, double[] grey) {

		int w = image.getWidth(null);
		int h = image.getHeight(null);
		for (int y = 0; y < h; y++) {
			for (int x = 0; x < w; x++) {				
				int pixel = image.getRGB(x, y);
				int position = y * w + x;				
				double R = (pixel >> 16) & 255;
				double G = (pixel >> 8) & 255;
				double B = (pixel & 255);
				grey[position] = 0.299*R + 0.587*G + 0.114*B;
			}
		}
	}
	
	public void compute_weights (double[] weight, int rwt) {
		int nwt = 2*rwt+1;
		weight[0] = 1;
		for (int i = 1; i < nwt; i++) { 
			weight[i] = 0.0; 
			for (int j = i; j >=1; j--) {
				weight[j] = (weight[j] + weight[j-1])/2;
			}
			weight[0] /= 2;
		}
	}
	
	public void normalize_grey_image (double[] grey, int w, int h, double[] weight, int rwt, double[] grel) {
		double AVG, DEV;
		for (int y = 0; y < h; y++) {
			for (int x = 0; x < w; x++) {				
				int position = y * w + x;	
				AVG = get_grey_avg (grey, w, h, x, y, weight, rwt);
				DEV = get_grey_dev (grey, w, h, x, y, weight, rwt, AVG);
				grel[position] = (grey[position] - AVG)/(3*DEV) + 0.5;
				//grel[position] = (grey[position] - AVG)/DEV;
				if (grel[position] < 0) { grel[position] = 0.0; }
				else if (grel[position] > 1) { grel[position] = 1.0; }
			}
		}
	
	}
	
	public double get_grey_avg (double[] grey, int w, int h, int x, int y, double[] weight, int rwt) {
		double sum_vwt = 0.0, sum_wt = 0.0;
		for (int dy = -rwt; dy <= rwt; dy++) {
			for (int dx = -rwt; dx <= rwt; dx++) {
				int x1 = x + dx;
				int y1 = y + dy;
				if ( (x1 >= 0) && (x1 < w) && (y1 >= 0) && (y1 < h)) {
				    int position = y1 * w + x1; 
				    double v = grey[position];
				    double wt = weight[rwt+dx]*weight[rwt+dy];
				    sum_vwt += v * wt;
				    sum_wt += wt; 
				}

			}
		}
		return sum_vwt/sum_wt;
	}
	
	public double get_grey_dev (double[] grey, int w, int h, int x, int y, double[] weight, int rwt, double AVG) {
		double sum_v2wt = 0.0, sum_wt = 0.0;
		for (int dy = -rwt; dy <= rwt; dy++) {
			for (int dx = -rwt; dx <= rwt; dx++) {
				int x1 = x + dx;
				int y1 = y + dy;
				if ( (x1 >= 0) && (x1 < w) && (y1 >= 0) && (y1 < h)) {
				    int position = y1 * w + x1; 
				    double v = grey[position]-AVG;
				    double wt = weight[rwt+dx]*weight[rwt+dy];
				    sum_v2wt += v * v * wt;
				    sum_wt += wt; 
				}

			}
		}
		double noise = 0.01; /*Assumed standard deviation of noise*/
		return Math.sqrt(sum_v2wt/sum_wt + noise*noise);
	}
	

	public void distribution (int w, int h, double[] weight, int rwt, double[] grel, double[] image) {
		
		double[] s = new double[h];
		
		for (int x = 0; x < w; x++) {

			for (int y = 0; y < h; y++) {	
				
				double sum_var = 0.0, sum_avg = 0.0, sum_w = 0.0;
				
				for (int dx = -rwt; dx <= rwt; dx++) {
					int x1 = x + dx;
					int position = y * w + x1;
					if ((x1 >= 0) && (x1 < w)) {
						sum_avg += weight[rwt+dx]*grel[position];
						sum_w += weight[rwt+dx]; 
					}
				}	
				assert(sum_w != 0.0); 
				
				double AVG = (sum_avg/sum_w);
				
				for (int dx = -rwt; dx <= rwt; dx++) {
					int x1 = x + dx;
					int position = y * w + x1;
					if ((x1 >= 0) && (x1 < w)) {
						sum_var += weight[rwt+dx]* ( (grel[position]-AVG)*(grel[position]-AVG) );
					}
				}				
				s[y] = (sum_var/sum_w);
			}
			double[] interval = get_interval (h, s);
			int top = (int)(interval[0]) * w + x;
			int bot = (int)(interval[1]) * w + x;
			if ( (top > 0) && (top < (w*h))) {
				image[top] = 125;
			}
			if ( (bot > 0) && (bot < (w*h))) {
				image[bot] = 125;
			}
		}
	}
	
	
	public double[] get_interval (int h, double[] s)		
	{
		double[] interval = new double[2];

		double sum_s = 0.0;

		double sum_sy = 0.0;

		for (int y = 0; y < h; y++) {
			sum_sy += s[y] * y;
			sum_s += s[y];
		}

		double y_med = sum_sy/sum_s;

		double sum_sdy2 = 0.0;

		for (int y = 0; y < h; y++) {
			sum_sdy2 += s[y] * ( (y - y_med)*(y - y_med) );
			sum_s += s[y];
		}

		double y_dev = Math.sqrt(sum_sdy2/sum_s);

		interval[0] = y_med - 2*y_dev;
		interval[1]	= y_med + 2*y_dev;
		
		return interval;
	}

}