dbscan / DBScan.java
DBScan.java
Raw
/*Name- Anisaftab Saiyed
 * ID- 300259073
 * CSI 2110 Programming assignment 1
 */
import java.util.*;
import java.io.*;



class DBScan{

    /**
     * eps is the epsilon distance for which we will find the surrounding points
     * minPts is the mininmum number of points required to form a cluster.
     */
    private double eps, minPts;
    /**
     * dBpoints is a Stack containing all the points form the csv file
     */
    private Stack<Point3D> dBpoints = new Stack<Point3D>();
    /**
     * cluster is the total number of cluster found in the file so far.
     */
    private int clusters = 0;

    /**
     * Constructs the a class which will scan all the points from given csv file for clusters.
     * @param givenStack a Stack of all the points from the given csv file.
     */    
    public DBScan( Stack<Point3D> givenStack ){
        this.dBpoints = givenStack;

    }

    /**
     * Setter for eps
     * @param eps  epsilon distance from the user
     * @return epilon distance
     */
    public double setEps(double eps){
        return this.eps = eps;
    }
    /**
     * Setter for minPts value
     * @param minPts  minPts to be considered for a group to be countes as a cluster
     * @return minPts
     */
    public double setMinPts( double minPts){
        return this.minPts = minPts;
    }

    /**
     * findClusters method executes the DBScan algorithm.
     * It scans all the points in dBpoints to look for clusters.
     * the method will find all the points within eps distance of the given point and then assign a unique label to all of them.
     * It will then look for more points from the surrounding perimeter points.
     */
    public void findClusters(){
        
        for (Point3D currPoint : dBpoints){
            Stack<Point3D> nearbyPoints = new Stack<Point3D>();

            int currPointLabel = currPoint.getLabel();

            if(currPointLabel != 0){
                continue;
            }

            
            NearestNeighbours findNeighbours = new NearestNeighbours(dBpoints);
            nearbyPoints = findNeighbours.rangeQuery(eps, currPoint);
            
            if( nearbyPoints.size() < this.minPts ){
                currPoint.setLabel(-1);
                continue;
            }

            clusters++;
            currPoint.setLabel(clusters);

            Stack<Point3D> tempStack = nearbyPoints;

            while ( !tempStack.isEmpty()){
                Point3D p = tempStack.pop();
                if (p.getLabel() == -1){ p.setLabel(clusters);}

                if ( p.getLabel() != 0){
                    continue;
                }
                p.setLabel(clusters);
                NearestNeighbours findOtherNeighbours = new NearestNeighbours(dBpoints);
                Stack<Point3D> otherNearbyPoints =  findOtherNeighbours.rangeQuery(eps, p);
                if(otherNearbyPoints.size() >= minPts ){
                    for (Point3D point : otherNearbyPoints) {
                        tempStack.push(point);
                    }
                }
            }
            
            

        }
    }

    /**
     * Getter for number of clusters found inside the dBpoints
     * @return total number of clusters
     */
    public int getNumberOfClusters(){
        return this.clusters;
    }
    /**
     * Getter for the stack containing all the points from the csv file.
     * @return stack of all given points in a file
     */
    public Stack<Point3D> getPoints(){
        return this.dBpoints;
    }

    /**
     * Read a given file for data points and store them in a stack.
     * @param filename filename to search for points within the current directory
     * @return a stack containing all the points inside the given file
     * @throws FileNotFoundException 
     */
    public static Stack<Point3D> read(String filename ) {
        Stack<Point3D> myStack = new Stack<Point3D>();
        Scanner sc;
        try{
            sc = new Scanner(new BufferedReader(new FileReader(filename))) ;
            String newLine = "";
            newLine = sc.nextLine();
            while(sc.hasNext()){
                newLine = sc.nextLine();
                String[] arr = newLine.split(",");
                Point3D pt = new Point3D(Double.parseDouble(arr[0]) , Double.parseDouble(arr[1]), Double.parseDouble(arr[2])) ;
                myStack.push(pt);
                
            }
            
            sc.close();
        }
        catch(FileNotFoundException f){
            f.printStackTrace();
        }
        catch( Exception e){
            e.printStackTrace();
        }

        return myStack;
    }

    /**
     * Stores all the points with their cluster label and rgb values in the given filename.
     * @param filename  a csv file name where all the data will be stored
     * @throws IOException
     */
    public void save(String filename){

        try {
            File csv = new File(filename + ".csv");
            PrintWriter writer = new PrintWriter(csv);
            
            writer.println( "x,y,z,C,R,G,B");
            for(Point3D p: dBpoints){
                String rgb = String.valueOf(p.getLabel()/ this.getNumberOfClusters());
                writer.println(p.getX() + "," +  p.getY() + "," + p.getZ() + "," + p.getLabel() + "," + rgb + "," + rgb + "," + rgb );
            }
            
            writer.close();
            
            


        } catch (IOException e) {
            e.printStackTrace();
        }
        
        
    }


    /**
     * Prints the total number of points present in each cluster.
     */
    public void numberOfPointsInCluster(){
        int count = 0;
        for(int i = 0; i <= clusters; i++){
            for(Point3D p : dBpoints){
                if(p.getLabel() == i){
                    count += 1;
                }
            }
            System.out.println("Label " + i + ":" + count);
        }
    }

    public static void main( String[] args ) {
        
        Stack<Point3D> myStack = new Stack<Point3D>();
        long startTime = System.currentTimeMillis();
        myStack = DBScan.read(args[0]);
        DBScan myScan = new DBScan(myStack);

        myScan.setEps(Double.parseDouble(args[1]));
        myScan.setMinPts(Double.parseDouble(args[2]));

        myScan.findClusters();

        String[] file = args[0].split("\\.");
    
        
        myScan.save(file[0] + "_clusters_" + args[1] + "_" + args[2] );
        long endTime = System.currentTimeMillis();

        long duration = (endTime - startTime);
        System.out.println( "Linear DBScan takes " + duration + "ms");
        System.out.println("Total number of Clusters: " + myScan.getNumberOfClusters());
        myScan.numberOfPointsInCluster();
        
        
    }
}