MiniTorch / project / module_interface.py
module_interface.py
Raw
import networkx as nx
import streamlit as st
from streamlit_ace import st_ace

import minitorch

MyModule = None
minitorch


def render_module_sandbox():
    st.write("## Sandbox for Module Trees")

    st.write(
        "Visual debugging checks showing the module tree that your code constructs."
    )

    code = st_ace(
        language="python",
        height=300,
        value="""
class MyModule(minitorch.Module):
    def __init__(self):
        super().__init__()
        self.parameter1 = minitorch.Parameter(15)
""",
    )
    out = exec(code, globals())
    out = MyModule()
    st.write(dict(out.named_parameters()))
    G = nx.MultiDiGraph()
    G.add_node("base")
    stack = [(out, "base")]

    while stack:
        n, name = stack[0]
        stack = stack[1:]
        for pname, p in n.__dict__["_parameters"].items():
            G.add_node(name + "." + pname, shape="rect", penwidth=0.5)
            G.add_edge(name, name + "." + pname)

        for cname, m in n.__dict__["_modules"].items():
            G.add_edge(name, name + "." + cname)
            stack.append((m, name + "." + cname))

    G.graph["graph"] = {"rankdir": "TB"}
    st.graphviz_chart(nx.nx_pydot.to_pydot(G).to_string())