package main.expressions;
import main.parsers.XGrammarLexer;
import main.parsers.XGrammarParser;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;
import java.util.*;
public class Writer {
/*
After data collection process is complete (From walking into tree)
first determine number of joins that should be completed (length of the doubleConditionVars)
Begin query string with
"for joinObj.tupleName in "
//get two groupingIds and their ids
End with
"return //iterate through list and get all the strings "
*/
public static JoinContainer joinInfo;
public static HashMap joinTerms; //TODO: change this to hashmap to map groupId to constructed joinTerm
public static ArrayList> joinOrderings;
public static String joinRewriteResult;
public static boolean joinRewritePossible = true;
public static String RewriteFullQuery(String query){
//if joinRecordings is not possible -> return original query
//if joinInfo.doubleVarConditions is empty -> return original query
//check other cases where rewrite might not occur?
initiateWriter(query);
if (joinRewritePossible) {
buildJoinTerms();
} else{
return query;
}
if (joinRewritePossible) {
determineJoinOrdering(new ArrayList<>(Writer.joinInfo.doubleVarConditions.keySet()));
} else{
return query;
}
if (joinRewritePossible) {
performEntireJoin();
} else{
return query;
}
if (joinRewritePossible){
return joinRewriteResult;
} else{
return query;
}
}
//have a condition that checks if join can even occur in the first place
//loads in JoinContainer object necessary to rewrite implicit join queries
public static void initiateWriter(String query) {
final XGrammarLexer lexer = new XGrammarLexer(CharStreams.fromString(query));
final CommonTokenStream tokens = new CommonTokenStream(lexer);
final XGrammarParser parser = new XGrammarParser(tokens);
ParseTree tree = parser.prog();
RewriteBuilder visitor = new RewriteBuilder();
JoinContainer result = visitor.visit(tree);
if (result.doubleVarConditions.size() < 1){ //no double conditions found
joinRewritePossible = false;
}
joinInfo = result;
}
//takes all groupings and singular where conditions and builds individual join terms
public static void buildJoinTerms(){
//e.g.
/*
for $b in doc("input")/book,
$tb in $b/title
return {$b} {$tb}
*/
joinTerms = new HashMap<>();
Set groupIds = joinInfo.varVarGroupings.keySet();
for (Integer id : groupIds) {
StringBuilder joinTerm = new StringBuilder();
//get for variables
joinTerm.append("for ");
ArrayList forStatements = joinInfo.varForGroupings.get(id);
for (int i = 0; i < forStatements.size(); i++){
joinTerm.append(forStatements.get(i));
if (i != forStatements.size()-1) {
joinTerm.append(", ");
}
}
//get singular per group where conditions
ArrayList whereConditions = joinInfo.singleConditions.get(id);
if (whereConditions != null && whereConditions.size() > 0){
joinTerm.append(",\n");
joinTerm.append("where ");
for (int i = 0; i < whereConditions.size(); i++){
joinTerm.append(whereConditions.get(i));
if (i != whereConditions.size()-1){
joinTerm.append(", ");
}
}
}
joinTerm.append("\n");
//return clause of join term
//nest inside joinInfo.tupleName tag
//generate return statement with each variable being in a constructor
String tupleNameOnly = joinInfo.tupleName.replace("$", "");
joinTerm.append("return ");
joinTerm.append("<" + tupleNameOnly + ">");
ArrayList returnVars = joinInfo.varVarGroupings.get(id);
for (int i = 0; i < returnVars.size(); i++){
String varName = returnVars.get(i);
String varNameOnly = varName.replace("$", "");
joinTerm.append("<" + varNameOnly + ">");
joinTerm.append("{" + varName + "}");
joinTerm.append("" + varNameOnly + ">");
if (i != returnVars.size()-1){
joinTerm.append(", ");
}
}
joinTerm.append("" + tupleNameOnly + ">");
joinTerms.put(id, joinTerm.toString());
}
}
public static boolean overlapExists(ArrayList> pairs, int upTo, ArrayList newPair){
for (int i = 0; i <= upTo; i++){
ArrayList toCheck = pairs.get(i);
if (toCheck.get(0) == newPair.get(0)){
return true;
}
if (toCheck.get(0) == newPair.get(1)){
return true;
}
if (toCheck.get(1) == newPair.get(0)){
return true;
}
if (toCheck.get(1) == newPair.get(1)){
return true;
}
}
return false;
}
//two-way join or multi-way join (only can occur if two or more groupings and if the number of where
public static void determineJoinOrdering(ArrayList> keys){ //want overlap
//TODO: fix this (simple sort will not work) -> e.g. a case that has the following -> [1 5] [2 3] [3 4] [4 5] -> join ordering should
//[2 3] [3 4] [4 5] [1 5] or [4 5] [1 5] [3 4] [2 3]
//devise algorithm that will order pairs such that there is always an overlap in the subsequent pairs
//build a hashmap with occurences
//for each pair create a total occurences sum
//sort based on these total occurences sums (highest occurence comes first)
//rearrange if any do not match up
//wouldn't work for edge case -> how to fix?
HashMap occurrences = new HashMap<>(); //maps the join grouping id to occurrences
//ArrayList> keys = new ArrayList<>(joinInfo.doubleVarConditions.keySet());
for (int i = 0; i < keys.size(); i++){
//grouping Id1
ArrayList keyPair = keys.get(i);
if (occurrences.containsKey(keyPair.get(0))){
Integer num = occurrences.get(keyPair.get(0));
occurrences.put(keyPair.get(0), num+1);
} else{
occurrences.put(keyPair.get(0), 1);
}
//grouping Id 2
if (occurrences.containsKey(keyPair.get(1))){
Integer num = occurrences.get(keyPair.get(1));
occurrences.put(keyPair.get(1), num+1);
} else{
occurrences.put(keyPair.get(1), 1);
}
}
Comparator> comparator = (list1, list2) -> {
int sum1 = occurrences.get(list1.get(0)) + occurrences.get(list1.get(1));
int sum2 = occurrences.get(list2.get(0)) + occurrences.get(list2.get(1));
int compareVal = Integer.compare(sum2, sum1); //sort from greatest to least occurrence sum
return compareVal;
};
keys.sort(comparator); //sorted from great occurrences to the least occurrences
int startIdx = 0;
int endIdx = keys.size();
int counter = 0;
while (startIdx+1 < endIdx){
if (overlapExists(keys, startIdx, keys.get(startIdx+1)) == true){ //if there is any overlap in any of the keys up to this point -> return true
counter = 0;
startIdx++;
} else{
//push startIdx+1 pair to back and don't change startIdx
ArrayList toMove = keys.get(startIdx+1);
keys.remove(startIdx+1);
keys.add(toMove);
counter++;
if (counter > 2*(endIdx-startIdx)){
joinRewritePossible = false;
//throw new RuntimeException("Join not possible");
}
}
}
joinOrderings = keys;
//ArrayList> keys = new ArrayList<>(joinInfo.doubleVarConditions.keySet());
/*
Comparator> comparator = (list1, list2) -> {
for (int i = 0; i < 2; i++){ //least to the greatest pairs of groupingIds
int compareVal = Integer.compare(list1.get(i), list2.get(i));
if (compareVal != 0){
return compareVal;
}
}
return 0;
};
*/
}
public static ArrayList> extractJoinAttrLists(ArrayList joinIds){
ArrayList conditionVars = joinInfo.doubleVarsOnly.get(joinIds);
if (conditionVars.size() % 2 != 0) {
throw new RuntimeException("there should be an even number of condition Vars that belong to the pair wise conditional");
}
ArrayList group1 = new ArrayList<>(); //firstJoin.get(0) - has the lower groupId
ArrayList group2 = new ArrayList<>(); //firstJoin.get(1) - has the higher groupId
for (int i = 0; i < conditionVars.size(); i+=2){
int id1 = joinInfo.varGroupIds.get(conditionVars.get(i));
int id2 = joinInfo.varGroupIds.get(conditionVars.get(i+1));
if (id1 < id2){
group1.add(conditionVars.get(i).replace("$", ""));
group2.add(conditionVars.get(i+1).replace("$", ""));
} else { //(id2 < id1)
group2.add(conditionVars.get(i).replace("$", ""));
group1.add(conditionVars.get(i+1).replace("$", ""));
}
}
return new ArrayList<>(Arrays.asList(group1, group2));
}
public static void performEntireJoin(){
StringBuilder str = new StringBuilder();
//str.append("for " + joinInfo.tupleName + " in ");
if (joinOrderings.size() == 0){
joinRewritePossible = false;
return;
}
//first join
ArrayList firstJoin = joinOrderings.get(0);
str.append("join (" + joinTerms.get(firstJoin.get(0)) + ",\n" + joinTerms.get(firstJoin.get(1)) + ",\n");
//ArrayList conditionVars = joinInfo.doubleVarsOnly.get(firstJoin);
ArrayList> groupAttrs = extractJoinAttrLists(firstJoin);
//add the attr lists (from all the cond eq's)
str.append(groupAttrs.get(0));
str.append(", ");
str.append(groupAttrs.get(1));
str.append(")");
if (joinOrderings.size() > 1){
//System.out.println("Layering more joins...");
//Iteratively layer on more joins in the ordering given
for (int i = 1; i < joinOrderings.size(); i++){
//prepend( join(
//append( new joinTerm)
//append (new attrLists)
str.insert(0, "join (");
str.append(",\n");
//figure out groupingId that wasn't in the first join [e.g. [1,2], [1,3] -> 3 is needed
int idToAdd;
ArrayList prevPair = joinOrderings.get(i-1);
ArrayList currPair = joinOrderings.get(i);
if (prevPair.contains(currPair.get(0))){
idToAdd = currPair.get(1);
} else{
idToAdd = currPair.get(0);
}
String newJoinTerm = joinTerms.get(idToAdd);
str.append(newJoinTerm);
str.append(",\n");
ArrayList> attrList = extractJoinAttrLists(currPair);
str.append(attrList.get(0));
str.append(", ");
str.append(attrList.get(1));
str.append(")");
}
}
//prepend
str.insert(0, "for " + joinInfo.tupleName + " in ");
str.append("\n");
str.append("return\n");
// str.append("return\n" + "<" + joinInfo.finalTagName + ">{");
//Finally end with appending the return statement //TODO: add appropriate changes to RewriteBuilder (collection of final return statements)
for (int i = 0; i < joinInfo.nodesToReturn.size(); i++){
str.append(joinInfo.nodesToReturn.get(i));
if (i != joinInfo.nodesToReturn.size()-1){
str.append(", ");
}
}
//Refactor code -> ensure that join rewrite should occur only when appropriate
joinRewriteResult = str.toString();
}
/*
{ $tb,
{ $a/price },
{ $b/price } }
*/
}