import numpy as np import plotly.graph_objects as go import streamlit as st from project.interface.streamlit_utils import render_function from show_tensor import tensor_figure from minitorch import SimpleBackend, Tensor, index_to_position, operators, to_index from minitorch.tensor_data import TensorData def st_select_index(tensor_shape, n_cols=3): out_index = [0] * len(tensor_shape) cols = st.columns(n_cols) for idx, dim in enumerate(tensor_shape): out_index[idx] = cols[idx % n_cols].number_input( f"Dimension {idx} index:", value=0, min_value=0, max_value=dim - 1 ) return out_index def st_visualize_storage(tensor: Tensor, selected_position: int, max_size=10): tensor_size = len(tensor._tensor._storage) if tensor_size > max_size: st.warning(f"Showing first {max_size} elements from the tensor storage.") x = list(range(min(tensor_size, max_size))) y = [0] * len(x) data = [ go.Scatter( hoverinfo="skip", mode="markers+text", x=x, y=y, marker=dict( size=50, symbol="square", color=[ "#69BAC9" if x_ == selected_position else "lightgray" for x_ in x ], ), text=tensor._tensor._storage[:max_size], textposition="middle center", textfont_size=20, ) ] lr_margin = 25 if len(x) >= 9 else 75 if len(x) >= 6 else 175 layout = go.Layout( title={"text": "Tensor Storage", "x": 0.5, "y": 1.0, "xanchor": "center"}, font={"family": "Raleway", "size": 20, "color": "black"}, xaxis={"showgrid": False, "showticklabels": False}, yaxis={"showgrid": False, "showticklabels": False}, paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", autosize=True, width=650, height=125, showlegend=False, margin=dict(l=lr_margin, r=lr_margin, t=0, b=0), ) fig = go.Figure(data=data, layout=layout) st.write(fig) def st_visualize_tensor( tensor: Tensor, highlighted_index, strides=None, show_value=True ): depth = tensor.shape[0] rows = tensor.shape[1] if len(tensor.shape) > 1 else 1 columns = tensor.shape[2] if len(tensor.shape) > 2 else 1 if strides is None: strides = tensor._tensor.strides if len(tensor.shape) != 3: # TODO: Fix visualization instead of showing warning st.error("Can only visualize a tensor which has 3 dimensions") return position_in_storage = index_to_position(highlighted_index, strides) if position_in_storage >= 0 and show_value: st.write( f"**Value at position {position_in_storage}:** {tensor._tensor._storage[position_in_storage]}" ) # Map index to highlight since tensor_figure doesn't know about strides st.write("highlighted", highlighted_index) if len(highlighted_index) > 2: highlighted_position = highlighted_index[0] highlighted_position += highlighted_index[1] * depth * columns highlighted_position += highlighted_index[2] * depth elif len(highlighted_index) > 1: highlighted_position = highlighted_index[0] highlighted_position += highlighted_index[1] * depth * columns else: highlighted_position = highlighted_index[0] fig = tensor_figure( depth, columns, rows, highlighted_position, f"Storage position: {position_in_storage}, Index: {highlighted_index}", # Fix for weirndess in tensor_figure axis axisTitles=["depth (i)", "columns (k)", "rows (j)"], show_fig=False, slider=False, ) fig.update_layout( width=650, height=500, paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", margin=dict(l=50, r=50, t=0, b=0), ) st.write(fig) def interface_visualize_tensor(tensor: Tensor, hide_function_defs: bool): st.write(f"**Tensor strides:** {tensor._tensor.strides}") selected_position = st.slider( "Selected position in storage", 0, len(tensor._tensor._storage) - 1, value=0 ) out_index = [0] * len(tensor.shape) to_index(selected_position, tensor.shape, out_index) st.write(f"**Corresponding index:** {out_index}") st_visualize_tensor(tensor, out_index, show_value=False) st_visualize_storage(tensor, selected_position) def interface_index_to_position(tensor: Tensor, hide_function_defs: bool): if not hide_function_defs: with st.expander("Show function definition"): render_function(index_to_position) col1, col2 = st.columns(2) idx = eval( col1.text_input( "Multi-dimensional index", value=str([0] * len(tensor._tensor.strides)) ) ) tensor_strides = eval( col2.text_input("Tensor strides", value=str(tensor._tensor.strides)) ) st_visualize_tensor(tensor, idx, tensor_strides) def interface_to_index(tensor: Tensor, hide_function_defs: bool): if not hide_function_defs: with st.expander("Show function definition"): render_function(to_index) tensor_shape = tensor.shape st.write(f"**Tensor strides:** {tensor._tensor.strides}") selected_position = st.number_input( "Position in storage", value=0, min_value=0, max_value=len(tensor._tensor._storage) - 1, ) out_index = [0] * len(tensor_shape) to_index(selected_position, tensor_shape, out_index) st.write( f"**Value at position {selected_position}:** {tensor._tensor._storage[selected_position]}" ) st.write("**Out index:**", out_index) st_visualize_tensor(tensor, out_index, show_value=False) def interface_strides(tensor: Tensor, hide_function_defs: bool): strides = eval(st.text_input("Tensor strides", value=str(tensor._tensor.strides))) st.write("**Try it out:**") out_index = st_select_index(tensor.shape) cols = st.columns(len(strides)) for dim, stride in enumerate(strides): cols[dim].write(f"*moves {stride} positions in storage*") st_visualize_tensor(tensor, out_index, strides, show_value=True) def interface_permute(tensor: Tensor, hide_function_defs: bool): if not hide_function_defs: with st.expander("Show function definition"): render_function(TensorData.permute) st.write(f"**Tensor strides:** {tensor._tensor.strides}") default_permutation = list(range(len(tensor.shape))) default_permutation.reverse() permutation = eval(st.text_input("Tensor permutation", value=default_permutation)) p_tensor = tensor.permute(*permutation) p_tensor_strides = p_tensor._tensor.strides st.write(f"**Permuted tensor strides:** {p_tensor_strides}") st.write("**Try selecting a tensor value by index:**") out_index = st_select_index(tensor.shape) viz_type = st.selectbox( "Choose tensor visualization", options=["Original tensor", "Permuted tensor"] ) if viz_type == "Original tensor": viz_tensor = tensor else: viz_tensor = p_tensor st_visualize_tensor(viz_tensor, out_index, show_value=False) st_visualize_storage( tensor, index_to_position(out_index, viz_tensor._tensor.strides) ) def st_eval_error_message(expression: str, error_msg: str): try: return eval(expression) except Exception as e: st.error(error_msg) raise e def render_tensor_sandbox(hide_function_defs: bool): st.write("## Sandbox for Tensors") st.write("**Define your tensor**") # Consistent random number generator rng = np.random.RandomState(42) # col1, col2 = st.columns(2) tensor_shape = st_eval_error_message( st.text_input("Tensor shape", value="(2, 2, 2)"), "Tensor shape must be defined as an in-line tuple, i.e. (2, 2, 2)", ) tensor_size = int(operators.prod(tensor_shape)) random_tensor = st.checkbox("Fill tensor with random numbers", value=True) if random_tensor: tensor_data = np.round(rng.rand(tensor_size), 2) st.write("**Tensor data storage:**") # Visualize horizontally st.write(tensor_data.reshape(1, -1)) else: tensor_data = st_eval_error_message( st.text_input("Tensor data storage", value=str(list(range(tensor_size)))), "Tensor data storage must be defined as an in-line list, i.e. [1, 2, 3, 4]", ) try: test_tensor = Tensor.make(tensor_data, tensor_shape, backend=SimpleBackend) except AssertionError as e: storage_size = len(tensor_data) if tensor_size != storage_size: st.error( f"Tensor data storage must define all values in shape ({tensor_size} != {storage_size })" ) else: st.error(e) return select_fn = { "Visualize Tensor Definition": interface_visualize_tensor, "Visualize Tensor Strides": interface_strides, "function: index_to_position": interface_index_to_position, "function: to_index": interface_to_index, "function: TensorData.permute": interface_permute, } selected_fn = st.selectbox("Select an interface", options=list(select_fn.keys())) select_fn[selected_fn](test_tensor, hide_function_defs)