package ssod_laba5;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Scanner;

public class Iris {
	
	public static void readDB(double[][] inputData) throws FileNotFoundException{
		File in = new File("IrisDB.txt");
        Scanner scanFile = new Scanner(in);
        for(int i = 0; i < 75; i++) {

            for(int j = 0; j < 4; j++) {
                double tmp = scanFile.nextDouble();
                inputData[i][j] = tmp;
            }

            if(i < 25)
                inputData[i][4] = 1;
            else if (i < 50)
                inputData[i][4] = 2;
            else
                inputData[i][4] = 3;
        }
	}
	
	static double gaussValue(double x, double ave, double sigma) {
        return 1.0 / (Math.pow(2 * Math.PI, 0.5) * sigma) * Math.exp(-Math.pow((x - ave) / sigma, 2.0));
    }

    static double getMu(int typeOfRule, double value) {
        switch(typeOfRule){
            case 1: return gaussValue(value, 0.0, 0.4);
            case 2: return gaussValue(value, 0.5, 0.4);
            case 3: return gaussValue(value, 1, 0.4);
            default: return 1;
        }
    }

    static int returnType(double y) {
        double b1 = 1.2;
        double b2 = 2.2;
        double b3 = 4;
        if(y < b1)
            return 1;
        else if(y >= b1 && y < b2)
            return 2;
        else if(y < b3)
            return 3;
        return -1;
    }
    
    static void saveGrafiki(){
    	File file1 = new File("gaus1.txt");
    	File file2 = new File("gaus2.txt");
    	File file3 = new File("gaus3.txt");
    	File file4 = new File("gaus4.txt");
    	
		try {
			FileWriter fr = new FileWriter(file1);
			BufferedWriter br = new BufferedWriter(fr);
			for(double i=0;i<1;i+=0.01){
				br.write(String.format("%.20f\n", i));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		
		try {
			FileWriter fr = new FileWriter(file2);
			BufferedWriter br = new BufferedWriter(fr);
			for(double i=0;i<1;i+=0.01){
				double temp = gaussValue(i, 0.0, 0.4);
				br.write(String.format("%.20f\n", temp));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		
		try {
			FileWriter fr = new FileWriter(file3);
			BufferedWriter br = new BufferedWriter(fr);
			for(double i=0;i<1;i+=0.01){
				double temp = gaussValue(i, 0.5, 0.4);
				br.write(String.format("%.20f\n", temp));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		
		try {
			FileWriter fr = new FileWriter(file4);
			BufferedWriter br = new BufferedWriter(fr);
			for(double i=0;i<1;i+=0.01){
				double temp = gaussValue(i, 1.0, 0.4);
				br.write(String.format("%.20f\n", temp));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		
    }
    
    public static void saveData( ArrayList<Double[]> data, ArrayList<String> successOrNot ){
    	File file1 = new File("data1.txt");
    	File file2 = new File("data2.txt");
    	File file3 = new File("data3.txt");
    	
    	try {
			FileWriter fr = new FileWriter(file1);
			BufferedWriter br = new BufferedWriter(fr);
			for(Double[] item: data){
				br.write(String.format("%.20f\n", item[0]));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
    	
    	try {
			FileWriter fr = new FileWriter(file2);
			BufferedWriter br = new BufferedWriter(fr);
			for(Double[] item: data){
				br.write(String.format("%.20f\n", item[1]));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
    	
    	try {
			FileWriter fr = new FileWriter(file3);
			BufferedWriter br = new BufferedWriter(fr);
			for(String item: successOrNot){
				br.write(item + "\n");
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
    	
    }
    
    public static void saveProbability(ArrayList<Double[]> totalArr, double[][] inputData){
    	File file1 = new File("prob1.txt");
    	File file2 = new File("prob2.txt");
    	File file3 = new File("prob3.txt");
    	
    	ArrayList<Double> prob1 = new ArrayList<Double>();
    	ArrayList<Double> prob2 = new ArrayList<Double>();
    	ArrayList<Double> prob3 = new ArrayList<Double>();
    	
    	for(Double[] item: totalArr){
    		double raznica1 = 1.0/ Math.abs(item[0] - 1);
    		double raznica2 = 1.0/ Math.abs(item[0] - 2);
    		double raznica3 = 1.0/ Math.abs(item[0] - 3);
    		double summa = raznica1 + raznica2 + raznica3;
    		
    		raznica1 = raznica1 / summa;
    		raznica2 = raznica2 / summa;
    		raznica3 = raznica3 / summa;
    		
    		prob1.add(raznica1);
    		prob2.add(raznica2);
    		prob3.add(raznica3);
    		
    	}
    	
    	try {
			FileWriter fr = new FileWriter(file1);
			BufferedWriter br = new BufferedWriter(fr);
			for(Double item: prob1){
				br.write(String.format("%.2f%%\n", item * 100));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
    	
    	try {
			FileWriter fr = new FileWriter(file2);
			BufferedWriter br = new BufferedWriter(fr);
			for(Double item: prob2){
				br.write(String.format("%.2f%%\n", item * 100));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
    	
    	try {
			FileWriter fr = new FileWriter(file3);
			BufferedWriter br = new BufferedWriter(fr);
			for(Double item: prob3){
				br.write(String.format("%.2f%%\n", item * 100));
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
    	
    }
    
       
	
    public static void main(String[] args) throws FileNotFoundException {

    	saveGrafiki();    	
    	
        int[][] rules = {{2, 3, 1, 1, 1}
                		,{2, 2, 1, 1, 1}
                		,{3, 2, 2, 2, 2}
                		,{3, 3, 2, 2, 2}
                		,{3, 2, 3, 3, 3}
                		,{3, 3, 3, 3, 3}};

        double[][] inputData = new double[75][5];
        readDB(inputData);

        ArrayList<Double[]> totalArr = new ArrayList<>();

        for(double[] item : inputData) {
            double[] out = new double[6];

            for (int i = 0; i < rules.length ; i++) { // find Mu
                out[i] = 1.0;
                for (int j = 0; j < 4; j++) {
                    out[i] *= getMu(rules[i][j], item[j]);
                }
            }

            
            double[][] arr ={
               	 {0.6,  0.6,  0.4,  0.1},
                    {0.5,  0.4,  0.7,  0.2},
                    {0.2,  0.3,  0.7,  0.8},
                    {0.2,  0.6,  0.4,  0.9},
                    {0.7,  0.4,  0.4,  0.5},
                    {0.9,  0.1,  0.9,  0.5}
               	};
            
            
            double delimoe = 0; // find y^
            double deleter = 0;
            for (int i = 0; i < rules.length; i++) {
            	delimoe += out[i] * ( item[0] * arr[i][0] + item[1] * arr[i][1] +item[2] * arr[i][2] +item[3] * arr[i][3]);
                //delimoe += out[i] * rules[i][4];
                deleter += out[i];
            }
            double y1 = delimoe / deleter;
            double y2 = returnType(y1);   

            Double[] tmpArr = {y1, y2, item[4]};

            totalArr.add(tmpArr);
        }
        
        saveProbability(totalArr, inputData);

        ArrayList<String> successOrNot = new ArrayList<>();

        double right = 0;
        double wrong = 0;
        for(Double[] item : totalArr) {
            System.out.print(item[0] + "\t");
            System.out.print(item[1] + "\t");
            if(Math.abs(item[2] - item[1]) < 0.1) {
                System.out.println("Success");
                successOrNot.add("Success");
                right++;
            } else {
                System.out.println("Not Success");
                successOrNot.add("Not success");
                wrong++;
            }
        }
        System.out.println("Total percentage = " + right / (right + wrong));
        
        saveData(totalArr, successOrNot);

        Scanner scanner = new Scanner(System.in); // black box
        while(true) {
            double[] x = new double[4]; // read x1 x2 x3 x4
            for (int i = 0; i < 4; i++) {
                x[i] = scanner.nextDouble();
            }

            double[] out = new double[6]; // find Mu
            for (int i = 0; i < rules.length ; i++) {
                out[i] = 1;
                for (int j = 0; j < 4; j++) {
                    out[i] *= getMu(rules[i][j], x[j]);
                }
            }

            double delimoe = 0; //find y^
            double deleter = 0;
            for (int i = 0; i < rules.length; i++) {
                delimoe += out[i] * rules[i][4];
                deleter += out[i];
            }
            double y1 = delimoe / deleter;
            
            int y2 = returnType(y1);
            String answer = "";
            if(y2 == 1){
            	answer = "Setosa(1)";
            }else if(y2 == 2){
            	answer = "Versicolor(2)";
            }else if(y2 == 3){
            	answer = "Virginica(3)";
            }

            System.out.println("Answer is " + answer);
        }

    }

    
}
