package main;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.LineNumberReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import javax.imageio.ImageIO;
import preprocessing.Parser;
import preprocessing.Util;
import descriptors.Fourier;
import descriptors.HoG;
import descriptors.Polar;
import descriptors.Z;
import descriptors.Z2;
import fr.lip6.classifier.Classifier;
import fr.lip6.classifier.DoublePegasosSVM;
import fr.lip6.classifier.SMOSVM;
import fr.lip6.kernel.SimpleCacheKernel;
import fr.lip6.kernel.typed.DoubleGaussChi2;
import fr.lip6.kernel.typed.DoubleGaussL2;
import fr.lip6.type.TrainingSample;

public class Compute_Descriptors {

	public static void main(String[] args) {

		String positives = args[0];

		String negatives = args[1];

		String descriptor_type = args[2];

		String option = args[3];

		String svm_object_name = args[4];

		String normalization = args[5];

		String classification = args[6];

		String path = args[9];
		
		String kernel_path = args[10];

		Parser parser = new Parser();

		LineNumberReader positives_file = parser.Open_File (positives);

		LineNumberReader negatives_file = parser.Open_File (negatives);

		ArrayList<TrainingSample<double[]>> descriptors_list = new ArrayList<TrainingSample<double[]>>();

		ArrayList<Integer> class_list = new ArrayList<Integer>();

		int pos_size = -1, neg_size = -1;
		
		if (option.compareTo("-training") == 0) {

			ArrayList<double[]> list = new ArrayList<double[]>();

			pos_size = Get_Descriptors (+1, positives_file, list, class_list, descriptor_type);

			neg_size = Get_Descriptors (-1, negatives_file, list, class_list, descriptor_type);
			
			double[] mean = Get_Mean (list);

			double[] deviation = Get_Deviation (list, mean);

			if (normalization.compareTo("-true") == 0) {
				Building (descriptors_list, class_list, mean, deviation, list, true);
			}
			else {
				Building (descriptors_list, class_list, mean, deviation, list, false);
			}

			Classifier cls = null;

			double C = -1;

			if (classification.compareTo("-linear") == 0) {

				C = Double.parseDouble(args[7]);

				int iterations = Integer.parseInt(args[8]);

				System.out.printf("C : %f, iterations : %d\n", C, iterations);

				cls = Get_Linear_Classifier (descriptors_list, C, iterations);
			}
			else if (classification.compareTo("-chi2") == 0) {

				/*VERIFICAR MELHOR ISSO*/
				double gamma = Double.parseDouble(args[7]);

				System.err.printf("Computing best gamma\n");
				double gamma2 = Gamma_Estimation (list);
				System.err.printf("Computing gamma end\n");

				C = Double.parseDouble(args[8]);

				System.out.printf("C : %f, gamma: %f, computed_gamma = %f\n", C, gamma, gamma2);

				cls = Get_Chi2_Classifier (descriptors_list, gamma, C, kernel_path);
				//cls = Get_GaussL2_Classifier (descriptors_list, gamma, C, kernel_path);
			}

			System.err.printf("Training List Size : %d\n", descriptors_list.size());

			/*Training over the list*/
			cls.train(descriptors_list);

			Statistics (cls, descriptors_list, class_list, path, C, true, pos_size, neg_size);

			/*Recording trained object*/
			try {
				/*Writing SVM trained object*/
				ObjectOutputStream object = new ObjectOutputStream(new FileOutputStream(svm_object_name));
				object.writeObject(cls);
				object.close();

				/*Recording Mean and Deviation*/
				PrintStream norm = parser.Open_Print_Stream (path + "normalization.txt");
				norm.printf("%d\n", mean.length);
				parser.Write_Print_Stream (norm, mean);
				parser.Write_Print_Stream (norm, deviation);
				parser.Close_Print_Stream (norm, path + "normalization.txt");
			}catch (IOException e) {
				System.err.println("Failed to record svm trainning object");
			}
		}
		if (option.compareTo("-testing") == 0) {

			/*Reading trained object*/

			Classifier cls = null;

			double[] mean = null;

			double[] deviation = null;

			try {
				ObjectInputStream object = new ObjectInputStream(new FileInputStream(svm_object_name));
				cls = (Classifier) object.readObject();
				object.close();

				LineNumberReader file = new LineNumberReader(new FileReader(path + "normalization.txt"));
				int size = Integer.parseInt(parser.getLineParameter (file, "\n", 0)); 
				mean = parser.Get_Data (size, file);
				deviation = parser.Get_Data (size, file); 
				file.close(); 
			} catch (IOException e1) {
				System.err.println("Failed to open svm trainning object");
				e1.printStackTrace();
			}	
			catch(ClassNotFoundException e2) {
				System.err.println("Failed to open svm trainning object");
				e2.printStackTrace();
			}

			ArrayList<double[]> list = new ArrayList<double[]>();

			pos_size = Get_Descriptors (+1, positives_file, list, class_list, descriptor_type);

			neg_size = Get_Descriptors (-1, negatives_file, list, class_list, descriptor_type);
			
			System.err.println("Building vector");

			if (normalization.compareTo("-true") == 0) {
				Building (descriptors_list, class_list, mean, deviation, list, true);
			}
			else {
				Building (descriptors_list, class_list, mean, deviation, list, false);
			}
			
			System.err.println("End Building vector");

			Statistics (cls, descriptors_list, class_list, path, Double.MAX_VALUE, true, pos_size, neg_size);

		}

		parser.Close_File (positives_file, positives);

		parser.Close_File (negatives_file, negatives);
	}
	
	static Classifier Get_Linear_Classifier (ArrayList<TrainingSample<double[]>> descriptors_list, double C, int iterations) {
		DoublePegasosSVM cls = new DoublePegasosSVM();
		cls.setVerbosityLevel(2);
		cls.setC(C);
		cls.setK(100);
		cls.setT(iterations*descriptors_list.size()/cls.getK());
		return cls;
	}
	
	static Classifier Get_Chi2_Classifier (ArrayList<TrainingSample<double[]>> descriptors_list, double gamma, double C, String kernel_path) {
		
		DoubleGaussChi2 k = new DoubleGaussChi2();

		k.setGamma(gamma);

		/*Loading kernel*/
		File kfile = new File(kernel_path+"kernel_"+k.getGamma());
		System.err.println("Trying to open kernel : " + kernel_path+"kernel_"+k.getGamma());
		SimpleCacheKernel<double[]> sk = null;

		if(kfile.exists()) {
			try {
				System.out.println("Loading kernel file");
				ObjectInputStream kin = new ObjectInputStream(new FileInputStream(kfile));
				sk = (SimpleCacheKernel<double[]>)kin.readObject();
				kin.close();
				System.out.println("Loaded");
			}
			catch(Exception e) {
				System.out.println("Unable to read kernel");
				e.printStackTrace();
				sk = new SimpleCacheKernel<double[]>(k, descriptors_list);
			}
		}
		else {
			System.out.println("Computing kernel cache");
			sk = new SimpleCacheKernel<double[]>(k, descriptors_list);
			System.out.println("Done");
			try {
				System.out.println("Writting kernel file");
				ObjectOutputStream kout = new ObjectOutputStream(new FileOutputStream(kfile));
				kout.writeObject(sk);
				kout.close();
				System.out.println("Done");
			}
			catch(Exception e) {
				System.err.println("Unable to save kernel");
				e.printStackTrace();
			}
		}
		/*Creating SVM class*/
		SMOSVM<double[]> cls = new SMOSVM<double[]>(sk);
		cls.setVerbosityLevel(2);
		cls.setC(C);
		cls.setKernel(k);
		return cls;
	}
	
	static Classifier Get_GaussL2_Classifier (ArrayList<TrainingSample<double[]>> descriptors_list, double gamma, double C, String kernel_path) {
		
		DoubleGaussL2 k = new DoubleGaussL2();

		k.setGamma(gamma);

		/*Loading kernel*/
		File kfile = new File(kernel_path+"kernel_"+k.getGamma());
		System.err.println("Trying to open kernel : " + kernel_path+"kernel_"+k.getGamma());
		SimpleCacheKernel<double[]> sk = null;

		if(kfile.exists()) {
			try {
				System.out.println("Loading kernel file");
				ObjectInputStream kin = new ObjectInputStream(new FileInputStream(kfile));
				sk = (SimpleCacheKernel<double[]>)kin.readObject();
				kin.close();
				System.out.println("Loaded");
			}
			catch(Exception e) {
				System.out.println("Unable to read kernel");
				e.printStackTrace();
				sk = new SimpleCacheKernel<double[]>(k, descriptors_list);
			}
		}
		else {
			System.out.println("Computing kernel cache");
			sk = new SimpleCacheKernel<double[]>(k, descriptors_list);
			System.out.println("Done");
			try {
				System.out.println("Writting kernel file");
				ObjectOutputStream kout = new ObjectOutputStream(new FileOutputStream(kfile));
				kout.writeObject(sk);
				kout.close();
				System.out.println("Done");
			}
			catch(Exception e) {
				System.err.println("Unable to save kernel");
				e.printStackTrace();
			}
		}
		/*Creating SVM class*/
		SMOSVM<double[]> cls = new SMOSVM<double[]>(sk);
		cls.setVerbosityLevel(2);
		cls.setC(C);
		cls.setKernel(k);
		return cls;
	}
	
	
	public static void Statistics (
			Classifier<double[]> cls,
			ArrayList<TrainingSample<double[]>> descriptor_list, 
			ArrayList<Integer> class_list,
			String path, 
			double C,
			boolean write,
			int pos_size,
			int neg_size) {

		PrintStream scores_pos = null;
		PrintStream scores_neg = null;
		
		Parser parser = new Parser();
		if (write) {
			System.err.printf("Gravando positive arquivo em %s\n", path + "scores_pos_" + C + ".txt");
			scores_pos = parser.Open_Print_Stream (path + "scores_pos_" + C + ".txt");
			//scores_pos.printf("%d\n", pos_size);
			System.err.printf("Gravando negative arquivo em %s\n", path + "scores_neg_" + C + ".txt");
			scores_neg = parser.Open_Print_Stream (path + "scores_neg_" + C + ".txt");
			//scores_neg.printf("%d\n", neg_size);
		}
		
		/*Evaluating the object training over the train list*/
		int errors = 0;

		int tp = 0, fp = 0, fn = 0, tn = 0;

		for(int i = 0 ; i < descriptor_list.size(); i++) {
			
			TrainingSample<double[]> e = descriptor_list.get(i);
			
			int l = e.label;
			
			double v = cls.valueOf(e.sample);
			
			if (write) {
				if (l < 0) {
					scores_neg.printf("%d %f\n", i, v);
				}
				else {
					scores_pos.printf("%d %f\n", i, v);
				}
			}
			
			if ( (v * l) < 0 ) {
				errors++;
				if (class_list.get(i) == 1) {
					fp++;
				}
				else {
					fn++;
				}
			}
			else {
				if (class_list.get(i) == 1) {
					tp++;
				}
				else { 
					tn++;
				}
			}
			if(i%100 == 0)
				System.err.print(".");
		}
		if (write) {
			parser.Close_Print_Stream (scores_pos, path + "scores_pos_" + C + ".txt");
			parser.Close_Print_Stream (scores_neg, path + "scores_neg_" + C + ".txt");
		}
		System.out.printf("tp : %d, fp : %d, fn : %d, tn : %d, errors : %d\n", tp, fp, fn, tn, errors);
	}
	
	public static void Building (
			ArrayList<TrainingSample<double[]>> descriptor_list, 
			ArrayList<Integer> class_list,
			double[] mean,
			double[] deviation,
			ArrayList<double[]> list,
			boolean normalization) {

		int valid = 0;

		for (int i = 0; i < deviation.length; i++) { if (deviation[i] != 0) {valid++;} }

		System.err.printf("Vetor descriptors before : %d\n", mean.length);
		
		System.err.printf("Vetor descriptors after : %d\n", valid);
		
		for (int i = 0; i < list.size(); i++) {
			double[] vector = list.get(i);
			double[] tmp = new double[valid];
			int k = 0;
			if (normalization) {
				for (int j = 0; j < vector.length; j++) {
					if (deviation[j] != 0) {
						tmp[k] = (vector[j] - mean[j])/deviation[j];
						k++;
					}
				}
			}
			else {
				for (int j = 0; j < vector.length; j++) {
					if (deviation[j] != 0) {
						tmp[k] = vector[j];
						k++;
					}
				}
			}
			/*System.err.printf("#### sample : %d - ", i);
			for (int j = 0; j < vector.length; j++) {
				System.err.printf("%f ", vector[j]);
			}
			System.err.printf("\n");
			for (int j = 0; j < tmp.length; j++) {
				System.err.printf("%f ", tmp[j]);
			}
			System.err.printf("\n");*/
			descriptor_list.add(new TrainingSample<double[]>(tmp, class_list.get(i)));
		}
	}
	
	public static double Gamma_Estimation (
			ArrayList<double[]> list) {
		
		double[][] matrix = new double[list.size()][list.size()];
		
		/*Computing the histogram matrix*/
		for (int i = 0; i < list.size(); i++) {
			for (int j = 0; j < list.size(); j++) {
				double[] h_i = list.get(i);
				double[] h_j = list.get(j);
				double sum = 0.0;
				for (int k = 0; k < h_i.length; k++) {
					/*Chi2 distance*/
					if ( (h_i[k] + h_j[k]) != 0) {
						sum += ( (h_i[k] - h_j[k])*(h_i[k] - h_j[k]) )/(h_i[k] + h_j[k]);
					}
				}
				matrix[i][j] = sum; 
			}
		}

		/*Computing the mean distance over the matrix*/
		double mean = 0.0;
		for (int i = 0; i < list.size(); i++) {
			for (int j = 0; j < list.size(); j++) {
				mean += matrix[i][j];
			}
		}
		
		System.err.printf("Matrix mean: %f\n", mean);
		
		double mean_distance = mean/(list.size()*list.size());
		
		System.err.printf("Mean distance: %f\n", mean_distance);
		
		double gamma = 1.0/mean_distance;
		
		System.err.printf("The best gamma is hope to be (near of): %f\n", gamma);
		
		return gamma;
		
	}
	
	public static int Get_Descriptors (
			int sample_class, 
			LineNumberReader positives_file, 
			ArrayList<double[]> list, 
			ArrayList<Integer> class_list,
			String descriptor_type) {
		
		int size = 0;
		BufferedImage image = null;
		int bad = 0;
		do {
			if (descriptor_type.compareTo("-all") != 0) {
				image = Util.Get_Image (positives_file);
				
				if (image == null) { break; }
			}

			double[] descriptor = null;

			if (descriptor_type.compareTo("-fourier") == 0) {
				try {
					descriptor = Get_Fourier_Descriptor (image);
				}
				catch(Exception e) {
					System.err.println("problem with image "+positives_file.getLineNumber());
					e.printStackTrace();
				}
			}
			else if (descriptor_type.compareTo("-polar") == 0) {
				try {
					descriptor = Get_Polar_Descriptor(image);
				}
				catch(Exception e) {
					System.err.println("problem with image "+positives_file.getLineNumber());
					e.printStackTrace();
				}
			}
			else if (descriptor_type.compareTo("-zernike") == 0) {
				try {
					descriptor = Get_Zernike_Descriptor(image);
				}
				catch(Exception e) {
					System.err.println("problem with image "+positives_file.getLineNumber());
					e.printStackTrace();
				}
			}
			else if (descriptor_type.compareTo("-hog") == 0) {
				try {
					descriptor = Get_HoG_Descriptor(image);
				}
				catch(Exception e) {
					System.err.println("problem with image "+positives_file.getLineNumber());
					e.printStackTrace();
				}
			}
			else if (descriptor_type.compareTo("-all") == 0) {
				Get_All_Descriptors (sample_class, positives_file, list, class_list);
				break;
			}
			
			
			if (descriptor != null && notZero(descriptor)) {
				list.add(descriptor);
				class_list.add(sample_class);
				size++;
			}
			else
			{
				System.err.println("bad image : "+positives_file.getLineNumber());
				bad++;
			}
		} while (true);
		System.err.println("bad images : " + bad);
		return size;
	}
	
	public static double[] Get_Mean (ArrayList<double[]> list) {

		int size = list.get(0).length;
		
		double[] mean = new double[size];
		
		for (int i = 0; i < size; i++) { mean[i] = 0.0; }
		
		for (int i = 0; i < list.size(); i++) {
			double[] vector = list.get(i);
			for (int j = 0; j < vector.length; j++) {
				mean[j] += vector[j];
			}
		}
		
		for (int i = 0; i < size; i++) { mean[i] /= list.size(); }
		
		return mean;
		
	}
	
	public static double[] Get_Deviation (ArrayList<double[]> list, double[] mean) {

		int size = list.get(0).length;
		
		double[] deviation = new double[size];
		
		for (int i = 0; i < size; i++) { deviation[i] = 0.0; }
		
		for (int i = 0; i < list.size(); i++) {
			double[] vector = list.get(i);
			for (int j = 0; j < vector.length; j++) {
				deviation[j] += (vector[j] - mean[j])*(vector[j] - mean[j]);
			}
		}
		
		for (int i = 0; i < size; i++) { deviation[i] = Math.sqrt(deviation[i]/(double)list.size()); }
		
		return deviation;
	}
	
	
	private static boolean notZero(double[] d)
	{
		for(int x = 0 ; x < d.length ; x++)
			if(d[x] != 0)
				return true;
		return false;
	}
	
	public static double[] Get_Fourier_Descriptor (BufferedImage image) 
	{
		BufferedImage local_image = Util.colorToGreyImage(image);
		
		int size = 30;
		
		Fourier fourier = new Fourier();
		
		fourier.fourier_fibd_contour (local_image);
		
		fourier.Discretize(size);
		
		int center = fourier.Compute_TF_Image ();

		double[] descriptor = new double[(size - 2)*2];
		
		int j = 0;
		for(int i = 0; i < center; i++, j++) { 
			descriptor[j] = fourier.getX(i); 
		} 
	    for(int i = 0; i < center; i++, j++) { 
	    	descriptor[j] = fourier.getY(i); 
	    }
	    for(int i = center + 2; i < size; i++, j++) { 
	    	descriptor[j] = fourier.getX(i); 
	    } 
	    for(int i = center + 2; i < size; i++, j++) { 
	    	descriptor[j] = fourier.getY(i); 
	    } 

		//System.out.println(Arrays.toString(descriptor));
		return descriptor;
	}
	
	public static double[] Get_Polar_Descriptor (BufferedImage image) 
	{
		BufferedImage local_image = Util.colorToGreyImage(image);
		
		int angle = 45;
		
		int radius = 15;
		
		Polar polar = new Polar();
		
		double[] descriptor = new double[angle * radius];

		polar.computeGravityCenter (local_image);
		
		BufferedImage polar_image = polar.Compute_Polar_Image (local_image, angle, radius); 

		BufferedImage TF = polar.Compute_TF_Image (polar_image, angle, radius);

		int[] array = Util.getImageArray(TF);

		for(int j = 0; j < (angle*radius); j++) {
			descriptor[j] = ((double)array[j]);
		}
		//System.out.println(Arrays.toString(descriptor));
		return descriptor;
	}
	
	public static double[] Get_Zernike_Descriptor (BufferedImage image)
	{
		BufferedImage local_image = Util.colorToGreyImage(image);
		
		int zernike_first_order = 1;

		int zernike_last_order = 6;

		//double[] descriptor = new double[(zernike_last_order-zernike_first_order)*];
		
		double[] descriptor = new double[48];
		
		Z[][] pzm = new Z[zernike_last_order + 1][2 * zernike_last_order + 1];

		Z z_max = new Z();

		double max = z_max.Computing_Rho_Max (local_image);

		for(int p = zernike_first_order, j = 0; p <= zernike_last_order; p++) {
			Z z = new Z();
			z.pseudo_zernike_moment_0 (p, local_image, max, z_max.barycenter_x, z_max.barycenter_y);
			pzm[p][zernike_last_order] = z;
			for(int q = 1; q <= p; q++) {
				Z2 z2 = new Z2();
				z2.pseudo_zernike_moment (p, q, local_image, max, z_max.barycenter_x, z_max.barycenter_y);
				pzm[p][-q + zernike_last_order] = z2.getNum1();
				pzm[p][+q + zernike_last_order] = z2.getNum2();
			}
			for(int q = -p; q <= p; q++) {
				descriptor[j++]=pzm[p][q+zernike_last_order].module();
			}
		}
		//System.out.println(Arrays.toString(descriptor));
		return descriptor;
	} 
		
	public static double[] Get_HoG_Descriptor (BufferedImage image)
	{
		HoG hog = new HoG();
		
		double[] descriptor = hog.hog(image, 5, 1, 3, 16, 16, false);		
		
		return descriptor;
	}
	
	public static void Get_All_Descriptors (
			int sample_class, 
			LineNumberReader file,
			ArrayList<double[]> list, 
			ArrayList<Integer> class_list )
	{
		try {
			Parser parser = new Parser();
			
			int fourier_size = Integer.parseInt(parser.getLineParameter (file, "\n", 0)); 
			double[] fourier_scores = parser.Get_Data (fourier_size, file);
			
			int polar_size = Integer.parseInt(parser.getLineParameter (file, "\n", 0)); 
			double[] polar_scores = parser.Get_Data (polar_size, file);
			
			int zernike_size = Integer.parseInt(parser.getLineParameter (file, "\n", 0)); 
			double[] zernike_scores = parser.Get_Data (zernike_size, file);
			
			file.close(); 

			int size = fourier_size;
			
			assert(fourier_size == polar_size); assert(polar_size == zernike_size);
			
			int bad = 0;
			for (int i = 0; i < size; i++) {
				double[] descriptor = new double[3];
				descriptor[0] = fourier_scores[i];
				descriptor[1] = polar_scores[i];
				descriptor[2] = zernike_scores[i];
				if (descriptor != null && notZero(descriptor)) {
					list.add(descriptor);
					class_list.add(sample_class);
				}
				else {
					System.err.println("bad descriptor : " + i);
					bad++;
				}
			}
			System.err.println("bad images : "+bad);
		} catch (IOException e) {
			System.err.println("Failed to the scores file");
			e.printStackTrace();
		}	
	}
	
	

}
