#include <map>
#include <set>
#include <string>
#include <iostream>
#include <fstream>
#include "csvstream.h"
#include <math.h>
using namespace std;
class classifier {
private:
//map representing the number of posts for each unique label
map<string, int> num_post_label;
//map of map representing the number of posts for each unique word given a label
map<string, map<string, int>> num_post_label_with_word;
//map representing the number of occurances for each unique word in the whole file
map<string, int> num_words;
//number of posts in the file
int num_posts;
//number of unique words in the file
int num_unique_words;
public:
classifier(): num_posts(0), num_unique_words(0) {}
set<string> unique_words(const string &str) {
istringstream source(str);
set<string> words;
string word;
while (source >> word) {
words.insert(word);
}
return words;
}
void process_data(string file_in) {
//read in a single csv file, with each line representing a post,
//reading in and store each line as a node in row_value;
;
csvstream name(file_in);
map<string, string> row_value;
while (name >> row_value) {
num_post_label[row_value["tag"]]++;
set<string> split_words = unique_words(row_value["content"]);
for (const auto &i : split_words) {
num_words[i]++;
num_post_label_with_word[row_value["tag"]][i]++;
}
num_posts += 1;
}
num_unique_words = static_cast<int>(num_words.size());
}
void process_print_norm(string file_in) {
process_data(file_in);
cout << "trained on " << num_posts << " examples" << endl;
}
void process_print_debug(string file_in) {
process_data(file_in);
cout << "training data:" << endl;
map<string, string> row_value;
csvstream csvin(file_in);
while (csvin >> row_value) {
string label = row_value["tag"];
string content = row_value["content"];
string predicted = find_most_probable_label(content).first;
cout << " label = " << label << ", content = " << content << endl;
}
cout << "trained on " << num_posts << " examples" << endl
<< "vocabulary size = " << num_unique_words << endl << endl;
cout << "classes:" << endl;
for(const auto &i : num_post_label) {
cout << " " << i.first << ", " << i.second << " examples, log-prior = "
<< compute_log_labelc(i.first) << endl;
}
cout << "classifier parameters:" << endl;
for(const auto &i : num_post_label_with_word) {
for (const auto &j : i.second) {
cout << " " << i.first << ":" << j.first << ", count = "
<< j.second << ", log-likelihood = "
<< compute_log_labelc_givenw(i.first, j.first) << endl;
}
}
}
double compute_log_labelc(string label) {
//computes the log probability of finding a post of a given label
//case 1: label doesnt exist
if (num_post_label.find(label) == num_post_label.end()) {
return log(1.0 / num_posts);
}
//case 2: label exists
else {
return log(num_post_label[label] / static_cast<double>(num_posts));
}
}
double compute_log_labelc_givenw(string label, string word) {
//computes the log probability of finding a post with a word given a label
//case 1: word does not exist
if (num_words.find(word) == num_words.end()) {
return log(1.0 / num_posts);
}
//case 2: word exists but not in label
else if (num_post_label_with_word[label].find(word) ==
num_post_label_with_word[label].end()) {
return log(num_words[word] / static_cast<double>(num_posts));
}
//case 3: word exists and is within label
else {
return log(num_post_label_with_word[label][word] /
static_cast<double>(num_post_label[label]));
}
}
double compute_prob(set<string> post_content, string label) {
//computes probability of a post having the given label
double prob_post = compute_log_labelc(label);
for (const auto &word : post_content) {
prob_post += compute_log_labelc_givenw(label, word);
}
return prob_post;
}
pair<string, double> find_most_probable_label(string post_in) {
//finds the label of the highest proability for a given post
set<string> set_of_unique_words = unique_words(post_in);
string label = num_post_label.begin()->first;
double max_prob = compute_prob(set_of_unique_words, label);
for (const auto &l : num_post_label) {
double prob = compute_prob(set_of_unique_words, l.first);
if (prob > max_prob) {
label = l.first;
max_prob = prob;
}
}
pair<string, double> pair (label, max_prob);
return pair;
}
// double compute_lps(string post_in) {
// // computes lps of predicted label based on post
// set<string> set_of_unique_words = unique_words(post_in);
// string label = find_most_probable_label(post_in);
// return compute_prob(set_of_unique_words, label);
// }
void predict(string file_in) {
//predicts labels for each post prints results to cout
int count_correct = 0, count_total = 0;
cout << endl << "test data:" << endl;
map<string, string> row_value;
csvstream csvin(file_in);
while (csvin >> row_value) {
string correct = row_value["tag"];
string content = row_value["content"];
string predicted = find_most_probable_label(content).first;
// double lps = compute_lps(row_value["content"]);
double lps = find_most_probable_label(content).second;
cout << " correct = " << correct << ", predicted = " << predicted
<< ", log-probability score = " << lps << endl
<< " content = " << content << endl << endl;
if (correct == predicted) {count_correct++;}
count_total++;
}
cout << "performance: " << count_correct << " / " << count_total
<< " posts predicted correctly" << endl;
}
};
int main(int argc, char **argv) {
cout.precision(3);
string executable, train_file, test_file, debug = "";
if (!(argc == 3 || argc == 4)) {
cout << "Usage: main.exe TRAIN_FILE TEST_FILE [--debug]" << endl;
return -1;
}
else if (argc == 3) {
executable = argv[0]; train_file = argv[1]; test_file = argv[2];
}
else {
executable = argv[0]; train_file = argv[1];
test_file = argv[2]; debug = argv[3];
if (debug != "--debug") {
cout << "Usage: main.exe TRAIN_FILE TEST_FILE [--debug]" << endl;
return -1;
}
}
ifstream fin_train(train_file);
ifstream fin_test(test_file);
if (!fin_train.is_open()) {
cout << "Error opening file: " << train_file << endl;
return -1;
}
if (!fin_test.is_open()) {
cout << "Error opening file: " << test_file << endl;
return -1;
}
classifier test;
if (debug == "--debug") {test.process_print_debug(train_file);}
else {test.process_print_norm(train_file);}
test.predict(test_file);
return 0;
}