dbscan / Exp3.java
Exp3.java
Raw
/*
 * Name - Anisaftab Saiyed
 * ID- 300259073
 * CSI PA2
 */

import java.util.*;
import java.io.*;

/*
 * Exp3 classes is the DBScan class but it uses the NearestNeighborsKD and performs the 
 * cluster searching using kdtree.
 */

class Exp3{

    /**
     * eps is the epsilon distance for which we will find the surrounding points
     * minPts is the mininmum number of points required to form a cluster.
     */
    private double eps, minPts;
    /**
     * dBpoints is a Stack containing all the points form the csv file
     */
    private List<Point3D> dBpoints ;
    /**
     * cluster is the total number of cluster found in the file so far.
     */
    private int clusters = 0;

    /**
     * Constructs the a class which will scan all the points from given csv file for clusters.
     * @param givenStack a Stack of all the points from the given csv file.
     */    
    Exp3( List<Point3D> givenStack ){
        this.dBpoints = givenStack;

    }

    /**
     * Setter for eps
     * @param eps  epsilon distance from the user
     * @return epilon distance
     */
    public double setEps(double eps){
        return this.eps = eps;
    }
    /**
     * Setter for minPts value
     * @param minPts  minPts to be considered for a group to be countes as a cluster
     * @return minPts
     */
    public double setMinPts( double minPts){
        return this.minPts = minPts;
    }

    /**
     * findClusters method executes the DBScan algorithm using kdsearchtree.
     * It scans all the points in dBpoints to look for clusters.
     * the method will find all the points within eps distance of the given point and then assign a unique label to all of them.
     * It will then look for more points from the surrounding perimeter points.
     */
    public void findClusters(){
        
        for (Point3D currPoint : dBpoints){
            List<Point3D> nearbyPoints = new ArrayList<Point3D>();

            int currPointLabel = currPoint.getLabel();

            if(currPointLabel != 0){
                continue;
            }

            
            NearestNeighborsKD findNeighbours = new NearestNeighborsKD(dBpoints);
            nearbyPoints = findNeighbours.getNeighbors(currPoint, eps);
            
            if( nearbyPoints.size() < this.minPts ){
                currPoint.setLabel(-1);
                continue;
            }

            clusters++;
            currPoint.setLabel(clusters);

            Stack<Point3D> tempStack = new Stack<Point3D>();
            tempStack.addAll(nearbyPoints) ;

            while ( !tempStack.isEmpty()){
                Point3D p = tempStack.pop();
                if (p.getLabel() == -1){ p.setLabel(clusters);}

                if ( p.getLabel() != 0){
                    continue;
                }
                p.setLabel(clusters);
                NearestNeighborsKD findOtherNeighbours = new NearestNeighborsKD(dBpoints);
                List<Point3D> otherNearbyPoints =  findOtherNeighbours.getNeighbors(p,eps) ;
                if(otherNearbyPoints.size() >= minPts ){
                    for (Point3D point : otherNearbyPoints) {
                        tempStack.push(point);
                    }
                }
            }
            
            

        }
    }

    /**
     * Getter for number of clusters found inside the dBpoints
     * @return total number of clusters
     */
    public int getNumberOfClusters(){
        return this.clusters;
    }
    /**
     * Getter for the stack containing all the points from the csv file.
     * @return stack of all given points in a file
     */
    public List<Point3D> getPoints(){
        return this.dBpoints;
    }

    /**
     * Read a given file for data points and store them in a stack.
     * @param filename filename to search for points within the current directory
     * @return a stack containing all the points inside the given file
     * @throws FileNotFoundException 
     */
    public static List<Point3D> read(String filename ) {
        List<Point3D> myStack = new ArrayList<Point3D>();
        Scanner sc;
        try{
            sc = new Scanner(new BufferedReader(new FileReader(filename))) ;
            String newLine = "";
            newLine = sc.nextLine();
            while(sc.hasNext()){
                newLine = sc.nextLine();
                String[] arr = newLine.split(",");
                Point3D pt = new Point3D(Double.parseDouble(arr[0]) , Double.parseDouble(arr[1]), Double.parseDouble(arr[2])) ;
                myStack.add(pt);
                
            }
            
            sc.close();
        }
        catch(FileNotFoundException f){
            f.printStackTrace();
        }
        catch( Exception e){
            e.printStackTrace();
        }

        return myStack;
    }

    /**
     * Stores all the points with their cluster label and rgb values in the given filename.
     * @param filename  a csv file name where all the data will be stored
     * @throws IOException
     */
    public void save(String filename){

        try {
            File csv = new File(filename + ".csv");
            PrintWriter writer = new PrintWriter(csv);
            
            writer.println( "x,y,z,C,R,G,B");
            for(Point3D p: dBpoints){
                
                writer.println(p.getX() + "," +  p.getY() + "," + p.getZ() + "," + p.getLabel() + "," + rgb(p.label) );
            }
            
            writer.close();
            
            


        } catch (IOException e) {
            e.printStackTrace();
        }
        
        
    }
    /**
	* Function to get unique rgb values from cluster labels from P1_sol.
    * @param label Integer cluster label
	* @return
	* 	String containing r,g,b values separated by commas
	*/
    private String rgb (int label){
        if (label==0){return "0.0,0.0,0.0";}
          
        double val = (0.5*label)/clusters+0.5;
        int id=(label%7)+1;
        /*  Bit shift shenanigans to make 7 colors
        *   colors are then scaled up with val
        *   to make them unique for each cluster
        */
        double r = (id>>2)*val;
        double g = ((id%4)>>1)*val;
        double b = (id%2)*val;
        String s = r+","+g+','+b;
        return s;
    }


    /**
     * Prints the total number of points present in each cluster.
     */
    public void numberOfPointsInCluster(){
        int count = 0;
        for(int i = 0; i <= clusters; i++){
            for(Point3D p : dBpoints){
                if(p.getLabel() == i){
                    count += 1;
                }
            }
            System.out.println("Label " + i + ":" + count);
        }
    }

    
    /** 
     * @param args Pls enter arguments in the following order : args[0]= csv filename, 
     * args[1]= epsilon distance value to look for nearby points,
     * args[2]= minimum points to form a cluster
     */
    public static void main( String[] args ) {
        
        List<Point3D> myStack = new ArrayList<Point3D>();
        //start the timer
        long startTime = System.currentTimeMillis();
        myStack = Exp3.read(args[0]);
        Exp3 myScan = new Exp3(myStack);

        myScan.setEps(Double.parseDouble(args[1]));
        myScan.setMinPts(Double.parseDouble(args[2]));

        myScan.findClusters();

        String[] file = args[0].split("\\.");
    
        
        myScan.save(file[0] + "_kd_"+ "_clusters_" + args[1] + "_" + args[2] );
        long endTime = System.currentTimeMillis();
        //timer ends
        long duration = (endTime - startTime);
        System.out.println( "KD DBScan takes " + duration + "ms");
        System.out.println("Total number of Clusters: " + myScan.getNumberOfClusters());
        myScan.numberOfPointsInCluster();
        
        
    }
}