package main;

import hypothesis_validation.HoG_Score;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.LineNumberReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import preprocessing.Parser;
import preprocessing.Util;
import fr.lip6.classifier.Classifier;
import fr.lip6.classifier.DoublePegasosSVM;
import fr.lip6.classifier.SMOSVM;
import fr.lip6.kernel.SimpleCacheKernel;
import fr.lip6.kernel.typed.DoubleGaussChi2;
import fr.lip6.kernel.typed.DoubleGaussL2;
import fr.lip6.type.TrainingSample;

public class Compute_Descriptor {

	public static void main(String[] args) {

		String positives = args[0];

		String negatives = args[1];

		String option = args[2];

		String svm_object_name = args[3];
		
		String classification = args[4];

		String path = args[7];
		
		String kernel_path = args[8];
		
		String descriptor_arguments = args[9];
		
		double threshold = Double.parseDouble(args[10]);

		Parser parser = new Parser();

		LineNumberReader positives_file = parser.Open_File (positives);

		LineNumberReader negatives_file = parser.Open_File (negatives);

		ArrayList<TrainingSample<double[]>> descriptors_list = new ArrayList<TrainingSample<double[]>>();

		ArrayList<Integer> class_list = new ArrayList<Integer>();
		
		ArrayList<String> image_list_name = new ArrayList<String>();

		int pos_size = -1, neg_size = -1;
		
		if (option.compareTo("-training") == 0) {

			ArrayList<double[]> list = new ArrayList<double[]>();

			pos_size = Get_HoG_Descriptor (+1, positives_file, list, class_list, image_list_name, descriptor_arguments);

			neg_size = Get_HoG_Descriptor (-1, negatives_file, list, class_list, image_list_name, descriptor_arguments);
			
			Building (descriptors_list, class_list, list);

			Classifier cls = null;

			double C = -1;

			if (classification.compareTo("-linear") == 0) {

				C = Double.parseDouble(args[5]);

				int iterations = Integer.parseInt(args[6]);

				System.out.printf("C : %f, iterations : %d\n", C, iterations);

				cls = Get_Linear_Classifier (descriptors_list, C, iterations);
			}
			else if (classification.compareTo("-chi2") == 0) {
				
				double gamma2 = Double.parseDouble(args[5]);

				System.err.printf("Computing best gamma\n");
				double gamma = Gamma_Estimation (list);
				System.err.printf("Computing gamma end\n");

				C = Double.parseDouble(args[6]);

				System.out.printf("C : %f, using gamma: %f, other_gamma = %f\n", C, gamma, gamma2);

				cls = Get_Chi2_Classifier (descriptors_list, gamma, C, kernel_path);
				//cls = Get_GaussL2_Classifier (descriptors_list, gamma, C, kernel_path);
			}

			System.err.printf("Training List Size : %d\n", descriptors_list.size());

			/*Training over the list*/
			cls.train(descriptors_list);

			Statistics (cls, descriptors_list, class_list, path, C, true, pos_size, neg_size, threshold, image_list_name);

			/*Recording trained object*/
			try {
				/*Writing SVM trained object*/
				ObjectOutputStream object = new ObjectOutputStream(new FileOutputStream(svm_object_name));
				object.writeObject(cls);
				object.close();
			}catch (IOException e) {
				System.err.println("Failed to record svm trainning object");
			}
		}
		if (option.compareTo("-training-dalal") == 0) {

			ArrayList<double[]> list = new ArrayList<double[]>();

			pos_size = Get_Dalal_Descriptor (+1, positives_file, list, class_list, image_list_name);

			neg_size = Get_Dalal_Descriptor (-1, negatives_file, list, class_list, image_list_name);
			
			System.err.printf("pos_samples : %d, neg_samples : %d\n", pos_size, neg_size);
			
			Building (descriptors_list, class_list, list);

			Classifier cls = null;

			double C = -1;

			if (classification.compareTo("-linear") == 0) {

				C = Double.parseDouble(args[5]);

				int iterations = Integer.parseInt(args[6]);

				System.out.printf("C : %f, iterations : %d\n", C, iterations);

				cls = Get_Linear_Classifier (descriptors_list, C, iterations);
			}
			else if (classification.compareTo("-chi2") == 0) {
				
				double gamma2 = Double.parseDouble(args[5]);

				System.err.printf("Computing best gamma\n");
				double gamma = Gamma_Estimation (list);
				System.err.printf("Computing gamma end\n");

				C = Double.parseDouble(args[6]);

				System.out.printf("C : %f, using gamma: %f, other_gamma = %f\n", C, gamma, gamma2);

				cls = Get_Chi2_Classifier (descriptors_list, gamma, C, kernel_path);
				//cls = Get_GaussL2_Classifier (descriptors_list, gamma, C, kernel_path);
			}

			System.err.printf("Training List Size : %d\n", descriptors_list.size());

			/*Training over the list*/
			cls.train(descriptors_list);

			Statistics (cls, descriptors_list, class_list, path, C, true, pos_size, neg_size, threshold, image_list_name);

			/*Recording trained object*/
			try {
				/*Writing SVM trained object*/
				ObjectOutputStream object = new ObjectOutputStream(new FileOutputStream(svm_object_name));
				object.writeObject(cls);
				object.close();
			}catch (IOException e) {
				System.err.println("Failed to record svm trainning object");
			}
		}
		if (option.compareTo("-testing-dalal") == 0) {

			/*Reading trained object*/

			Classifier cls = null;

			try {
				ObjectInputStream object = new ObjectInputStream(new FileInputStream(svm_object_name));
				cls = (Classifier) object.readObject();
				object.close();
			} catch (IOException e1) {
				System.err.println("Failed to open svm trainning object");
				e1.printStackTrace();
			}	
			catch(ClassNotFoundException e2) {
				System.err.println("Failed to open svm trainning object");
				e2.printStackTrace();
			}

			ArrayList<double[]> list = new ArrayList<double[]>();
			
			pos_size = Get_Dalal_Descriptor (+1, positives_file, list, class_list, image_list_name);

			neg_size = Get_Dalal_Descriptor (-1, negatives_file, list, class_list, image_list_name);

			System.err.println("Building vector");
			
			Building (descriptors_list, class_list, list);
			
			System.err.println("End Building vector");

			Statistics (cls, descriptors_list, class_list, path, Double.MAX_VALUE, true, pos_size, neg_size, threshold, image_list_name);
		}
	
		if (option.compareTo("-testing") == 0) {

			/*Reading trained object*/

			Classifier cls = null;

			try {
				ObjectInputStream object = new ObjectInputStream(new FileInputStream(svm_object_name));
				cls = (Classifier) object.readObject();
				object.close();
			} catch (IOException e1) {
				System.err.println("Failed to open svm trainning object");
				e1.printStackTrace();
			}	
			catch(ClassNotFoundException e2) {
				System.err.println("Failed to open svm trainning object");
				e2.printStackTrace();
			}

			ArrayList<double[]> list = new ArrayList<double[]>();

			pos_size = Get_HoG_Descriptor (+1, positives_file, list, class_list, image_list_name, descriptor_arguments);

			neg_size = Get_HoG_Descriptor (-1, negatives_file, list, class_list, image_list_name, descriptor_arguments);
			
			System.err.println("Building vector");

			Building (descriptors_list, class_list, list);
			
			System.err.println("End Building vector");

			Statistics (cls, descriptors_list, class_list, path, Double.MAX_VALUE, true, pos_size, neg_size, threshold, image_list_name);
		}

		parser.Close_File (positives_file, positives);
		parser.Close_File (negatives_file, negatives);
	}
	
	static Classifier Get_Linear_Classifier (ArrayList<TrainingSample<double[]>> descriptors_list, double C, int iterations) {
		DoublePegasosSVM cls = new DoublePegasosSVM();
		cls.setVerbosityLevel(2);
		cls.setC(C);
		cls.setK(100);
		cls.setT(iterations*descriptors_list.size()/cls.getK());
		return cls;
	}
	
	static Classifier Get_Chi2_Classifier (ArrayList<TrainingSample<double[]>> descriptors_list, double gamma, double C, String kernel_path) {
		
		DoubleGaussChi2 k = new DoubleGaussChi2();

		k.setGamma(gamma);

		/*Loading kernel*/
		File kfile = new File(kernel_path+"kernel_"+k.getGamma());
		System.err.println("Trying to open kernel : " + kernel_path+"kernel_"+k.getGamma());
		SimpleCacheKernel<double[]> sk = null;

		if(kfile.exists()) {
			try {
				System.out.println("Loading kernel file");
				ObjectInputStream kin = new ObjectInputStream(new FileInputStream(kfile));
				sk = (SimpleCacheKernel<double[]>)kin.readObject();
				kin.close();
				System.out.println("Loaded");
			}
			catch(Exception e) {
				System.out.println("Unable to read kernel");
				e.printStackTrace();
				sk = new SimpleCacheKernel<double[]>(k, descriptors_list);
			}
		}
		else {
			System.out.println("Computing kernel cache");
			sk = new SimpleCacheKernel<double[]>(k, descriptors_list);
			System.out.println("Done");
			try {
				System.out.println("Writting kernel file");
				ObjectOutputStream kout = new ObjectOutputStream(new FileOutputStream(kfile));
				kout.writeObject(sk);
				kout.close();
				System.out.println("Done");
			}
			catch(Exception e) {
				System.err.println("Unable to save kernel");
				e.printStackTrace();
			}
		}
		/*Creating SVM class*/
		SMOSVM<double[]> cls = new SMOSVM<double[]>(sk);
		cls.setVerbosityLevel(2);
		cls.setC(C);
		cls.setKernel(k);
		return cls;
	}
	
	static Classifier Get_GaussL2_Classifier (ArrayList<TrainingSample<double[]>> descriptors_list, double gamma, double C, String kernel_path) {
		
		DoubleGaussL2 k = new DoubleGaussL2();

		k.setGamma(gamma);

		/*Loading kernel*/
		File kfile = new File(kernel_path+"kernel_"+k.getGamma());
		System.err.println("Trying to open kernel : " + kernel_path+"kernel_"+k.getGamma());
		SimpleCacheKernel<double[]> sk = null;

		if(kfile.exists()) {
			try {
				System.out.println("Loading kernel file");
				ObjectInputStream kin = new ObjectInputStream(new FileInputStream(kfile));
				sk = (SimpleCacheKernel<double[]>)kin.readObject();
				kin.close();
				System.out.println("Loaded");
			}
			catch(Exception e) {
				System.out.println("Unable to read kernel");
				e.printStackTrace();
				sk = new SimpleCacheKernel<double[]>(k, descriptors_list);
			}
		}
		else {
			System.out.println("Computing kernel cache");
			sk = new SimpleCacheKernel<double[]>(k, descriptors_list);
			System.out.println("Done");
			try {
				System.out.println("Writting kernel file");
				ObjectOutputStream kout = new ObjectOutputStream(new FileOutputStream(kfile));
				kout.writeObject(sk);
				kout.close();
				System.out.println("Done");
			}
			catch(Exception e) {
				System.err.println("Unable to save kernel");
				e.printStackTrace();
			}
		}
		/*Creating SVM class*/
		SMOSVM<double[]> cls = new SMOSVM<double[]>(sk);
		cls.setVerbosityLevel(2);
		cls.setC(C);
		cls.setKernel(k);
		return cls;
	}
	
	
	public static void Statistics (
			Classifier<double[]> cls,
			ArrayList<TrainingSample<double[]>> descriptor_list, 
			ArrayList<Integer> class_list,
			String path, 
			double C,
			boolean write,
			int pos_size,
			int neg_size,
			double threshold,
			ArrayList<String> image_list_name) {

		PrintStream scores = null, statistics = null;
		
		Parser parser = new Parser();
		if (write) {
			scores = parser.Open_Print_Stream (path + "scores_information.txt");
			statistics = parser.Open_Print_Stream (path + "statistics.txt");
		}
		
		/*Evaluating the object training over the train list*/
		int errors = 0;

		int tp = 0, fp = 0, fn = 0, tn = 0;

		for(int i = 0 ; i < descriptor_list.size(); i++) {
			
			TrainingSample<double[]> e = descriptor_list.get(i);
			
			int l = e.label;
			
			double v = cls.valueOf(e.sample);
			
			if ( (class_list.get(i) == 1) && (v >= threshold) ) {
				tp++;
			}
			else if ( (class_list.get(i) == 1) && (v < threshold) ) {
				fn++;
			}
			else if ( (class_list.get(i) == -1) && (v < threshold) ) {
				tn++;
			}
			else if ( (class_list.get(i) == -1) && (v >= threshold) ) {
				fp++;
			}
			
			if (write) {
				scores.printf("%2.6f %2.6f %d %s\n", v, v, class_list.get(i), image_list_name.get(i));
			}
		}
		
		double alpha = 0.5;
		double precision = (double )tp / (double)(fp + tp);
		double recall = (double )tp / (double)(fn + tp);
		double error = 1.0/(alpha/precision + (1.0 - alpha)/recall);
		double phi = (double )tn / (double)(fp + tn);
		
		statistics.printf("%f %f %f %f,  tp : %d, fn : %d, tn : %d, fp : %d\n", precision, recall, error, phi, tp, fn, tn, fp);
		
		if (write) {
			parser.Close_Print_Stream (scores, path + "scores_information.txt");	
			parser.Close_Print_Stream (statistics, path + "statistics.txt");
		}
	}
	

	
	public static void Building (
			ArrayList<TrainingSample<double[]>> descriptor_list, 
			ArrayList<Integer> class_list,
			ArrayList<double[]> list) {

		for (int i = 0; i < list.size(); i++) {
			double[] vector = list.get(i);
			double[] tmp = new double[vector.length];
			int k = 0;
			for (int j = 0; j < vector.length; j++) {
				tmp[k] = vector[j];
				k++;
			}
			descriptor_list.add(new TrainingSample<double[]>(tmp, class_list.get(i)));
		}
	}
	
	public static double Gamma_Estimation (
			ArrayList<double[]> list) {
		
		double[][] matrix = new double[list.size()][list.size()];
		
		/*Computing the histogram matrix*/
		for (int i = 0; i < list.size(); i++) {
			for (int j = 0; j < list.size(); j++) {
				double[] h_i = list.get(i);
				double[] h_j = list.get(j);
				double sum = 0.0;
				for (int k = 0; k < h_i.length; k++) {
					/*Chi2 distance*/
					if ( (h_i[k] + h_j[k]) != 0) {
						sum += ( (h_i[k] - h_j[k])*(h_i[k] - h_j[k]) )/(h_i[k] + h_j[k]);
					}
				}
				matrix[i][j] = sum; 
			}
		}

		/*Computing the mean distance over the matrix*/
		double mean = 0.0;
		for (int i = 0; i < list.size(); i++) {
			for (int j = 0; j < list.size(); j++) {
				mean += matrix[i][j];
			}
		}
		
		System.err.printf("Matrix mean: %f\n", mean);
		
		double mean_distance = mean/(list.size()*list.size());
		
		System.err.printf("Mean distance: %f\n", mean_distance);
		
		double gamma = 1.0/mean_distance;
		
		System.err.printf("The best gamma is hope to be (near of): %f\n", gamma);
		
		return gamma;		
	}
	
	public static int Get_HoG_Descriptor (
			int sample_class, 
			LineNumberReader positives_file, 
			ArrayList<double[]> list, 
			ArrayList<Integer> class_list,
			ArrayList<String> image_list_name, 
			String descriptor_arguments) {

		int size = 0;
		
		BufferedImage image = null;
		
		int bad_images = 0;
		
		int descriptor_size = -1;
		
		do {
			
			String[] image_name = new String[1];
			
			image = Util.Get_Image (positives_file, image_name); 

			if (image == null) { break; }

			double[] descriptor = null;

			try {
				descriptor = build_descriptor(image, descriptor_arguments);
				descriptor_size = descriptor.length;
			}
			catch(Exception e) {
				System.err.println("problem with image "+positives_file.getLineNumber());
				e.printStackTrace();
			}

			if (descriptor != null && notZero(descriptor)) {
				list.add(descriptor);
				class_list.add(sample_class);
				image_list_name.add(image_name[0]);
				size++;
			}
			else
			{
				System.err.println("Bad image : "+positives_file.getLineNumber());
				bad_images++;
			}
		} while (true);
		System.err.println("Bad images : " + bad_images);
		System.out.println("Descriptor Size : " + descriptor_size);
		return size;
	}
	
	public static int Get_Dalal_Descriptor (
			int sample_class, 
			LineNumberReader descriptor_file, 
			ArrayList<double[]> list, 
			ArrayList<Integer> class_list,
			ArrayList<String> image_list_name) {

		int size = 0;

		do {			
			Parser parser = new Parser();

			String[] s = parser.getTokens (descriptor_file, "\n");

			if (s == null) { break; }
			
			String[] S = s[0].trim().split(" ");
						
			int sample_class_dalal = Integer.parseInt(S[0]);
			
			assert (sample_class == sample_class_dalal);
			
			double[] descriptor = new double[S.length-1];
			
			for (int i = 0; i < (S.length - 1); i++) {
				descriptor[i] = Double.parseDouble(S[i+1]);
			}

			if (descriptor != null && notZero(descriptor)) {
				list.add(descriptor);
				class_list.add(sample_class);
				image_list_name.add("unknown");
				size++;
			}
		} while (true);
		return size;
	}
	
	
	private static boolean notZero (double[] d)
	{
		for(int x = 0 ; x < d.length ; x++)
			if(d[x] != 0) //if(!Double.isNaN(d[x]))
				return true;
		return false;
	}	
		
	public static double[] build_descriptor (BufferedImage image, String file_parameters)
	{
		Parser parser = new Parser();
		
		ArrayList<String> list = parser.Get_Array_List (file_parameters);
		
		HoG_Score object = new HoG_Score();
		
		boolean debug = false;
		
		double descriptor[] = object.get_descriptor (debug, image, list);
		
		return descriptor;
	}
}
