package dasherJava.core.languageModeling; import java.util.Arrays; import java.util.Iterator; import java.util.Map; import java.util.Map.Entry; import java.util.SortedMap; import java.util.TreeMap; import dasherJava.DasherJava; import dasherJava.core.languageModeling.LanguageAlphabet.UnicodeNotFoundException; import dasherJava.core.languageModeling.TrainingFileReader.TrainingFileReaderException; public class LanguageModel { public static final int NORMALIZATION = 1<<16; private final LanguageAlphabet alphabet; private final int maxOrder; //This cannot be modified without rebuilding the model private final ContextTrieNode contextTrieRoot = new ContextTrieNode(null, null); //These could be modified without rebuilding the model private final int alpha; private final int beta; private final int uniform; //Characters with fixed probabilities are treated separately from the PPM model: Since they are independent of //the current context, we can precompute their normalized probabilities. These are stored in the array below, //which contains a negative value for every character that does not have a fixed probability. When given a //context at runtime, we take this mask and fill out all these "holes" with the data returned by the PPM. private final int[] fixedProbabilitiesMask; //This value is the global NORMALIZATION value minus the sum of all non-negative values in the array above, //which is the actual normalization value for the PPM. private final int nonFixedProbabilitiesNorm; public LanguageModel(LanguageAlphabet alphabet, int maxOrder, int alpha, int beta, int uniform) { this.alphabet=alphabet; this.maxOrder=maxOrder; this.alpha=alpha; this.beta=beta; this.uniform=uniform; Map fixedProbabilityCharacters = alphabet.getFixedProbabilityCharacters(); fixedProbabilitiesMask=new int[alphabet.getNumOfCharacters()]; int fixedProbabilitiesSum = 0; int numOfFixedProbabilityCharacters = 0; for (int i = 0; i=numOfFixedProbabilityCharacters) throw new RuntimeException("LanguageModel: " +"remaining is "+remaining+", but should be in [0, "+(numOfFixedProbabilityCharacters-1)+"]"); for (int i = 0; i trainingData, LanguageModelTrainingStats trainingStatsReport) { Context trainingContext = createEmptyContext(); //Cannot use enhanced foreach-loop here since we need to catch exceptions //and then continue the loop Iterator iterator = trainingData.iterator(); try { while (iterator.hasNext()) { try { Integer symbolIndex = iterator.next(); //may throw UnicodeNotFoundException trainingContext.learnSymbol(symbolIndex, trainingStatsReport); trainingStatsReport.incrementNumOfSymbolsRead(); } catch (UnicodeNotFoundException ex) { trainingStatsReport.incrementSkippedSymbolCount(ex.getUnicode()); } } } catch (TrainingFileReaderException ex) { DasherJava.showErrorMessage("Language model training aborted", ex); } } public int[] getNonFixedProbabilities(Context context) { int[] result = new int[alphabet.getNumOfCharacters()-alphabet.getFixedProbabilityCharacters().size()]; if (result.length==0) return result; //only fixed probability characters in this model int uniformAdd = Math.max(1, nonFixedProbabilitiesNorm*uniform/1000/result.length); int toSpend = nonFixedProbabilitiesNorm-result.length*uniformAdd; for (ContextTrieNode temp = context.getHead(); temp!=null; temp=temp.getVine()) { int total = 0; for (ContextTrieNode child : temp.getChildren()) { total+=child.getCount(); } if (total!=0) { int sizeOfSlice = toSpend; for (Entry childEntry : temp.getChildrenEntries()) { int p = (int) ((long) sizeOfSlice*(100*childEntry.getValue().getCount()-beta)/(100*total+alpha)); result[childEntry.getKey()]+=p; toSpend-=p; } } } int sizeOfSlice2 = toSpend; for (int symbolIndex = 0; symbolIndex=0) result[i]=fixedProbabilitiesMask[i]; //copy fixed probability else { //use computed probability result[i]=nonFixedProbabilities[nonFixedProbabilitiesIndex]; nonFixedProbabilitiesIndex++; } } //Check sum and non-negativity: int sum = 0; for (int i : result) { sum+=i; if (i<0) throw new RuntimeException("getProbabilities(): Negative probability: "+i); } if (sum!=NORMALIZATION) throw new RuntimeException("getProbabilities(): Sum is "+sum +", but should be "+NORMALIZATION); // return result; } public Context createEmptyContext() { return new Context(this, contextTrieRoot); } public LanguageAlphabet getAlphabet() { return alphabet; } int getMaxOrder() { return maxOrder; } public static class LanguageModelTrainingStats { private int numOfSymbolsRead = 0; private int numOfContextTrieNodes = 0; private final SortedMap skippedSymbols = new TreeMap<>(); //sorted for output public void incrementNumOfSymbolsRead() { numOfSymbolsRead++; if (numOfSymbolsRead%1000000==0) System.out.println("Language model training in progress: Read " +numOfSymbolsRead+" symbols so far"); } public void incrementNumOfContextTrieNodes() { numOfContextTrieNodes++; if (numOfContextTrieNodes%1000000==0) System.out.println("Language model training in progress: Created " +numOfContextTrieNodes+" context trie nodes so far"); } public void incrementSkippedSymbolCount(int unicode) { Integer alreadySkippedCount = skippedSymbols.getOrDefault(unicode, 0); skippedSymbols.put(unicode, alreadySkippedCount+1); } public int getNumOfSymbolsRead() { return numOfSymbolsRead; } public int getNumOfContextTrieNodes() { return numOfContextTrieNodes; } public Map getSkippedSymbols() { return skippedSymbols; } } }