package fr.lip6.kernel;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import fr.lip6.type.TrainingSample;

/**
 * Simple multithreaded implementation over a given Kernel. The multithreading comes only when
 * computing the Gram matrix.<br />
 * Number of Threads is function of available processors.
 * @author dpicard
 *
 * @param <T>
 */
public class ThreadedKernel<T> extends Kernel<T> {

	/**
	 * 
	 */
	private static final long serialVersionUID = -2193768216118832033L;
	
	
	protected Kernel<T> k;
//	private double[][] matrix;

	/**
	 * MultiThread the given kernel
	 * @param kernel
	 */
	public ThreadedKernel(Kernel<T> kernel)
	{
		this.k = kernel;
	}
	

	@Override
	public double valueOf(T t1, T t2) {
		return k.valueOf(t1, t2);
	}

	@Override
	public double valueOf(T t1) {
		return k.valueOf(t1);
	}
	
	/* (non-Javadoc)
	 * @see fr.lip6.kernel.Kernel#getKernelMatrix(java.util.ArrayList)
	 */
	@Override
	public double[][] getKernelMatrix(List<TrainingSample<T>> e) {
		
		double[][] matrix = new double[e.size()][e.size()];

		
		//heuristic choice of number of threads : about as much as the available processors
		int nbc = ((int)Math.sqrt(Runtime.getRuntime().availableProcessors()+1));
		int icrem = e.size()/nbc ;
		
		ArrayList<MatrixThread> threads = new ArrayList<MatrixThread>();
		
		for(int i = 0 ; i < e.size() ; i+=icrem)
		for(int j = 0 ; j < e.size() ; j+=icrem)
		{
			MatrixThread t = new MatrixThread(matrix, e, i, i+icrem, j, j+icrem);
			threads.add(t);
			t.start();
		}
		
		boolean cont = true;
		while(cont)
		{
			cont = false;
			for(MatrixThread t : threads)
				if(!t.hasFinished() && t.isAlive())
					cont = true;
			
			Thread.yield();
		}		
		return matrix;
	}

	
	private class MatrixThread extends Thread
	{
		double[][] m;
		List<TrainingSample<T>> e;
		int mini, maxi, minj, maxj;
		boolean finished = false;
		
		/**
		 * @param m
		 * @param e2
		 * @param mini
		 * @param maxi
		 * @param minj
		 * @param maxj
		 */
		public MatrixThread(double[][] m, List<TrainingSample<T>> e2, int mini, int maxi,
				int minj, int maxj) {
			this.m = m;
			this.e = e2;
			this.mini = mini;
			this.maxi = Math.min(maxi, e2.size());
			this.minj = minj;
			this.maxj = Math.min(maxj, e2.size());
		}



		public void run() {
		
			finished = false;
				for (int i = mini; i < maxi; i++) {
					for (int j = minj; j < maxj; j++) {
						T t1 = e.get(i).sample;
						T t2 = e.get(j).sample;
						double v = valueOf(t1, t2);
						if(!Double.isNaN(v))
						{
							m[i][j] = valueOf(t1, t2);
						}
						else
						{
							System.out.println("NAN : v="+v);
							System.out.println("t1="+Arrays.toString((double[])t1));
							System.out.println("t1="+Arrays.toString((double[])t2));
							System.exit(0);
						}
					}
				}

			finished = true;
		}
		
		
		public boolean hasFinished()
		{
			return finished;
		}
	}



}
