User:Kevin Baas/stat generator code/CountWords.java

import java.util.*;
import java.io.*;

public class CountWords {
	int total_word_count = 0;
	Hashtable<String,Integer> total_word_counts = new Hashtable<String,Integer>();
	Vector<String> articles = new Vector<String>();
	public String[] sarticles;
	//Vector<Integer> article_totals;
	public String count_path = "";
	Hashtable<String,Integer> article_total_word_counts = new Hashtable<String,Integer>();
	//Hashtable<String,Hashtable<String,Integer>> article_word_counts = new Hashtable<String,Hashtable<String,Integer>>();

	public void countArticleWords(String article, String[] words) {
		articles.add(article);
		int total = 0;

		//get article word count entry
		/*
		Hashtable<String,Integer> article_word_count = article_word_counts.get(article);
		if( article_word_count == null) {
			article_word_count = new Hashtable<String,Integer>();
			article_word_counts.put(article,article_word_count);
		}*/

		Hashtable<String,Integer> article_word_count = new Hashtable<String,Integer>();
		System.out.println();

		//can be improved by adding to total word count in a second pass, from iterating through article hash table
		for( int i = 0; i < words.length; i++) {

			//get word and make sure it only contains alphabetic characters
			if( words[i] == null)
				continue;
			String word = words[i].toLowerCase();
			byte[] test = word.getBytes();
			boolean ok = true;
			if( test.length == 0)
				continue;
			if( test[0] >= '0' && test[0] <= '9') {
				for( int j = 0; j < test.length; j++) {
					if( test[j] < '0' || test[j] > '9') {
						ok = false;
						break;
					}
				}
			} else {
				for( int j = 0; j < test.length; j++) {
					if( test[j] < 'a' || test[j] > 'z') {
						ok = false;
						break;
					}
				}
			}
			if( !ok)
				continue;

			//add to word counts
			total++;
			Integer tot_count = total_word_counts.get(word);
			if( tot_count == null)
				tot_count = 0;
			tot_count++;

			Integer art_count = article_word_count.get(word);
			if( art_count == null)
				art_count = 0;
			art_count++;
			total_word_counts.put(word,tot_count);
			article_word_count.put(word,art_count);
			//System.out.print(word+" ");
		}

		//write word count totals
		article_total_word_counts.put(article,total);
		total_word_count += total;
		System.out.print(" "+total+" ");
		write_counts_as_file(article,article_word_count);
	}
	void write_counts_as_file(String article, Hashtable<String,Integer> article_word_count) {
		File f = new File(count_path+"\\"+article+".csv");
		try {
			FileOutputStream fis = new FileOutputStream(f);
			StringBuffer sb = new StringBuffer();
			Enumeration<String> e = article_word_count.keys();
			while(e.hasMoreElements()) {
				String s = e.nextElement();
				//System.out.print(s+" ");
				Integer i = article_word_count.get(s);
				sb.append((s+","+i+"\n"));
			}
			fis.write(new String(sb).getBytes());
			fis.close();
		} catch (Exception ex) { }
	}

	int[] get_doc_term_counts(String article) {
		Hashtable<String,Integer> local = new Hashtable<String,Integer>();
		int[] counts = new int[words.length];
		System.out.println(count_path+"\\"+article+".csv");
		File f = new File(count_path+"\\"+article+".csv");
		try {
			FileInputStream fis = new FileInputStream(f);
			StringBuffer sb = new StringBuffer();
			//while( fis.available() > 0) {
				byte[] bb = new byte[(int)f.length()];
				fis.read(bb);
				sb.append(new String(bb));
			//}
			fis.close();
			String s = new String(sb);
			String[] lines = s.split("\n");
			//System.out.println("lines - "+lines.length);
			for( int i = 0; i < lines.length; i++) {
				String[] fields = lines[i].split(",");
				//System.out.println(fields[0].trim()+":"+ new Integer(fields[1].trim()));
				local.put(fields[0].trim(), new Integer(fields[1].trim()));
			}
			for( int i = 0; i < words.length; i++) {
				Integer val = local.get(words[i]);
				if( val == null)
					val = 0;
				counts[i] = val;
			}
		} catch (Exception ex) {
			ex.printStackTrace();
		}
		return counts;
	}

	public String[] words = null;
	public int[] values = null;
	public double[] freqs = null;
	public void compressMainWordIndex(int words_to_add) {
		int min = 5;
		Vector<String> total_keys = new Vector<String>();
		Enumeration<String> e = total_word_counts.keys();
		while(e.hasMoreElements()) {
			String s = e.nextElement();
			if(total_word_counts.get(s) >= min)
				total_keys.add(s);
		}
		words = new String[total_keys.size()];
		double adj = words_to_add/words.length;
		values = new int[total_keys.size()];
		freqs = new double[total_keys.size()];
		for( int i = 0; i < words.length; i++) {
			words[i] = total_keys.get(i);
			values[i] = total_word_counts.get(words[i]);
			freqs[i] = (values[i]+adj)/(total_word_count+words_to_add);
		}
		sarticles = new String[articles.size()];
		for( int i = 0; i < articles.size(); i ++)
			sarticles[i] = articles.get(i);
	}
	public double[] getMeanRegressedArticleWordFreq(String article, int words_to_add) {
		return null;
		/*
		double[] art_freqs = new double[freqs.length];
		Hashtable<String,Integer> arthash = article_word_counts.get(article);
		int count = article_total_word_counts.get(article);
		for( int i = 0; i < words.length; i++) {
			Integer f = arthash.get(words[i]);
			if( f == null)
				f = 0;
			art_freqs[i] = (f+freqs[i]*words_to_add)/(count+words_to_add);
		}
		return art_freqs;
		*/
	}
	public double[] getWordSurprise(double[] article_word_frequency) {
		double[] surprise = new double[words.length];
		for( int i = 0; i < words.length; i++)
			surprise[i] = Math.log(article_word_frequency[i]/freqs[i]);
		return surprise;
	}
	public double[] get_doc_term_stats(double doc_term_count, double doc_word_count, double tot_term_count, double tot_word_count) {
		double[] stats = new double[0];

		double jpdocterm = doc_term_count / tot_word_count;//=ptermdoc * pdoc;
		double ptermdoc = doc_term_count / doc_word_count;
		double pdocterm = doc_term_count / tot_term_count;//= ptermdoc * pdoc / pterm;
		double pterm = tot_term_count / tot_word_count;
		double pdoc = doc_word_count / tot_word_count;
		double pre = pdoc*pterm;

/*
		double lpterm = -Math.log(pterm);
		double lpdoc = -Math.log(pdoc);
		double lptermdoc = -Math.log(ptermdoc);
		double lpdocterm = -Math.log(pdocterm);
		double lnpterm = -Math.log(1-pterm);
		double lnpdoc = -Math.log(1-pdoc);
		double lnptermdoc = -Math.log(1-ptermdoc);
		double lnpdocterm = -Math.log(1-pdocterm);
*/
		double mult = 10000000;
		stats = new double[]{
				mult*H(pre,ptermdoc), //0
				mult*H(pre,pterm),
				mult*H(pre,pdocterm),
				mult*H(pre,pdoc),
				mult*H(jpdocterm,ptermdoc), //4
				mult*H(jpdocterm,pterm),
				mult*H(jpdocterm,pdocterm),
				mult*H(jpdocterm,pdoc),
				mult*pdoc*H(1-pterm,1-ptermdoc), //8
				mult*pdoc*H(1-pterm,1-pterm),
				mult*pterm*H(1-pdoc,1-pdocterm),
				mult*pterm*H(1-pdoc,1-pdoc),
				mult*pdoc*H(1-ptermdoc,1-ptermdoc),  //12
				mult*pdoc*H(1-ptermdoc,1-pterm),
				mult*pterm*H(1-pdocterm,1-pdocterm),
				mult*pterm*H(1-pdocterm,1-pdoc),
				mult*H(pre,jpdocterm),  //16
				mult*H(jpdocterm,jpdocterm),

				//H(pjdocterm,pdoc),
				//H(pjdocterm,pterm),
				doc_term_count, //=total term count, doc word count
				((doc_term_count >= 1) ? 1 : 0), //=articles term is in, distinct terms in doc
				0,0,0,0,
				0,0,0,0,
				0,0,0,0,
				0,
				//mult*H(pdocterm,pdocterm),
		};
		for( int i = 0; i < 8; i++)
			stats[8+i]+=stats[i];
		for( int i = 0; i < 8; i++)
			stats[20+i] = stats[2*i+1]-stats[2*i];
		stats[28] = stats[16]-stats[1];
		stats[29] = stats[16]-stats[3];
		stats[30] = stats[17]-stats[5];
		stats[31] = stats[17]-stats[7];
		stats[32] = stats[17]-stats[16];
		return stats;
	}

	public double[][][] get_term_stats(String[] docs, int[] doc_word_counts, int[] tot_term_counts, int tot_word_count) {
		double[][] term_stats = new double[tot_term_counts.length][];
		double[][] doc_stats = new double[docs.length][];
		for( int i = 0; i < tot_term_counts.length; i++) {
			term_stats[i] = new double[]{
					0,0,0,0,0,0,0,0,
					0,0,0,0,0,0,0,0,
					0,0,0,0,0,0,0,0,
					0,0,0,0,0,0,0,0,
					0,
			};
		}
		for( int i = 0; i < docs.length; i++) {
			doc_stats[i] = new double[]{
					0,0,0,0,0,0,0,0,
					0,0,0,0,0,0,0,0,
					0,0,0,0,0,0,0,0,
					0,0,0,0,0,0,0,0,
					0,
			};
		}
		for( int i = 0; i < docs.length; i++) {
			System.out.print(".");
			if( i % 100 == 0)
				System.out.println();
			int[] doc_term_counts = get_doc_term_counts(docs[i]);
			for( int j = 0; j < doc_term_counts.length; j++) {
				double[] doc_term_stats = this.get_doc_term_stats(doc_term_counts[j], doc_word_counts[i], tot_term_counts[j], tot_word_count);
				for( int k = 0; k < doc_term_stats.length; k++) {
					doc_stats[i][k] += doc_term_stats[k];
				}
				for( int k = 0; k < doc_term_stats.length; k++)
					term_stats[j][k] += doc_term_stats[k];
			}
		}
		return new double[][][]{doc_stats,term_stats};
	}
	double H(double p,double q) {
		if( p == 0 || q == 0 || p != p || q!= q)
			return 0.0;
		return -p * Math.log(q);
	}
	double[][][] getAllStats() {
		String[] sarticles = new String[articles.size()];
		int[] art_word_counts = new int[articles.size()];
		for( int i = 0; i < sarticles.length; i++) {
			sarticles[i] = articles.get(i);
			Integer n = article_total_word_counts.get(sarticles[i]);
			if( n == null)
				n = 0;
			art_word_counts[i] = n;
		}
		return get_term_stats(sarticles,art_word_counts,values,this.total_word_count);
	}
}