DasherJava / src / dasherJava / core / languageModeling / ContextTrieNode.java
ContextTrieNode.java
Raw
package dasherJava.core.languageModeling;

import java.util.Collections;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;

import dasherJava.core.languageModeling.LanguageModel.LanguageModelTrainingStats;

public class ContextTrieNode {
	
	private final ContextTrieNode vine;
	private Map<Integer, ContextTrieNode> children = null; //implementation of this map is highly performance-critical!
	private int count = 0;
	
	public ContextTrieNode(ContextTrieNode vine, LanguageModelTrainingStats trainingStatsReport) {
		this.vine=vine;
		if (trainingStatsReport!=null) trainingStatsReport.incrementNumOfContextTrieNodes();
	}
	
	public ContextTrieNode getVine() {
		return vine;
	}
	
	public ContextTrieNode createChild(int symbolIndex, LanguageModelTrainingStats trainingStatsReport) {
		if (children==null) children=new TreeMap<>(); //using TreeMap instead of HashMap since memory usage is
		                                              //far more critical than lookup time
		ContextTrieNode child = children.get(symbolIndex);
		if (child!=null) child.incrementCount();
		else {
			child=new ContextTrieNode(vine==null ? this : vine.createChild(symbolIndex, trainingStatsReport),
					trainingStatsReport); //vine==null means empty (root) context
			child.incrementCount(); //initial count 1
			children.put(symbolIndex, child);
		}
		return child;
	}
	
	public ContextTrieNode getChild(int symbolIndex) {
		if (children==null) return null; //No children at all yet
		return children.get(symbolIndex);
	}
	
	public Iterable<ContextTrieNode> getChildren() {
		if (children==null) return Collections.emptySet(); //No children at all yet
		return children.values();
	}
	
	public Iterable<Entry<Integer, ContextTrieNode>> getChildrenEntries() {
		if (children==null) return Collections.emptySet(); //No children at all yet
		return children.entrySet();
	}
	
	public void incrementCount() {
		count++;
	}
	
	public int getCount() {
		return count;
	}
}