comp-vision-group-cw / src / main / java / uk / ac / soton / ecs / zmk1g19 / Run1.java
Run1.java
Raw
package uk.ac.soton.ecs.zmk1g19;

import org.openimaj.data.dataset.VFSGroupDataset;
import org.openimaj.data.dataset.VFSListDataset;
import org.openimaj.experiment.dataset.split.GroupedRandomSplitter;
import org.openimaj.experiment.evaluation.classification.ClassificationEvaluator;
import org.openimaj.experiment.evaluation.classification.ClassificationResult;
import org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser;
import org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMResult;
import org.openimaj.feature.DoubleFV;
import org.openimaj.feature.DoubleFVComparison;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.image.FImage;
import org.openimaj.ml.annotation.basic.KNNAnnotator;
import org.openimaj.util.parallel.Parallel;
import uk.ac.soton.ecs.zmk1g19.featureExtractors.TinyImageExtractor;

import java.util.Map;

/**
 * Run 1: KNN
 *
 * @author Sofia Kisiala    (zmk1g19)
 * @author Harry Nelson     (hjn2g19)
 * @author Max Burgess      (mwmb1g19)
 * @author Anan Venkatesh   (av1u19)
 * @author Fergus Adams     (fhwa1g19)
 */
public class Run1 {

    /**
     * Main method used to run our first classifier.
     *
     * @param training Image set used to train classifier
     * @param testing  Image set used to test classifier
     */
    public static void run(VFSGroupDataset<FImage> training, VFSListDataset<FImage> testing) {
        //create custom feature extractor
        FeatureExtractor<DoubleFV, FImage> fe = new TinyImageExtractor();
        //test for optimal value of k
        int k = findK(training, fe, 15, 5);
        //create and train our knn annotator using euclidean metric and optimal k
        KNNAnnotator<FImage, String, DoubleFV> knn = new KNNAnnotator<>(fe, DoubleFVComparison.EUCLIDEAN, k);
        knn.train(training);
        //build our file output
        App.writeAnnotatorOutput(knn, "run1", testing);
    }

    /**
     * Testing utility function for getting the optimal value of k for the KNNAnnotator.
     *
     * @param training Training image set
     * @param fe       Feature extractor
     * @param range    Range of k values to test
     * @param runs     Number of runs to test each value on
     * @return Optimal value of k
     */
    static int findK(VFSGroupDataset<FImage> training, FeatureExtractor<DoubleFV, FImage> fe, int range, int runs) {
        Object lock = new Object();
        //split data into training and testing
        GroupedRandomSplitter<String, FImage> split = new GroupedRandomSplitter<>(training, 90, 0, 10);
        float bestAcc = 0;
        int bestK = 0;
        //test all k between 1 and 15
        for (int k = 1; k <= range; k++) {
            final float[] avg = {0f};
            //perform each test 5 times and average the result
            int finalK = k;
            Parallel.forIndex(0, runs, 1, n -> { //run tests in parallel
                KNNAnnotator<FImage, String, DoubleFV> knn = new KNNAnnotator<>(fe, DoubleFVComparison.EUCLIDEAN, finalK);
                knn.train(split.getTrainingDataset());
                ClassificationEvaluator<CMResult<String>, String, FImage> eval = new ClassificationEvaluator<>(knn, split.getTestDataset(),
                        new CMAnalyser<FImage, String>(CMAnalyser.Strategy.SINGLE)); //create evaluator for our testing
                Map<FImage, ClassificationResult<String>> guesses = eval.evaluate();
                CMResult<String> result = eval.analyse(guesses);
                synchronized (lock) {
                    avg[0] += Float.parseFloat(result.getSummaryReport().substring(12, 17)); //only way to parse the result
                }
            });
            avg[0] = avg[0] / runs;
            if (avg[0] > bestAcc) { //record new average and k if record is beaten
                bestAcc = avg[0];
                bestK = k;
            }
        }
        System.out.println("Best accuracy was " + bestAcc + " when k = " + bestK + ".");
        return bestK;
    }
}