xquery-engine / src / main / expressions / Writer.java
Writer.java
Raw
package main.expressions;

import main.parsers.XGrammarLexer;
import main.parsers.XGrammarParser;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;

import java.util.*;

public class Writer {

    /*
        After data collection process is complete (From walking into tree)
        first determine number of joins that should be completed (length of the doubleConditionVars)

        Begin query string with
        "for joinObj.tupleName in "
        //get two groupingIds and their ids
        End with
        "return <joinObj.finalTagName> //iterate through list and get all the strings </joinObj.finalTagName>"
    */
    public static JoinContainer joinInfo;
    public static HashMap<Integer, String> joinTerms; //TODO: change this to hashmap to map groupId to constructed joinTerm
    public static ArrayList<ArrayList<Integer>> joinOrderings;
    public static String joinRewriteResult;

    public static boolean joinRewritePossible = true;



    public static String RewriteFullQuery(String query){
        //if joinRecordings is not possible -> return original query
        //if joinInfo.doubleVarConditions is empty -> return original query
        //check other cases where rewrite might not occur?
        initiateWriter(query);

        if (joinRewritePossible) {
            buildJoinTerms();
        } else{
            return query;
        }

        if (joinRewritePossible) {
            determineJoinOrdering(new ArrayList<>(Writer.joinInfo.doubleVarConditions.keySet()));
        } else{
            return query;
        }

        if (joinRewritePossible) {
            performEntireJoin();
        } else{
            return query;
        }

        if (joinRewritePossible){
            return joinRewriteResult;
        } else{
            return query;
        }
    }


    //have a condition that checks if join can even occur in the first place


    //loads in JoinContainer object necessary to rewrite implicit join queries
    public static void initiateWriter(String query) {
        final XGrammarLexer lexer = new XGrammarLexer(CharStreams.fromString(query));
        final CommonTokenStream tokens = new CommonTokenStream(lexer);
        final XGrammarParser parser = new XGrammarParser(tokens);

        ParseTree tree = parser.prog();
        RewriteBuilder visitor = new RewriteBuilder();
        JoinContainer result = visitor.visit(tree);

        if (result.doubleVarConditions.size() < 1){ //no double conditions found
            joinRewritePossible = false;
        }
        joinInfo = result;
    }


    //takes all groupings and singular where conditions and builds individual join terms
    public static void buildJoinTerms(){
        //e.g.
        /*
            for $b in doc("input")/book,
                    $tb in $b/title
                    return <tuple> <b> {$b} </b> <tb> {$tb} </tb> </tuple>
        */
        joinTerms = new HashMap<>();
        Set<Integer> groupIds = joinInfo.varVarGroupings.keySet();
        for (Integer id : groupIds) {
            StringBuilder joinTerm = new StringBuilder();
            //get for variables
            joinTerm.append("for ");
            ArrayList<String> forStatements = joinInfo.varForGroupings.get(id);
            for (int i = 0; i < forStatements.size(); i++){
                joinTerm.append(forStatements.get(i));
                if (i != forStatements.size()-1) {
                    joinTerm.append(", ");
                }
            }
            //get singular per group where conditions
            ArrayList<String> whereConditions = joinInfo.singleConditions.get(id);
            if (whereConditions != null && whereConditions.size() > 0){
                joinTerm.append(",\n");
                joinTerm.append("where ");
                for (int i = 0; i < whereConditions.size(); i++){
                    joinTerm.append(whereConditions.get(i));
                    if (i != whereConditions.size()-1){
                        joinTerm.append(", ");
                    }
                }
            }
            joinTerm.append("\n");

            //return clause of join term
            //nest inside joinInfo.tupleName tag
            //generate return statement with each variable being in a constructor

            String tupleNameOnly = joinInfo.tupleName.replace("$", "");

            joinTerm.append("return ");
            joinTerm.append("<" + tupleNameOnly + ">");
            ArrayList<String> returnVars = joinInfo.varVarGroupings.get(id);
            for (int i = 0; i < returnVars.size(); i++){
                String varName = returnVars.get(i);
                String varNameOnly = varName.replace("$", "");
                joinTerm.append("<" + varNameOnly + ">");
                joinTerm.append("{" + varName + "}");
                joinTerm.append("</" + varNameOnly + ">");
                if (i != returnVars.size()-1){
                    joinTerm.append(", ");
                }
            }
            joinTerm.append("</" + tupleNameOnly + ">");
            joinTerms.put(id, joinTerm.toString());
        }
    }


    public static boolean overlapExists(ArrayList<ArrayList<Integer>> pairs, int upTo, ArrayList<Integer> newPair){

        for (int i = 0; i <= upTo; i++){
            ArrayList<Integer> toCheck = pairs.get(i);
            if (toCheck.get(0) == newPair.get(0)){
                return true;
            }
            if (toCheck.get(0) == newPair.get(1)){
                return true;
            }
            if (toCheck.get(1) == newPair.get(0)){
                return true;
            }
            if (toCheck.get(1) == newPair.get(1)){
                return true;
            }
        }
        return false;
    }

    //two-way join or multi-way join (only can occur if two or more groupings and if the number of where
    public static void determineJoinOrdering(ArrayList<ArrayList<Integer>> keys){ //want overlap

        //TODO: fix this (simple sort will not work) -> e.g. a case that has the following -> [1 5] [2 3] [3 4] [4 5] -> join ordering should
        //[2 3] [3 4] [4 5] [1 5] or [4 5] [1 5] [3 4] [2 3]
        //devise algorithm that will order pairs such that there is always an overlap in the subsequent pairs

        //build a hashmap with occurences
        //for each pair create a total occurences sum
        //sort based on these total occurences sums (highest occurence comes first)
        //rearrange if any do not match up
        //wouldn't work for edge case -> how to fix?
        

        HashMap<Integer, Integer> occurrences = new HashMap<>(); //maps the join grouping id to occurrences
        //ArrayList<ArrayList<Integer>> keys = new ArrayList<>(joinInfo.doubleVarConditions.keySet());
        for (int i = 0; i < keys.size(); i++){
            //grouping Id1
            ArrayList<Integer> keyPair = keys.get(i);
            if (occurrences.containsKey(keyPair.get(0))){
                Integer num = occurrences.get(keyPair.get(0));
                occurrences.put(keyPair.get(0), num+1);
            } else{
                occurrences.put(keyPair.get(0), 1);
            }

            //grouping Id 2
            if (occurrences.containsKey(keyPair.get(1))){
                Integer num = occurrences.get(keyPair.get(1));
                occurrences.put(keyPair.get(1), num+1);
            } else{
                occurrences.put(keyPair.get(1), 1);
            }
        }


        Comparator<ArrayList<Integer>> comparator = (list1, list2) -> {
            int sum1 = occurrences.get(list1.get(0)) + occurrences.get(list1.get(1));
            int sum2 = occurrences.get(list2.get(0)) + occurrences.get(list2.get(1));
            int compareVal = Integer.compare(sum2, sum1); //sort from greatest to least occurrence sum
            return compareVal;
        };

        keys.sort(comparator); //sorted from great occurrences to the least occurrences

        int startIdx = 0;
        int endIdx = keys.size();
        int counter = 0;

        while (startIdx+1 < endIdx){
            if (overlapExists(keys, startIdx, keys.get(startIdx+1)) == true){ //if there is any overlap in any of the keys up to this point -> return true
                counter = 0;
                startIdx++;
            } else{
                //push startIdx+1 pair to back and don't change startIdx
                ArrayList<Integer> toMove = keys.get(startIdx+1);
                keys.remove(startIdx+1);
                keys.add(toMove);
                counter++;
                if (counter > 2*(endIdx-startIdx)){
                    joinRewritePossible = false;
                    //throw new RuntimeException("Join not possible");
                }
            }
        }

        joinOrderings = keys;
        //ArrayList<ArrayList<Integer>> keys = new ArrayList<>(joinInfo.doubleVarConditions.keySet());
        /*
        Comparator<ArrayList<Integer>> comparator = (list1, list2) -> {
            for (int i = 0; i < 2; i++){ //least to the greatest pairs of groupingIds
                int compareVal = Integer.compare(list1.get(i), list2.get(i));
                if (compareVal != 0){
                    return compareVal;
                }
            }
            return 0;
        };

         */


    }

    public static ArrayList<ArrayList<String>> extractJoinAttrLists(ArrayList<Integer> joinIds){
        ArrayList<String> conditionVars = joinInfo.doubleVarsOnly.get(joinIds);
        if (conditionVars.size() % 2 != 0) {
            throw new RuntimeException("there should be an even number of condition Vars that belong to the pair wise conditional");
        }
        ArrayList<String> group1 = new ArrayList<>(); //firstJoin.get(0) - has the lower groupId
        ArrayList<String> group2 = new ArrayList<>(); //firstJoin.get(1) - has the higher groupId
        for (int i = 0; i < conditionVars.size(); i+=2){
            int id1 = joinInfo.varGroupIds.get(conditionVars.get(i));
            int id2 = joinInfo.varGroupIds.get(conditionVars.get(i+1));
            if (id1 < id2){
                group1.add(conditionVars.get(i).replace("$", ""));
                group2.add(conditionVars.get(i+1).replace("$", ""));
            } else { //(id2 < id1)
                group2.add(conditionVars.get(i).replace("$", ""));
                group1.add(conditionVars.get(i+1).replace("$", ""));
            }
        }
        return new ArrayList<>(Arrays.asList(group1, group2));
    }

    public static void performEntireJoin(){

        StringBuilder str = new StringBuilder();
        //str.append("for " + joinInfo.tupleName + " in ");

        if (joinOrderings.size() == 0){
            joinRewritePossible = false;
            return;
        }
        //first join
        ArrayList<Integer> firstJoin = joinOrderings.get(0);
        str.append("join (" + joinTerms.get(firstJoin.get(0)) + ",\n" + joinTerms.get(firstJoin.get(1)) + ",\n");

        //ArrayList<String> conditionVars = joinInfo.doubleVarsOnly.get(firstJoin);
        ArrayList<ArrayList<String>> groupAttrs = extractJoinAttrLists(firstJoin);

        //add the attr lists (from all the cond eq's)
        str.append(groupAttrs.get(0));
        str.append(", ");
        str.append(groupAttrs.get(1));
        str.append(")");

        if (joinOrderings.size() > 1){
            //System.out.println("Layering more joins...");
            //Iteratively layer on more joins in the ordering given
            for (int i = 1; i < joinOrderings.size(); i++){
                //prepend( join(
                //append( new joinTerm)
                //append (new attrLists)
                str.insert(0, "join (");
                str.append(",\n");

                //figure out groupingId that wasn't in the first join [e.g. [1,2], [1,3] -> 3 is needed
                int idToAdd;
                ArrayList<Integer> prevPair = joinOrderings.get(i-1);
                ArrayList<Integer> currPair = joinOrderings.get(i);
                if (prevPair.contains(currPair.get(0))){
                    idToAdd = currPair.get(1);
                } else{
                    idToAdd = currPair.get(0);
                }
                String newJoinTerm = joinTerms.get(idToAdd);
                str.append(newJoinTerm);
                str.append(",\n");

                ArrayList<ArrayList<String>> attrList = extractJoinAttrLists(currPair);
                str.append(attrList.get(0));
                str.append(", ");
                str.append(attrList.get(1));
                str.append(")");
            }
        }
        //prepend
        str.insert(0, "for " + joinInfo.tupleName + " in ");
        str.append("\n");
        str.append("return\n");
       // str.append("return\n" + "<" + joinInfo.finalTagName + ">{");
        //Finally end with appending the return statement //TODO: add appropriate changes to RewriteBuilder (collection of final return statements)
        for (int i = 0; i < joinInfo.nodesToReturn.size(); i++){
            str.append(joinInfo.nodesToReturn.get(i));
            if (i != joinInfo.nodesToReturn.size()-1){
                str.append(", ");
            }
        }
        //Refactor code -> ensure that join rewrite should occur only when appropriate
        joinRewriteResult = str.toString();
    }





    /*
        <book-with-prices>
            { $tb,
        <price-review>{ $a/price }</price-review>,
        <price>{ $b/price }</price> }
        </book-with-prices>
    */



}