Programming-Language-Design / src / visitor / codegeneration / CodeGenerator.java
CodeGenerator.java
Raw
package visitor.codegeneration;

import ast.Program;
import ast.definition.classes.FuncDefinition;
import ast.definition.classes.VarDefinition;
import ast.expression.classes.ArrayIndexing;
import ast.statement.Statement;
import ast.statement.classes.IfElse;
import ast.statement.classes.WhileLoop;
import ast.type.Type;
import ast.type.classes.FunctionType;
import visitor.codegeneration.utils.ReturnDataObject;

import java.io.FileWriter;
import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;

public class CodeGenerator {

    FileWriter fileWriter;
    String endLine = "\n";
    String tab = "\t";

    int numberOfLabels = 0;

    public CodeGenerator(String inputName, String outputName) {
        try {
            fileWriter = new FileWriter(outputName);
            fileWriter.write("#source \"" + inputName + "\"" + endLine + endLine);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public int getLabels(int howMany) {
        int labelNumber = this.numberOfLabels;
        this.numberOfLabels += howMany;
        return labelNumber;
    }

    public void write(String code) {
        try {
            fileWriter.write(tab + code + endLine);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void writeNoTab(String code) {
        try {
            fileWriter.write(code + endLine);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void close() {
        try {
            fileWriter.close();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void writeLine(int line) {
        writeNoTab(endLine + "#line" + tab + line);
    }

    //CONVERT

    public void convert(Type type1, Type type2) {
        String[] conversion = type1.convertTo(type2);
        if (conversion.length > 0)
            for (String str : conversion)
                write(str);
    }

    //PUSH OPERATIONS

    public void push(String suffix, Object value) {
        write("push" + suffix + " " + value);
    }

    public void pushBP() {
        write("push bp");
    }

    public void pushAddress(VarDefinition def) {
        if (def.getScope() == 0) { //Global variables
            push("a", def.getOffset());
        } else {
            pushBP(); //Push BP
            push("i", def.getOffset());
            write("addi");
        }
    }

    // COMPLEX OPERATIONS

    public void program(Program ast, ExecuteCGVisitor executeCGVisitor) {
        //Global variables definitions
        writeNoTab("' *  Global variables");
        ast.getDefinitions().forEach(def -> {
            if (def instanceof VarDefinition) def.accept(executeCGVisitor, null);
        });

        //"Main" call
        writeNoTab(endLine + "' Invocation to the main function");
        writeNoTab("call main");
        writeNoTab("halt");

        //Functions definitions
        ast.getDefinitions().forEach(def -> {
            if (def instanceof FuncDefinition) {
                writeLine(def.getLine());
                def.accept(executeCGVisitor, null);
            }
        });
        close();
    }

    public void arithmetic(String operator, Type type) {
        switch (operator) {
            case "+":
                write("add" + type.suffix());
                break;
            case "-":
                write("sub" + type.suffix());
                break;
            case "*":
                write("mul" + type.suffix());
                break;
            case "/":
                write("div" + type.suffix());
                break;
            case "%":
                write("mod" + type.suffix());
                break;
        }
    }

    public void comparison(String operator, Type type) {
        switch (operator) {
            case ">":
                write("gt" + type.suffix());
                break;
            case "<":
                write("lt" + type.suffix());
                break;
            case ">=":
                write("ge" + type.suffix());
                break;
            case "<=":
                write("le" + type.suffix());
                break;
            case "==":
                write("eq" + type.suffix());
                break;
            case "!=":
                write("ne" + type.suffix());
                break;
        }
    }

    public void logical(String operator, Type type) {
        switch (operator) {
            case "&&":
                write("and");
                break;
            case "||":
                write("or");
                break;
        }
    }

    public void arrayIndexingAddress(ArrayIndexing ast) {
        push("i", ast.getType().numberOfBytes());
        write("muli");
        write("addi");
    }

    public void funcDefinition(FuncDefinition ast, ExecuteCGVisitor executeCGVisitor) {

        writeNoTab(" " + ast.getName() + ":");

        int bytesParams = parameters(ast, executeCGVisitor);
        int bytesLocals = localVariables(ast, executeCGVisitor);
        int bytesReturn = ((FunctionType) ast.getType()).getReturnType().numberOfBytes();

        ReturnDataObject returnObj = new ReturnDataObject(bytesReturn, bytesLocals, bytesParams);

        executeBlock(
                ast.getFunctionStmts().stream().filter(stmt -> !(stmt instanceof VarDefinition)).collect(Collectors.toList())
                , executeCGVisitor
                , returnObj);

        //If no explicit return statement
        if (bytesReturn == 0)
            write("ret " + returnObj.getBytesReturn() + ", " + returnObj.getBytesLocals() + ", " + returnObj.getBytesParams());

    }

    private int parameters(FuncDefinition ast, ExecuteCGVisitor executeCGVisitor) {
        write("' * Parameters");
        ((FunctionType) ast.getType()).getParameters().forEach(def -> def.accept(executeCGVisitor, null));
        return ((FunctionType) ast.getType()).getParameters().stream().mapToInt(param -> param.getType().numberOfBytes()).sum();
    }

    private int localVariables(FuncDefinition ast, ExecuteCGVisitor executeCGVisitor) {
        write("' * Local variables");
        int bytesLocals = 0;

        List<Statement> localVariables = ast.getFunctionStmts().stream().
                filter(stmt -> stmt instanceof VarDefinition).collect(Collectors.toList());
        localVariables.forEach(def -> def.accept(executeCGVisitor, null));

        if (localVariables.size() != 0) {
            bytesLocals = ((VarDefinition) localVariables.get(localVariables.size() - 1)).getOffset();
            write("enter " + (-bytesLocals));
        }

        return -bytesLocals;
    }

    public void whileLoop(WhileLoop ast, ExecuteCGVisitor executeCGVisitor) {
        int labelNumber = getLabels(2);
        write("' * While");
        writeNoTab("label" + labelNumber + ":");
        ast.getCondition().accept(executeCGVisitor.valueCGVisitor, null);
        write("jz label" + (labelNumber + 1));

        write("' * Body of the while statement");
        executeBlock(ast.getStatements(), executeCGVisitor, null);

        write("jmp label" + labelNumber);
        writeNoTab("label" + (labelNumber + 1) + ":");
    }

    public void ifElse(IfElse ast, ExecuteCGVisitor executeCGVisitor) {
        //If Condition
        write("' * If statement");
        int labelNumber = getLabels(2);
        ast.getCondition().accept(executeCGVisitor.valueCGVisitor, null);
        write("jz label" + labelNumber);

        //Body of "if"
        write("' * Body of the if branch");
        executeBlock(ast.getIfStmts(), executeCGVisitor, null);
        write("jmp label" + (labelNumber + 1));

        //Body of "else"
        writeNoTab("label" + labelNumber + ":");
        write("' * Body of the else branch");
        executeBlock(ast.getElseStmts(), executeCGVisitor, null);

        writeNoTab("label" + (labelNumber + 1) + ":");
    }

    private void executeBlock(List<Statement> statements, ExecuteCGVisitor executeCGVisitor, ReturnDataObject traversalParameter) {
        int currentLine = -1;
        for (Statement stmt : statements) {
            if (currentLine != stmt.getLine())
                writeLine(stmt.getLine());
            stmt.accept(executeCGVisitor, traversalParameter);
            currentLine = stmt.getLine();
        }
    }
}