DasherJava / src / dasherJava / core / languageModeling / TrainingFileReader.java
TrainingFileReader.java
Raw
package dasherJava.core.languageModeling;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;

public class TrainingFileReader extends BufferedReader implements Iterable<Integer> {
	
	private final LanguageAlphabet alphabet;
	private int nextChar;
	
	public TrainingFileReader(String filename, LanguageAlphabet alphabet) throws IOException {
		super(new InputStreamReader(new FileInputStream(filename), StandardCharsets.UTF_8));
		this.alphabet=alphabet;
		nextChar=read();
	}
	
	@Override
	public Iterator<Integer> iterator() {
		return new Iterator<>() {
			@Override
			public boolean hasNext() {
				return nextChar!=-1;
			}
			@Override
			public Integer next() {
				if (nextChar==-1) throw new TrainingFileReaderException("next() has been called while hasNext() "
						+"is false (EOF reached)");
				int tmp = nextChar;
				try {
					nextChar=read();
				} catch (IOException ex) {
					throw new TrainingFileReaderException("IOException: "+ex.getMessage(), ex);
				}
				return alphabet.getLanguageModelSymbolIndex(tmp); //may throw UnicodeNotFoundException
			}
		};
	}
	
	@Override
	public int read() throws IOException { //adapted from https://stackoverflow.com/a/53271251
		int high = super.read();
		if (high<0) return -1; //EOF
		if (!Character.isHighSurrogate((char) high)) return high;
		int low = super.read();
		if (low<0) return -1; //EOF
		if (!Character.isLowSurrogate((char) low)) throw new TrainingFileReaderException("Invalid surrogate pair");
		return Character.toCodePoint((char) high, (char) low);
	}
	
	public static class TrainingFileReaderException extends RuntimeException {
		public TrainingFileReaderException(String message) {
			super(message);
		}
		public TrainingFileReaderException(String message, Throwable cause) {
			super(message, cause);
		}
	}
}