#hbarto import sys import heapq import math import numpy as np from PIL import Image # row/column offsets # note that rows increase in the down direction # cols increase to the right ALL_ACTIONS = [ (1, 0), # down (0, 1), # right (-1, 0), # up (0, -1), # left (1, 1), # down-right (1, -1), # down-left (-1, 1), # up-right (-1, -1), # up-left ] # This function should return a list of tuples of neighbor cells that # are reachable from the current state x # # env_map is a Boolean 2D numpy array indicating whether each cell is # free (True) or blocked (False) # # x is a tuple of (row, col) indices def reachable_neighbors(env_map, x): # Iterate over each row/column offset in ALL_ACTIONS and # return the ones that result in new states that are: # # - in bounds # - free space # reachable = [] position = np.array(x) for i in ALL_ACTIONS: #make each "action" into an array so it's easier to handle move = np.array(i) # each adjacent cell should be tested so add original position to each move, then put in tuple form x = position[0] + move[0] y = position[1] + move[1] idx = np.array([x,y]) test_cell = tuple(idx) #check T/F value of cell viable = env_map[test_cell] if (viable): # add cell position to our list of tuples if it's viable reachable.append(test_cell) return reachable # This function should return the Euclidean distance between the # grid cells. The x_cur and x_next parameters are both tuples of # (row, col) indices for grid cells. def distance(x_cur, x_next): coor1 = np.array(x_cur) coor2 = np.array(x_next) dist = math.sqrt((coor2[0] - coor1[0])**2 + (coor2[1] - coor1[1])**2) return dist # The heuristic function returns the Euclidean distance between cells, # scaled by the heuristic inflation. def heuristic(x_cur, x_goal, heuristic_inflation): return heuristic_inflation*distance(x_cur, x_goal) # Find a path on a 2D grid. Parameters: # # env_map is a H-by-W Boolean 2D numpy array indicating whether each # cell is free (True) or blocked (False) # # x_start and x_goal are a tuple of (row, col) indices into the array # for the starting and ending states of the search # # heuristic_inflation is a floating-point scale factor on the # heuristic. Setting heuristic_inflation equal # to zero should be equivalent to Dijksta's # algorithm. Any heuristic greater than 1.0 is # not strictly admissible (i.e. not guaranteed to # find an optimal path). # # Returns a tuple consisting of the following: # # success - True if a path was found, False otherwise. # # path - A list of (row, col) states along the optimal path, # beginning with the start state and ending with the goal # state, or None if no path exists. # # path_cost - The total cost to come along the path, or np.inf if no # path exists. # # cost_to_come - An H-by-W array of floats indicating optimal cost to reach # any visited states, or np.inf if a state was not visited. # # pred - An H-by-W-by-2 array of integer indices indicating # predecessor row/columns for each visited state except the # initial state. Any state without a predecessor should have # (-1, -1) stored in the array. # # Note that even if the search was unsuccessful, the returned arrays # cost_to_come and pred may contain useful information. def astar_search(env_map, x_start, x_goal, heuristic_inflation): print('searching for path from {} to {} in map of shape {}'.format( x_start, x_goal, env_map.shape)) assert len(env_map.shape) == 2 and env_map.dtype == np.bool assert env_map[x_start] and env_map[x_goal] success = False path = [] #changed this from "None" to [] path_cost = np.inf cost_to_come = np.inf * np.ones(env_map.shape) pred = -np.ones(env_map.shape + (2,), dtype=int) cost_to_come[x_start] = 0.0 start_priority = heuristic(x_start, x_goal, heuristic_inflation) Q = [ (start_priority, x_start) ] progress_chunk_size = int(env_map.sum() / 100.0) progress_counter = 0 while Q: # while not empty # extract the item in the Q with the lowest priority (priority, x_cur) = heapq.heappop(Q) # check for success - if successful, update # success, path, path_cost and break out of the loop here #print ('xcur: ', x_cur) if (x_cur == x_goal): #how to compare tuple values? do I have to compare each value as an array? eg. x_cur[0] == x_goal[0] && x_cur[1] == x_goal[1] success = True backtrack = tuple(x_cur) # retrace steps to find path x_begin = tuple(x_start) while(backtrack != x_begin): path.append(backtrack) backtrack = tuple(pred[backtrack]) if (backtrack == x_begin): path.append(backtrack) path_cost = cost_to_come[x_cur] return (success, path, path_cost, cost_to_come, pred) neighbors = reachable_neighbors(env_map, x_cur) for x_next in neighbors: one_step_cost = distance(x_cur, x_next) g = cost_to_come[x_cur] + one_step_cost if g < cost_to_come[x_next]: #update cost_to_come array # cost_to_come - An H-by-W array of floats indicating optimal cost to reach # any visited states, or np.inf if a state was not visited. # x_next from the list "neighbors" is a tuple, so can be used to index into "cost_to_come" array # if previous value for "cost_to_come" of this next cell is greater than g, then give it g as new value #(either the cell had never been visited and so was infinity, or it was reached via a different longer/more expensive route ) cost_to_come[x_next] = g #update pred array # pred - An H-by-W-by-2 array of integer indices indicating # predecessor row/columns for each visited state except the # initial state. Any state without a predecessor should have # (-1, -1) stored in the array. #So, I again use the x_next tuple to index into the array, and replace the value there with x_cur which is x_next's predecessor pred[x_next] = x_cur # enqueue (priority, state) tuple into priority queue. h = heuristic(x_next, x_goal, heuristic_inflation) next_priority = g + h heapq.heappush(Q, (next_priority, x_next)) progress_counter += 1 if progress_counter > progress_chunk_size: # "progress" indicator sys.stdout.write('.') sys.stdout.flush() progress_counter = 0 print('\ndone!') return success, path, path_cost, cost_to_come, pred ###################################################################### def make_debug_image(env_map, cost_to_come, path, palette): palette = np.array(palette) debug_image = np.empty(env_map.shape + (3,), dtype=float) debug_image[~env_map] = 0.35 debug_image[env_map] = 0.65 valid_mask = ~np.isinf(cost_to_come) valid_costs = cost_to_come[valid_mask] sz = 0.35 * min(env_map.shape[0], env_map.shape[1]) valid_idx = valid_costs * len(palette) / sz valid_idx = np.round(valid_idx).astype(int) valid_idx = valid_idx % len(palette) print('idx range', valid_idx.min(), valid_idx.max()) debug_image[valid_mask] = palette[valid_idx] debug_image = (debug_image*255).astype(np.uint8) if path is not None: path = np.array(path) debug_image[path[:,0], path[:,1]] = (255, 0, 0) debug_image = Image.fromarray(debug_image, 'RGB') debug_image.save('debug_output.png') print('wrote debug_output.png') ###################################################################### def main(): if len(sys.argv) != 7: print('usage: astar.py IMAGEFILENAME STARTROW STARTCOL ' 'GOALROW GOALCOL INFLATION') sys.exit(1) imgfile = sys.argv[1] coords = [ int(arg) for arg in sys.argv[2:6] ] heuristic_inflation = float(sys.argv[6]) x_start = (coords[0], coords[1]) x_goal = (coords[2], coords[3]) # open image, convert to np array, threshold env_map = Image.open(imgfile).convert('L') env_map = np.array(env_map) env_map = np.where(env_map > 127, True, False) result = astar_search(env_map, x_start, x_goal, heuristic_inflation) success, path, path_cost, cost_to_come, pred = result print('search expanded {} of {} possible nodes'.format( np.sum(~np.isinf(cost_to_come)), env_map.sum())) if success: print('search returned successful path of length {:.2f}'.format(path_cost)) assert path is not None assert not np.isinf(path_cost) assert cost_to_come[x_goal] == path_cost else: print('no path existed') assert path is None assert np.isinf(path_cost) assert cost_to_come[x_goal] == np.inf # twilight color map https://github.com/bastibe/twilight palette = np.loadtxt('twilight.txt', delimiter=',').reshape(-1, 3) make_debug_image(env_map, cost_to_come, path, palette) ###################################################################### if __name__ == '__main__': main()