DasherJava / src / dasherJava / core / languageModeling / LanguageModel.java
LanguageModel.java
Raw
package dasherJava.core.languageModeling;

import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.SortedMap;
import java.util.TreeMap;

import dasherJava.DasherJava;
import dasherJava.core.languageModeling.LanguageAlphabet.UnicodeNotFoundException;
import dasherJava.core.languageModeling.TrainingFileReader.TrainingFileReaderException;

public class LanguageModel {
	
	public static final int NORMALIZATION = 1<<16;
	
	private final LanguageAlphabet alphabet;
	private final int maxOrder; //This cannot be modified without rebuilding the model
	private final ContextTrieNode contextTrieRoot = new ContextTrieNode(null, null);
	
	//These could be modified without rebuilding the model
	private final int alpha;
	private final int beta;
	private final int uniform;
	
	//Characters with fixed probabilities are treated separately from the PPM model: Since they are independent of
	//the current context, we can precompute their normalized probabilities. These are stored in the array below,
	//which contains a negative value for every character that does not have a fixed probability. When given a
	//context at runtime, we take this mask and fill out all these "holes" with the data returned by the PPM.
	private final int[] fixedProbabilitiesMask;
	//This value is the global NORMALIZATION value minus the sum of all non-negative values in the array above,
	//which is the actual normalization value for the PPM.
	private final int nonFixedProbabilitiesNorm;
	
	public LanguageModel(LanguageAlphabet alphabet, int maxOrder, int alpha, int beta, int uniform) {
		this.alphabet=alphabet;
		this.maxOrder=maxOrder;
		this.alpha=alpha;
		this.beta=beta;
		this.uniform=uniform;
		Map<Integer, LanguageCharacter> fixedProbabilityCharacters = alphabet.getFixedProbabilityCharacters();
		fixedProbabilitiesMask=new int[alphabet.getNumOfCharacters()];
		int fixedProbabilitiesSum = 0;
		int numOfFixedProbabilityCharacters = 0;
		for (int i = 0; i<fixedProbabilitiesMask.length; i++) {
			LanguageCharacter fixedProbabilityCharacter = fixedProbabilityCharacters.get(i);
			if (fixedProbabilityCharacter!=null) {
				int prob = (int) (NORMALIZATION*fixedProbabilityCharacter.getFixedProbability()); //always round down
				fixedProbabilitiesMask[i]=prob;
				fixedProbabilitiesSum+=prob;
				numOfFixedProbabilityCharacters++;
			} else fixedProbabilitiesMask[i]=-1; //non-fixed-probability character
		}
		//If there are only characters with fixed probabilities (and thus no PPM is necessary) we need to ensure that
		//the normalized fixed probabilities sum to the correct value by evenly distributing any leftover space.
		if (numOfFixedProbabilityCharacters==alphabet.getNumOfCharacters() && fixedProbabilitiesSum!=NORMALIZATION) {
			int addToAll = (NORMALIZATION-fixedProbabilitiesSum)/numOfFixedProbabilityCharacters;
			for (int i = 0; i<fixedProbabilitiesMask.length; i++) {
				fixedProbabilitiesMask[i]+=addToAll;
				fixedProbabilitiesSum+=addToAll;
			}
			int remaining = NORMALIZATION-fixedProbabilitiesSum;
			if (remaining<0 || remaining>=numOfFixedProbabilityCharacters) throw new RuntimeException("LanguageModel: "
					+"remaining is "+remaining+", but should be in [0, "+(numOfFixedProbabilityCharacters-1)+"]");
			for (int i = 0; i<fixedProbabilitiesMask.length; i++) {
				if (remaining==0) break;
				fixedProbabilitiesMask[i]++;
				fixedProbabilitiesSum++;
				remaining--;
			}
			if (fixedProbabilitiesSum!=NORMALIZATION) throw new RuntimeException("LanguageModel: "
					+"fixedProbabilitiesSum is "+fixedProbabilitiesSum+", but should be "+NORMALIZATION);
		}
		nonFixedProbabilitiesNorm=NORMALIZATION-fixedProbabilitiesSum;
		System.out.println("Language model creation done (no training yet). Non-fixed probabilities norm: "
				+nonFixedProbabilitiesNorm+". Fixed probabilities mask: "+Arrays.toString(fixedProbabilitiesMask));
		if (nonFixedProbabilitiesNorm<0) throw new RuntimeException("LanguageModel: Sum of fixed probabilities is "
				+"too large: "+Arrays.toString(fixedProbabilitiesMask));
	}
	
	public void train(Iterable<Integer> trainingData, LanguageModelTrainingStats trainingStatsReport) {
		Context trainingContext = createEmptyContext();
		//Cannot use enhanced foreach-loop here since we need to catch exceptions
		//and then continue the loop
		Iterator<Integer> iterator = trainingData.iterator();
		try {
			while (iterator.hasNext()) {
				try {
					Integer symbolIndex = iterator.next(); //may throw UnicodeNotFoundException
					trainingContext.learnSymbol(symbolIndex, trainingStatsReport);
					trainingStatsReport.incrementNumOfSymbolsRead();
				} catch (UnicodeNotFoundException ex) {
					trainingStatsReport.incrementSkippedSymbolCount(ex.getUnicode());
				}
			}
		} catch (TrainingFileReaderException ex) {
			DasherJava.showErrorMessage("Language model training aborted", ex);
		}
	}
	
	public int[] getNonFixedProbabilities(Context context) {
		int[] result = new int[alphabet.getNumOfCharacters()-alphabet.getFixedProbabilityCharacters().size()];
		if (result.length==0) return result; //only fixed probability characters in this model
		int uniformAdd = Math.max(1, nonFixedProbabilitiesNorm*uniform/1000/result.length);
		int toSpend = nonFixedProbabilitiesNorm-result.length*uniformAdd;
		for (ContextTrieNode temp = context.getHead(); temp!=null; temp=temp.getVine()) {
			int total = 0;
			for (ContextTrieNode child : temp.getChildren()) {
				total+=child.getCount();
			}
			if (total!=0) {
				int sizeOfSlice = toSpend;
				for (Entry<Integer, ContextTrieNode> childEntry : temp.getChildrenEntries()) {
					int p = (int) ((long) sizeOfSlice*(100*childEntry.getValue().getCount()-beta)/(100*total+alpha));
					result[childEntry.getKey()]+=p;
					toSpend-=p;
				}
			}
		}
		int sizeOfSlice2 = toSpend;
		for (int symbolIndex = 0; symbolIndex<result.length; symbolIndex++) {
			int p = sizeOfSlice2/result.length;
			result[symbolIndex]+=p;
			toSpend-=p;
		}
		int left = result.length;
		for (int symbolIndex = 0; symbolIndex<result.length; symbolIndex++) {
			int p = toSpend/left;
			result[symbolIndex]+=p+uniformAdd;
			left--;
			toSpend-=p;
		}
		//Check sum and non-negativity:
		int sum = 0;
		for (int i : result) {
			sum+=i;
			if (i<0) throw new RuntimeException("getNonFixedProbabilities(): Negative probability: "+i);
		}
		if (sum!=nonFixedProbabilitiesNorm) throw new RuntimeException("getNonFixedProbabilities(): Sum is "+sum
				+", but should be "+nonFixedProbabilitiesNorm);
		//
		return result;
	}
	
	public int[] getProbabilities(Context context) {
		int[] result = new int[fixedProbabilitiesMask.length];
		int[] nonFixedProbabilities = getNonFixedProbabilities(context);
		int nonFixedProbabilitiesIndex = 0;
		for (int i = 0; i<result.length; i++) {
			if (fixedProbabilitiesMask[i]>=0) result[i]=fixedProbabilitiesMask[i]; //copy fixed probability
			else { //use computed probability
				result[i]=nonFixedProbabilities[nonFixedProbabilitiesIndex];
				nonFixedProbabilitiesIndex++;
			}
		}
		//Check sum and non-negativity:
		int sum = 0;
		for (int i : result) {
			sum+=i;
			if (i<0) throw new RuntimeException("getProbabilities(): Negative probability: "+i);
		}
		if (sum!=NORMALIZATION) throw new RuntimeException("getProbabilities(): Sum is "+sum
				+", but should be "+NORMALIZATION);
		//
		return result;
	}
	
	public Context createEmptyContext() {
		return new Context(this, contextTrieRoot);
	}
	
	public LanguageAlphabet getAlphabet() {
		return alphabet;
	}
	
	int getMaxOrder() {
		return maxOrder;
	}
	
	public static class LanguageModelTrainingStats {
		private int numOfSymbolsRead = 0;
		private int numOfContextTrieNodes = 0;
		private final SortedMap<Integer, Integer> skippedSymbols = new TreeMap<>(); //sorted for output
		public void incrementNumOfSymbolsRead() {
			numOfSymbolsRead++;
			if (numOfSymbolsRead%1000000==0) System.out.println("Language model training in progress: Read "
					+numOfSymbolsRead+" symbols so far");
		}
		public void incrementNumOfContextTrieNodes() {
			numOfContextTrieNodes++;
			if (numOfContextTrieNodes%1000000==0) System.out.println("Language model training in progress: Created "
					+numOfContextTrieNodes+" context trie nodes so far");
		}
		public void incrementSkippedSymbolCount(int unicode) {
			Integer alreadySkippedCount = skippedSymbols.getOrDefault(unicode, 0);
			skippedSymbols.put(unicode, alreadySkippedCount+1);
		}
		public int getNumOfSymbolsRead() {
			return numOfSymbolsRead;
		}
		public int getNumOfContextTrieNodes() {
			return numOfContextTrieNodes;
		}
		public Map<Integer, Integer> getSkippedSymbols() {
			return skippedSymbols;
		}
	}
}