MVA-2021 / npm3d / hw1_basics / code / subsampling.py
subsampling.py
Raw
import numpy as np
from ply import write_ply, read_ply
from collections import defaultdict

import time

def cloud_decimation(points, colors, labels, factor):

    decimated_points = points[::factor,:]
    decimated_colors = colors[::factor,:]
    decimated_labels = labels[::factor]

    return decimated_points, decimated_colors, decimated_labels

def cloud_decimation_grid(points, colors, labels, factor):

    a, b, c = np.inf, np.inf, np.inf
    res = 0.10
    print(points.shape[0])
    while a*b*c > points.shape[0]/factor:
        xyz_max = np.max(points, axis=0)
        xyz_min = np.min(points, axis=0)
        a, b, c = (xyz_max-xyz_min) // res +1
        res += 0.001
    print(a*b*c)
    indices = (points-xyz_min) // res

    boxes_points = defaultdict(list)
    boxes_colors = defaultdict(list)
    boxes_labels = defaultdict(list)
    for ii in range(len(points)):
        x,y,z = indices[ii]
        key = (x,y,z)
        boxes_points[key].append(points[ii])
        boxes_colors[key].append(colors[ii])
        boxes_labels[key].append(labels[ii])

    decimated_points = []
    decimated_colors = []
    decimated_labels = []

    for key in boxes_colors.keys():
        boxes_points[key]= np.stack(boxes_points[key], axis=0)
        boxes_colors[key]= np.stack(boxes_colors[key], axis=0)
        decimated_points.append(np.mean(boxes_points[key], axis=0))
        decimated_colors.append(boxes_colors[key][0])
        decimated_labels.append(boxes_labels[key][0])

    decimated_points = np.array(decimated_points)
    decimated_colors = np.array(decimated_colors)
    decimated_labels = np.array(decimated_labels)

    return decimated_points, decimated_colors, decimated_labels



if __name__ == '__main__':

    # Load point cloud
    # ****************
    #
    #   Load the file '../data/indoor_scan.ply'
    #   (See read_ply function)
    #

    # Path of the file
    file_path = '../data/indoor_scan.ply'

    # Load point cloud
    data = read_ply(file_path)

    # Concatenate data
    points = np.vstack((data['x'], data['y'], data['z'])).T
    colors = np.vstack((data['red'], data['green'], data['blue'])).T
    labels = data['label']    

    # Decimate the point cloud
    # ************************
    #

    BRUTE = False

    if BRUTE:

        # Define the decimation factor
        factor = 300

        # Decimate
        t0 = time.time()
        decimated_points, decimated_colors, decimated_labels = cloud_decimation(points, colors, labels, factor)
        t1 = time.time()
        print('decimation done in {:.3f} seconds'.format(t1 - t0))

        # Save
        write_ply('../decimated.ply', [decimated_points, decimated_colors, decimated_labels], ['x', 'y', 'z', 'red', 'green', 'blue', 'label'])

    else:
        # Define grid resolution
        factor = 300

        # Decimate
        t0 = time.time()
        decimated_points, decimated_colors, decimated_labels = cloud_decimation_grid(points, colors, labels, factor)
        t1 = time.time()
        print('decimation done in {:.3f} seconds'.format(t1 - t0))

    # Save
    write_ply('../decimated.ply', [decimated_points, decimated_colors, decimated_labels], ['x', 'y', 'z', 'red', 'green', 'blue', 'label'])
    
    print('Done')