""" Basic functions for presenting and recording stimuli to network. Used by other .py files like bar.py, pasu_shape.py, and grating.py Tony Fu, Sep 14, 2022 """ import os import sys import math import warnings import datetime import numpy as np import matplotlib.pyplot as plt from matplotlib.backends.backend_pdf import PdfPages __all__ = ['clip', 'stimset_gridx_map', 'print_progress', 'add_weighted_map', 'add_non_overlap_map', 'summarize_TB1', 'summarize_TBn', 'record_stim_counts', 'record_splist', 'record_center_responses', 'record_script_log', 'mapstat_comr_1', 'make_map_pdf'] #######################################.####################################### # # # IMPORT JIT # # # # Numba may not work with the lastest version of NumPy. In that case, a # # do-nothing decorator also named jit is used. # # # ############################################################################### try: from numba import jit except: warnings.warn("stimulus.py cannot import Numba.") def jit(func): """ A do-nothing decorator in place of the actual njit in case that Python cannot import Numba. """ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper #######################################.####################################### # # # CLIP # # # ############################################################################### @jit def clip(val, vmin, vmax): """Limits value to be vmin <= val <= vmax""" if vmin > vmax: raise Exception("vmin should be smaller than vmax.") val = min(val, vmax) val = max(val, vmin) return val #######################################.####################################### # # # STIMSET_GRIDX_MAP # # # # Given a stimulus length and maximum RF size (both in pixels), return a list# # of the x-coordinates of the grid points relative to the center of the # # image field. # # # # I believe the following are true: # # (1) The center coordinate "0.0" will always be included # # (2) There will be an odd number of coordinates # # (3) The extreme coordinates will never be more then half of a stimulus # # length outside of the maximum RF ('max_rf') # # # ############################################################################### def stimset_gridx_map(max_rf,stim_len): """ Parameters ---------- max_rf - maximum RF size (pix)\n stim_len - stimulus length (pix)\n """ dx = stim_len / 2.0 # Grid spacing is 1/2 of stimulus length xmax = round((max_rf/dx) / 2.0) * dx # Max offset of grid point from center xlist = np.arange(-xmax,xmax+1,dx) return xlist # Test if __name__ == '__main__': max_rf = 49 blen = 5 print(stimset_gridx_map(max_rf,blen)) #######################################.####################################### # # # PRINT_PROGRESS # # # ############################################################################### def print_progress(text): """ Prints progress (whatever text) without printing a new line everytime. """ sys.stdout.write('\r') sys.stdout.write(text) sys.stdout.flush() #######################################.####################################### # # # ADD_WEIGHTED_MAP # # # ############################################################################### def add_weighted_map(new_stim, sum_map, response): """ Add the new_stim, weighted by the unit's rectified response to the stimulus, to the sum_map. Parameters ---------- new_stim - stimulus to be added to the map. sum_map - cumulative stimulus map. response - the unit's response to the new_stim. """ sum_map += new_stim * response #######################################.####################################### # # # ADD_NON_OVERLAP_MAP # # # ############################################################################### def add_non_overlap_map(new_stim, sum_map, stim_thr): """ Add the new_stim to the map if the new_stim is not overlapping with any existing stimuli. The new_stim is first binarized with the {stim_thr} threshold to get rid of some of the anti-aliasing pixels. Parameters ---------- new_stim - stimulus to be added to the map.\n sum_map - cumulative stimulus map.\n stim_thr - stimulus pixels w/ a value below stim_thr will be excluded. Returns ------- True if the new_stim has been included in the """ # Binarize new_stim new_stim[new_stim < stim_thr] = 0 new_stim[new_stim >= stim_thr] = 1 # Only add the new stim if it is not overlapping with any existing stimuli. if not np.any(np.logical_and(sum_map>0, new_stim>0)): sum_map += new_stim return True return False #######################################.####################################### # # # SUMMARIZE_TB1 # # # ############################################################################### def summarize_TB1(splist, center_responses, layer_name, txt_path): """ Summarize the top and bottom stimuli in a .txt file in format: layer_name, unit_i, top_idx, top_x, top_y, top_r, bot_idx, bot_x, bot_y Parameters ---------- splist - stimulus parameter list.\n center_responses - responses of center unit in [stim_i, unit_i] format.\n model_name - name of the model. Used for file naming.\n layer_name - name of the layer. Used as file entries/primary key.\n txt_path - path name of the file, must end with '.txt'\n """ num_units = center_responses.shape[1] # shape = [stim, unit] with open(txt_path, 'a') as f: for unit_i in range(num_units): isort = np.argsort(center_responses[:, unit_i]) # Ascending top_i = isort[-1] bot_i = isort[0] top_r = center_responses[top_i, unit_i] bot_r = center_responses[bot_i, unit_i] top_stim = splist[top_i] bot_stim = splist[bot_i] f.write(f"{layer_name:} {unit_i:} ") f.write(f"{top_i:} {top_stim['x0']:.2f} {top_stim['y0']:.2f} {top_r:.4f} ") f.write(f"{bot_i:} {bot_stim['x0']:.2f} {bot_stim['y0']:.2f} {bot_r:.4f}\n") #######################################.####################################### # # # SUMMARIZE_TBn # # # ############################################################################### def summarize_TBn(splist, center_responses, layer_name, txt_path, top_n=20): """ Summarize the top- and bottom-n stimuli in a .txt file in format: layer_name, unit_i, top_avg_x, top_avg_y, bot_avg_x, bot_avg_y Parameters ---------- splist - the stimulus parameter list.\n center_responses - the responses of center unit in [stim_i, unit_i] format.\n model_name - name of the model. Used for file naming.\n layer_name - name of the layer. Used as file entries/primary key.\n txt_dir - the path name of the file, must end with '.txt'\n top_n - the top and bottom N stimuli to record. """ num_units = center_responses.shape[1] # shape = [stim, unit] with open(txt_path, 'a') as f: for unit_i in range(num_units): isort = np.argsort(center_responses[:, unit_i]) # Ascending # Initializations top_avg_x = 0 top_avg_y = 0 bot_avg_x = 0 bot_avg_y = 0 for i in range(top_n): top_i = isort[-i-1] bot_i = isort[i] # Equally weighted sum for avg coordinates of stimuli. top_avg_x += splist[top_i]['x0']/top_n top_avg_y += splist[top_i]['y0']/top_n bot_avg_x += splist[bot_i]['x0']/top_n bot_avg_y += splist[bot_i]['y0']/top_n f.write(f"{layer_name} {unit_i} ") f.write(f"{top_avg_x:.2f} {top_avg_y:.2f} ") f.write(f"{bot_avg_x:.2f} {bot_avg_y:.2f}\n") #######################################.####################################### # # # RECORD_STIM_COUNTS # # # ############################################################################### def record_stim_counts(txt_path, layer_name, unit_i, num_max_stim, num_min_stim): """Write the numbers of stimuli used in the top and bottom maps.""" with open(txt_path, 'a') as f: f.write(f"{layer_name} {unit_i} {num_max_stim} {num_min_stim}\n") #######################################.####################################### # # # RECORD_SPLIST # # # ############################################################################### def record_splist(txt_path, splist): """Write the contents of splist into a text file.""" with open(txt_path, 'a') as f: for stimulus_idx, params in enumerate(splist): f.write(f"{stimulus_idx}") for val in params.values(): f.write(f" {val}") f.write('\n') #######################################.####################################### # # # RECORD_CENTER_RESPONSES # # # ############################################################################### def record_center_responses(txt_path, center_responses, top_n, is_top): """ Write the indicies and responses of the top- and bottom-N into a text file. """ num_units = center_responses.shape[1] # in dimension: [stimulus, unit] center_responses_sorti = np.argsort(center_responses, axis=0) if is_top: center_responses_sorti = np.flip(center_responses_sorti, 0) with open(txt_path, 'a') as f: for unit_i in range(num_units): for i, stimulus_idx in enumerate(center_responses_sorti[:, unit_i]): if i >= top_n: break f.write(f"{unit_i} {i} {stimulus_idx} ") f.write(f"{center_responses[stimulus_idx, unit_i]:.4f}\n") # Format: unit_i, rank, stimulus_index, response_value #######################################.####################################### # # # RECORD_SCRIPT_LOG # # # ############################################################################### def record_script_log(txt_path, layer_name, batch_size, response_thres, num_stim): with open(txt_path, 'a') as f: now = datetime.datetime.now() f.write(f"{now} {layer_name} batch_size={batch_size} response_thres={response_thres} num_stim={num_stim}\n") #######################################.####################################### # # # MAPSTAT_COMR_1 # # # # For a 2D array 'map', compute the (x,y) center of mass and the radius # # that contains a fraction 'f' of the area of the map. # # # ############################################################################### def mapstat_comr_1(map,f): # # map = map.copy() map = map - map.min() xn = len(map) list0 = np.sum(map,1) # Sum along axis 1, to ultimately get COM on axis 0 list1 = np.sum(map,0) # Sum along axis 0, to ultimately get COM on axis 1 total = np.sum(list0) # Overall total weight of entire map xvals = np.arange(xn) # X-values (pix) prod0 = xvals*list0 prod1 = xvals*list1 if (total > 0.0): com0 = np.sum(prod0)/total com1 = np.sum(prod1)/total else: com0 = com1 = -1 dist2 = [] # empty list to hold squared distances from COM magn = [] # empty list to hold magnitude for i in range(xn): di2 = (i-com0)*(i-com0) for j in range(xn): dj2 = (j-com1)*(j-com1) if (map[i,j] > 0.0): dist2.append(di2 + dj2) magn.append(map[i,j]) isort = np.argsort(dist2) # Get list of indices that sort list (least 1st) if (com0 == -1): return -1, -1, -1 n = len(dist2) # Go down the sorted list, adding up the magnitudes, until the fractional # criterion is exceeded. Compute the radius for the final point added. # tot = 0.0 k = 0 for i in range(n): # for each non-zero position in the map k = isort[i] # Get index 'k' of the next point closest to the COM tot += magn[k] if (tot/total >= f): break radius = np.sqrt(dist2[k]) return com0, com1, radius #######################################.####################################### # # # MAKE_MAP_PDF # # # ############################################################################### def make_map_pdf(max_maps, min_maps, pdf_path): """ Make a pdf, one unit per page. Parameters ---------- maps - maps with dimensions [unit_i, y, x] (black and white) or [unit_i, y, x, rgb] (color)\n pdf_path - path name of the file, must end with '.pdf'\n """ yn, xn = max_maps.shape[1:3] with PdfPages(pdf_path) as pdf: fig, (ax1, ax2) = plt.subplots(1, 2) fig.set_size_inches(10, 5) im1 = ax1.imshow(np.zeros((yn, xn, 3)), vmax=1, vmin=0, cmap='gray') im2 = ax2.imshow(np.zeros((yn, xn, 3)), vmax=1, vmin=0, cmap='gray') for unit_i, (max_map, min_map) in enumerate(zip(max_maps, min_maps)): print_progress(f"Making pdf for unit {unit_i}...") fig.suptitle(f"no.{unit_i}", fontsize=20) vmax = max_map.max() vmin = max_map.min() vrange = vmax - vmin if math.isclose(vrange, 0): vrange = 1 im1.set_data((max_map-vmin)/vrange) ax1.set_title('max') vmax = min_map.max() vmin = min_map.min() vrange = vmax - vmin if math.isclose(vrange, 0): vrange = 1 im2.set_data((min_map-vmin)/vrange) ax2.set_title('min') plt.show() pdf.savefig(fig) plt.close()