MiniTorch / project / show_tensor.py
show_tensor.py
Raw
import numpy as np
import plotly.graph_objects as go

x = [1, 2, 1, 2, 1, 2]
y = [1, 1, 2, 2, 3, 3]
x1 = np.array(list(map(lambda x: x + 3, x))).ravel()
y1 = y.copy()
x2 = np.array(list(map(lambda x: x1 + 3, x))).ravel()
y2 = y.copy()
initial_matrix = np.vstack(
    [
        np.hstack([[1, 1, 1, 1, 1], np.ones(5) * 2]),
        np.array([1, 2, 3, 4, 5, 1, 2, 3, 4, 5]),
    ]
)
axis_default = ["i", "k", "j"]


def permute(mat, x, y):
    return mat.transpose(x, y)


def plot_matrix(x, y, title, w=300, h=500, bg="white"):
    data = [
        go.Scatter(
            hoverinfo="skip",
            mode="markers",
            x=x,
            y=y,
            marker=dict(color="black", size=50, symbol="square-open"),
        ),
    ]

    layout = go.Layout(
        title={"text": title, "x": 0.5, "y": 0.9, "xanchor": "center"},
        font={"family": "Raleway", "size": 40, "color": "black"},
        xaxis={"showgrid": False, "showticklabels": False},
        yaxis={"showgrid": False, "showticklabels": False},
        paper_bgcolor=bg,
        plot_bgcolor=bg,
        autosize=False,
        width=w,
        height=h,
        showlegend=False,
    )

    fig = go.Figure(data=data, layout=layout)
    fig.show()


def plot_map():

    data = [
        go.Scatter(
            hoverinfo="skip",
            mode="markers",
            x=x,
            y=y,
            marker=dict(color="black", size=50, symbol="square-open"),
        ),
        go.Scatter(
            hoverinfo="skip",
            mode="markers",
            x=x1,
            y=y1,
            marker=dict(color="#69BAC9", size=50, symbol="square"),
        ),
    ]

    layout = go.Layout(
        title={"text": "map", "x": 0.5, "y": 0.9, "xanchor": "center"},
        font={"family": "Raleway", "size": 40, "color": "black"},
        xaxis={"showgrid": False, "showticklabels": False},
        yaxis={"showgrid": False, "showticklabels": False},
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)",
        autosize=False,
        width=500,
        height=400,
        showlegend=False,
    )

    fig = go.Figure(data=data, layout=layout)

    fig.add_annotation(
        x=5,
        y=2,
        ax=2,
        ay=2,
        xref="x",
        yref="y",
        axref="x",
        ayref="y",
        showarrow=True,
        arrowhead=3,
        font=dict(
            size=15,
            color="black",
        ),
    )
    fig.add_annotation(
        x=4,
        y=3,
        ax=1,
        ay=3,
        xref="x",
        yref="y",
        axref="x",
        ayref="y",
        showarrow=True,
        arrowhead=3,
        font=dict(
            size=15,
            color="black",
        ),
    )
    fig.show()


def plot_zip():

    data = [
        go.Scatter(
            hoverinfo="skip",
            mode="markers",
            x=x,
            y=y,
            marker=dict(color="black", size=50, symbol="square-open"),
        ),
        go.Scatter(
            hoverinfo="skip",
            mode="markers",
            x=x1,
            y=y1,
            marker=dict(color="black", size=50, symbol="square-open"),
        ),
        go.Scatter(
            hoverinfo="skip",
            mode="markers",
            x=x2,
            y=y2,
            marker=dict(color="#69BAC9", size=50, symbol="square"),
        ),
    ]

    layout = go.Layout(
        title={"text": "zip", "x": 0.5, "y": 0.9, "xanchor": "center"},
        font={"family": "Raleway", "size": 40, "color": "black"},
        xaxis={"showgrid": False, "showticklabels": False},
        yaxis={"showgrid": False, "showticklabels": False},
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)",
        autosize=False,
        width=700,
        height=400,
        showlegend=False,
    )

    fig = go.Figure(data=data, layout=layout)
    fig.add_annotation(
        x=8,
        y=3,
        ax=2,
        ay=3,
        xref="x",
        yref="y",
        axref="x",
        ayref="y",
        showarrow=True,
        arrowhead=3,
        font=dict(
            size=15,
            color="black",
        ),
    )
    fig.add_annotation(
        x=8,
        y=3,
        ax=5,
        ay=2.8,
        xref="x",
        yref="y",
        axref="x",
        ayref="y",
        showarrow=True,
        arrowhead=3,
        font=dict(
            size=15,
            color="black",
        ),
    )
    fig.show()


def plot_reduce():

    data = [
        go.Scatter(
            hoverinfo="skip",
            mode="markers",
            x=x,
            y=y,
            marker=dict(color="black", size=50, symbol="square-open"),
        ),
        go.Scatter(
            hoverinfo="skip",
            mode="markers",
            x=x1[:2],
            y=y1[:2],
            marker=dict(color="#69BAC9", size=50, symbol="square"),
        ),
    ]

    layout = go.Layout(
        title={"text": "reduce", "x": 0.5, "y": 0.9, "xanchor": "center"},
        font={"family": "Raleway", "size": 40, "color": "black"},
        xaxis={"showgrid": False, "showticklabels": False},
        yaxis={"showgrid": False, "showticklabels": False},
        paper_bgcolor="rgba(0,0,0,0)",
        plot_bgcolor="rgba(0,0,0,0)",
        autosize=False,
        width=700,
        height=400,
        showlegend=False,
    )

    fig = go.Figure(data=data, layout=layout)

    fig.add_annotation(
        x=2,
        y=1,
        ax=2,
        ay=3,
        xref="x",
        yref="y",
        axref="x",
        ayref="y",
        showarrow=True,
        arrowhead=3,
        font=dict(
            size=15,
            color="black",
        ),
    )
    fig.add_annotation(
        x=1,
        y=1,
        ax=1,
        ay=3,
        xref="x",
        yref="y",
        axref="x",
        ayref="y",
        showarrow=True,
        arrowhead=3,
        font=dict(
            size=15,
            color="black",
        ),
    )
    fig.add_annotation(
        x=3.5,
        y=1,
        ax=2.3,
        ay=1,
        xref="x",
        yref="y",
        axref="x",
        ayref="y",
        showarrow=True,
        arrowhead=3,
        font=dict(
            size=15,
            color="black",
        ),
    )
    fig.show()


def plot_tensor(x, y, z, active=5):
    fig = go.Figure()

    # Construct tensor coordinates
    def construct_tensor(shape=[x, y, z]):
        coords = []
        for z in list(range(shape[2])):
            for y in list(range(shape[1])):
                for x in list(range(shape[0])):

                    coords.append([x, y, z])
        return np.array(coords) * 1.1

    tensor_coords = construct_tensor(shape=[x, y, z])

    # Construct one 3d mesh box
    def add_one_box(ind, xs, ys, zs, name, alpha=1.0):

        # Build triangles from tensor coordinates
        i = [7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2]
        j = [3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3]
        k = [0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6]
        triangles = np.vstack((i, j, k)).T

        vertices = np.vstack([xs, ys, zs]).T
        tri_points = vertices[triangles]
        Xe = []
        Ye = []
        Ze = []

        for T in tri_points:
            Xe.extend([T[k % 3][0] for k in range(4)] + [None])
            Ye.extend([T[k % 3][1] for k in range(4)] + [None])
            Ze.extend([T[k % 3][2] for k in range(4)] + [None])

        # Build mesh box from triangles
        if ind == active:
            c = "#69bac9"
        else:
            c = "white"
        fig.add_trace(
            go.Mesh3d(
                hoverinfo="skip",
                opacity=alpha,
                x=xs,
                y=ys,
                z=zs,
                color=c,
                flatshading=True,
                alphahull=0,
                name=name,
                showscale=False,
                visible=True,
                lighting=dict(ambient=0.5, diffuse=0.6),
                lightposition=dict(x=0, y=0, z=0),
            )
        )

    # Add boxes to fig
    def box_adder(boxes):

        # Construct boxes from tensor coordinates
        def construct_whole_box(initXYZ):
            for ind, i in enumerate(initXYZ):
                if ind == 0:
                    wholeBoxXs = [
                        i + 0,
                        i + 1,
                        i + 0,
                        i + 1,
                        i + 0,
                        i + 1,
                        i + 0,
                        i + 1,
                    ]
                if ind == 1:
                    wholeBoxYs = [
                        i + 0,
                        i + 0,
                        i + 1,
                        i + 1,
                        i + 0,
                        i + 0,
                        i + 1,
                        i + 1,
                    ]
                if ind == 2:
                    wholeBoxZs = [
                        i + 0,
                        i + 0,
                        i + 0,
                        i + 0,
                        i + 1,
                        i + 1,
                        i + 1,
                        i + 1,
                    ]
            return wholeBoxXs, wholeBoxYs, wholeBoxZs

        for ind, i in enumerate(boxes):
            add_one_box(
                ind,
                *construct_whole_box(i),
                str((np.array([i[0], i[2], i[1]]) / (1.1)).astype(int))
                .replace(" ", ",")
                .replace("[", "(")
                .replace("]", ")")
            )

    box_adder(tensor_coords)
    return fig


def tensor_figure(
    x,
    y,
    z,
    active,
    title,
    xr=None,
    yr=None,
    zr=None,
    axisTitles=axis_default,
    slider=True,
    eye=dict(x=2.8, y=1.6, z=1.6),
    show_fig=True,
):
    fig = plot_tensor(x, y, z, active=active)
    if xr is None:
        xr = [x + 0.2, 0]
        yr = [0, y + 0.2]
        zr = [z + 0.2, 0]
    # Create and add slider
    if slider:
        steps = []
        for i, val in enumerate(fig.data):
            step = dict(
                method="update",
                args=[
                    {
                        "opacity": [1.0] * len(fig.data),
                        "color": ["white"] * len(fig.data),
                    },
                    {"title": "Tensor Index: " + val["name"]},
                ],
            )
            # Toggle i'th trace
            step["args"][0]["opacity"][i] = 1.0
            step["args"][0]["color"][i] = "#69bac9"
            steps.append(step)

        fig.update_layout(
            sliders=[
                dict(
                    active=active,
                    steps=steps,
                    currentvalue=dict(visible=False),
                    tickcolor="#fcfcfc",
                    font=dict(color="#fcfcfc"),
                ),
            ],
        )

    camera = dict(up=dict(x=0, y=0, z=1), center=dict(x=0, y=0, z=0), eye=eye)

    fig.update_layout(
        title={"text": title, "x": 0.5, "y": 0.9, "xanchor": "center"},
        font={"family": "Raleway", "size": 40, "color": "black"},
        scene_camera=camera,
        paper_bgcolor="#fcfcfc",
        font_size=20,
        hovermode="x unified",
        hoverlabel=dict(bgcolor="#fcfcfc", font_size=30, font_family="Times New Roman"),
        scene=dict(
            xaxis=dict(
                showbackground=False,
                zerolinecolor="#fcfcfc",
                showticklabels=False,
                range=xr,
                title=axisTitles[0],
            ),
            yaxis=dict(
                showbackground=False,
                zerolinecolor="#fcfcfc",
                showticklabels=False,
                range=yr,
                title=axisTitles[1],
            ),
            zaxis=dict(
                showbackground=False,
                zerolinecolor="#fcfcfc",
                showticklabels=False,
                range=zr,
                title=axisTitles[2],
            ),
        ),
        width=500,
        height=500,
        margin=dict(r=10, l=10, b=10, t=50),
    )
    if show_fig:
        fig.show()
    else:
        return fig