package dasherJava.core.languageModeling;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.SortedMap;
import java.util.TreeMap;
import dasherJava.DasherJava;
import dasherJava.core.languageModeling.LanguageAlphabet.UnicodeNotFoundException;
import dasherJava.core.languageModeling.TrainingFileReader.TrainingFileReaderException;
public class LanguageModel {
public static final int NORMALIZATION = 1<<16;
private final LanguageAlphabet alphabet;
private final int maxOrder; //This cannot be modified without rebuilding the model
private final ContextTrieNode contextTrieRoot = new ContextTrieNode(null, null);
//These could be modified without rebuilding the model
private final int alpha;
private final int beta;
private final int uniform;
//Characters with fixed probabilities are treated separately from the PPM model: Since they are independent of
//the current context, we can precompute their normalized probabilities. These are stored in the array below,
//which contains a negative value for every character that does not have a fixed probability. When given a
//context at runtime, we take this mask and fill out all these "holes" with the data returned by the PPM.
private final int[] fixedProbabilitiesMask;
//This value is the global NORMALIZATION value minus the sum of all non-negative values in the array above,
//which is the actual normalization value for the PPM.
private final int nonFixedProbabilitiesNorm;
public LanguageModel(LanguageAlphabet alphabet, int maxOrder, int alpha, int beta, int uniform) {
this.alphabet=alphabet;
this.maxOrder=maxOrder;
this.alpha=alpha;
this.beta=beta;
this.uniform=uniform;
Map<Integer, LanguageCharacter> fixedProbabilityCharacters = alphabet.getFixedProbabilityCharacters();
fixedProbabilitiesMask=new int[alphabet.getNumOfCharacters()];
int fixedProbabilitiesSum = 0;
int numOfFixedProbabilityCharacters = 0;
for (int i = 0; i<fixedProbabilitiesMask.length; i++) {
LanguageCharacter fixedProbabilityCharacter = fixedProbabilityCharacters.get(i);
if (fixedProbabilityCharacter!=null) {
int prob = (int) (NORMALIZATION*fixedProbabilityCharacter.getFixedProbability()); //always round down
fixedProbabilitiesMask[i]=prob;
fixedProbabilitiesSum+=prob;
numOfFixedProbabilityCharacters++;
} else fixedProbabilitiesMask[i]=-1; //non-fixed-probability character
}
//If there are only characters with fixed probabilities (and thus no PPM is necessary) we need to ensure that
//the normalized fixed probabilities sum to the correct value by evenly distributing any leftover space.
if (numOfFixedProbabilityCharacters==alphabet.getNumOfCharacters() && fixedProbabilitiesSum!=NORMALIZATION) {
int addToAll = (NORMALIZATION-fixedProbabilitiesSum)/numOfFixedProbabilityCharacters;
for (int i = 0; i<fixedProbabilitiesMask.length; i++) {
fixedProbabilitiesMask[i]+=addToAll;
fixedProbabilitiesSum+=addToAll;
}
int remaining = NORMALIZATION-fixedProbabilitiesSum;
if (remaining<0 || remaining>=numOfFixedProbabilityCharacters) throw new RuntimeException("LanguageModel: "
+"remaining is "+remaining+", but should be in [0, "+(numOfFixedProbabilityCharacters-1)+"]");
for (int i = 0; i<fixedProbabilitiesMask.length; i++) {
if (remaining==0) break;
fixedProbabilitiesMask[i]++;
fixedProbabilitiesSum++;
remaining--;
}
if (fixedProbabilitiesSum!=NORMALIZATION) throw new RuntimeException("LanguageModel: "
+"fixedProbabilitiesSum is "+fixedProbabilitiesSum+", but should be "+NORMALIZATION);
}
nonFixedProbabilitiesNorm=NORMALIZATION-fixedProbabilitiesSum;
System.out.println("Language model creation done (no training yet). Non-fixed probabilities norm: "
+nonFixedProbabilitiesNorm+". Fixed probabilities mask: "+Arrays.toString(fixedProbabilitiesMask));
if (nonFixedProbabilitiesNorm<0) throw new RuntimeException("LanguageModel: Sum of fixed probabilities is "
+"too large: "+Arrays.toString(fixedProbabilitiesMask));
}
public void train(Iterable<Integer> trainingData, LanguageModelTrainingStats trainingStatsReport) {
Context trainingContext = createEmptyContext();
//Cannot use enhanced foreach-loop here since we need to catch exceptions
//and then continue the loop
Iterator<Integer> iterator = trainingData.iterator();
try {
while (iterator.hasNext()) {
try {
Integer symbolIndex = iterator.next(); //may throw UnicodeNotFoundException
trainingContext.learnSymbol(symbolIndex, trainingStatsReport);
trainingStatsReport.incrementNumOfSymbolsRead();
} catch (UnicodeNotFoundException ex) {
trainingStatsReport.incrementSkippedSymbolCount(ex.getUnicode());
}
}
} catch (TrainingFileReaderException ex) {
DasherJava.showErrorMessage("Language model training aborted", ex);
}
}
public int[] getNonFixedProbabilities(Context context) {
int[] result = new int[alphabet.getNumOfCharacters()-alphabet.getFixedProbabilityCharacters().size()];
if (result.length==0) return result; //only fixed probability characters in this model
int uniformAdd = Math.max(1, nonFixedProbabilitiesNorm*uniform/1000/result.length);
int toSpend = nonFixedProbabilitiesNorm-result.length*uniformAdd;
for (ContextTrieNode temp = context.getHead(); temp!=null; temp=temp.getVine()) {
int total = 0;
for (ContextTrieNode child : temp.getChildren()) {
total+=child.getCount();
}
if (total!=0) {
int sizeOfSlice = toSpend;
for (Entry<Integer, ContextTrieNode> childEntry : temp.getChildrenEntries()) {
int p = (int) ((long) sizeOfSlice*(100*childEntry.getValue().getCount()-beta)/(100*total+alpha));
result[childEntry.getKey()]+=p;
toSpend-=p;
}
}
}
int sizeOfSlice2 = toSpend;
for (int symbolIndex = 0; symbolIndex<result.length; symbolIndex++) {
int p = sizeOfSlice2/result.length;
result[symbolIndex]+=p;
toSpend-=p;
}
int left = result.length;
for (int symbolIndex = 0; symbolIndex<result.length; symbolIndex++) {
int p = toSpend/left;
result[symbolIndex]+=p+uniformAdd;
left--;
toSpend-=p;
}
//Check sum and non-negativity:
int sum = 0;
for (int i : result) {
sum+=i;
if (i<0) throw new RuntimeException("getNonFixedProbabilities(): Negative probability: "+i);
}
if (sum!=nonFixedProbabilitiesNorm) throw new RuntimeException("getNonFixedProbabilities(): Sum is "+sum
+", but should be "+nonFixedProbabilitiesNorm);
//
return result;
}
public int[] getProbabilities(Context context) {
int[] result = new int[fixedProbabilitiesMask.length];
int[] nonFixedProbabilities = getNonFixedProbabilities(context);
int nonFixedProbabilitiesIndex = 0;
for (int i = 0; i<result.length; i++) {
if (fixedProbabilitiesMask[i]>=0) result[i]=fixedProbabilitiesMask[i]; //copy fixed probability
else { //use computed probability
result[i]=nonFixedProbabilities[nonFixedProbabilitiesIndex];
nonFixedProbabilitiesIndex++;
}
}
//Check sum and non-negativity:
int sum = 0;
for (int i : result) {
sum+=i;
if (i<0) throw new RuntimeException("getProbabilities(): Negative probability: "+i);
}
if (sum!=NORMALIZATION) throw new RuntimeException("getProbabilities(): Sum is "+sum
+", but should be "+NORMALIZATION);
//
return result;
}
public Context createEmptyContext() {
return new Context(this, contextTrieRoot);
}
public LanguageAlphabet getAlphabet() {
return alphabet;
}
int getMaxOrder() {
return maxOrder;
}
public static class LanguageModelTrainingStats {
private int numOfSymbolsRead = 0;
private int numOfContextTrieNodes = 0;
private final SortedMap<Integer, Integer> skippedSymbols = new TreeMap<>(); //sorted for output
public void incrementNumOfSymbolsRead() {
numOfSymbolsRead++;
if (numOfSymbolsRead%1000000==0) System.out.println("Language model training in progress: Read "
+numOfSymbolsRead+" symbols so far");
}
public void incrementNumOfContextTrieNodes() {
numOfContextTrieNodes++;
if (numOfContextTrieNodes%1000000==0) System.out.println("Language model training in progress: Created "
+numOfContextTrieNodes+" context trie nodes so far");
}
public void incrementSkippedSymbolCount(int unicode) {
Integer alreadySkippedCount = skippedSymbols.getOrDefault(unicode, 0);
skippedSymbols.put(unicode, alreadySkippedCount+1);
}
public int getNumOfSymbolsRead() {
return numOfSymbolsRead;
}
public int getNumOfContextTrieNodes() {
return numOfContextTrieNodes;
}
public Map<Integer, Integer> getSkippedSymbols() {
return skippedSymbols;
}
}
}