package com.spbsu.direct;
import com.spbsu.commons.func.Action;
import com.spbsu.commons.io.StreamTools;
import com.spbsu.commons.io.codec.seq.DictExpansion;
import com.spbsu.commons.io.codec.seq.ListDictionary;
import com.spbsu.commons.math.io.Vec2CharSequenceConverter;
import com.spbsu.commons.math.vectors.impl.vectors.SparseVec;
import com.spbsu.commons.seq.*;
import com.spbsu.commons.util.ThreadTools;
import com.spbsu.direct.gen.SimpleGenerativeModel;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TObjectDoubleMap;
import gnu.trove.map.custom_hash.TObjectDoubleCustomHashMap;
import gnu.trove.strategy.HashingStrategy;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* User: solar
* Date: 07.10.15
* Time: 15:56
*/
public class BroadMatch {
public static boolean debug = true;
volatile static int index = 0;
volatile static int windex = 0;
public static void main(String[] args) throws IOException {
if (args.length < 2)
throw new IllegalArgumentException("Need at least two arguments: mode and file to work with");
switch (args[0]) {
case "-dict": {
final DictExpansion<CharSeq> expansion = new DictExpansion<>(Integer.parseInt(args[1]), System.out);
final String outputFile = args[2];
final Action<DictExpansion<CharSeq>> printer = new Action<DictExpansion<CharSeq>>() {
int dictIndex = 0;
@Override
public void invoke(DictExpansion<CharSeq> result) {
try {
System.out.println("Dump dictionary #" + dictIndex);
result.print(new FileWriter(StreamTools.stripExtension(outputFile) + "-" + dictIndex + ".dict"));
dictIndex++;
windex = 0;
} catch (Exception e) {
e.printStackTrace();
}
}
};
expansion.addListener(printer);
final ThreadPoolExecutor executor = ThreadTools.createBGExecutor("Creating DictExpansion", 100000);
for (int i = 3; i < args.length; i++) {
CharSeqTools.processLines(StreamTools.openTextFile(args[i]), new Action<CharSequence>() {
String current;
@Override
public void invoke(CharSequence line) {
final CharSequence[] parts = new CharSequence[3];
if (CharSeqTools.split(line, '\t', parts).length != 3)
throw new IllegalArgumentException("Each input line must contain <uid>\\t<ts>\\t<query> triplet. This one: [" + line + "]@" + outputFile + ":" + index + " does not.");
if (CharSeqTools.startsWith(parts[0], "uu/") || CharSeqTools.startsWith(parts[0], "r"))
return;
final String uid = parts[0].toString();
final String query = parts[2].toString();
if (query.equals(current))
return;
current = query;
// if (!CharSeqTools.equals(parts[0], currentUser))
final Runnable item = () -> {
final String normalizedQuery = normalizeQuery(query);
final ArraySeq<CharSeq> seq = convertToSeq(normalizedQuery);
if (windex++ < 10)
System.out.println(uid + ": " + normalizedQuery + " -> " + seq);
expansion.accept(seq);
};
final BlockingQueue<Runnable> queue = executor.getQueue();
//noinspection Duplicates
if (queue.remainingCapacity() == 0) {
try {
queue.put(item);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
} else executor.execute(item);
}
});
}
break;
}
case "-depends": {
final double alpha = 0.5;
final TIntList freqsLA = new TIntArrayList();
final ListDictionary<CharSeq> dict = loadDictionaryWithFreqs(args[1], freqsLA);
final SimpleGenerativeModel model = new SimpleGenerativeModel(dict, freqsLA);
model.loadStatistics(args[2]);
for (int i = 3; i < args.length; i++) {
final String fileName = args[i];
CharSeqTools.processLines(StreamTools.openTextFile(fileName), new Action<CharSequence>() {
long ts;
String query;
String user;
IntSeq prevQSeq;
@Override
public void invoke(CharSequence line) {
final CharSequence[] parts = new CharSequence[3];
if (CharSeqTools.split(line, '\t', parts).length != 3)
throw new IllegalArgumentException("Each input line must contain <uid>\\t<ts>\\t<query> triplet. This one: [" + line + "]@" + fileName + ":" + index + " does not.");
if (CharSeqTools.startsWith(parts[0], "uu/") || CharSeqTools.startsWith(parts[0], "r"))
return;
final long ts = CharSeqTools.parseLong(parts[1]);
final String query = normalizeQuery(parts[2].toString());
if (query == null || query.equals(this.query)) {
this.ts = ts;
return;
}
final IntSeq currentQSeq = dropUnknown(dict.parse(convertToSeq(query), model.freqs, model.totalFreq));
if (currentQSeq == null) {
this.ts = ts;
return;
}
model.processSeq(currentQSeq);
final String prev = parts[0].equals(this.user) && ts - this.ts < TimeUnit.MINUTES.toSeconds(30) ? this.query : null;
this.query = parts[2].toString();
this.user = parts[0].toString();
this.ts = ts;
if (prev != null && prevQSeq != null) {
model.processGeneration(prevQSeq, currentQSeq, alpha);
}
prevQSeq = currentQSeq;
if (++index % 10000000 == 0) {
try (final Writer out = new OutputStreamWriter(new FileOutputStream("output-" + (index / 10000000) + ".txt"))) {
model.print(out, true);
} catch (IOException e) {
e.printStackTrace();
}
}
}
});
}
try (final Writer out = new OutputStreamWriter(new FileOutputStream("output-" + (index / 10000000) + ".txt"))) {
model.print(out, false);
} catch (IOException e) {
e.printStackTrace();
}
break;
}
case "-stats": {
final Vec2CharSequenceConverter converter = new Vec2CharSequenceConverter();
final TIntList freqs = new TIntArrayList();
final ListDictionary<CharSeq> dict = loadDictionaryWithFreqs(args[1], freqs);
final SparseVec[] stats = new SparseVec[dict.size() + 1];
for (int i = 0; i < stats.length; i++) {
stats[i] = new SparseVec(dict.size());
}
final String outputFile = args[2];
//noinspection LoopStatementThatDoesntLoop
for (int i = 3; i < args.length; i++) {
CharSeqTools.processLines(StreamTools.openTextFile(args[i]), new Action<CharSequence>() {
long ts;
String query;
String user;
IntSeq prevQSeq;
double totalFreq = freqs.sum();
@Override
public void invoke(CharSequence line) {
final CharSequence[] parts = new CharSequence[3];
if (CharSeqTools.split(line, '\t', parts).length != 3)
throw new IllegalArgumentException("Each input line must contain <uid>\\t<ts>\\t<query> triplet. This one: [" + line + "]@" + args[i] + ":" + index + " does not.");
if (CharSeqTools.startsWith(parts[0], "uu/") || CharSeqTools.startsWith(parts[0], "r"))
return;
final long ts = CharSeqTools.parseLong(parts[1]);
final String query = normalizeQuery(parts[2].toString());
if (query == null || query.equals(this.query)) {
this.ts = ts;
return;
}
final IntSeq currentQSeq = dropUnknown(dict.parse(convertToSeq(query), freqs, totalFreq));
if (currentQSeq == null) {
prevQSeq = null;
this.query = null;
return;
}
for (int i = 0; i < currentQSeq.length(); i++) {
final int symbol = currentQSeq.intAt(i);
if (symbol >= freqs.size())
freqs.fill(freqs.size(), symbol + 1, 0);
freqs.set(symbol, freqs.get(symbol) + 1);
}
final CharSequence uid = parts[0];
if (!uid.equals(this.user)) {
{ // session start
prevQSeq = null;
for (int i = 0; i < currentQSeq.length(); i++) {
stats[dict.size()].adjust(currentQSeq.intAt(i), 1.);
}
}
}
final IntSeq prevQSeq = uid.equals(this.user) && ts - this.ts < TimeUnit.MINUTES.toSeconds(30) ? this.prevQSeq : null;
this.query = query;
this.user = uid.toString();
this.ts = ts;
if (prevQSeq != null) {
for (int i = 0; i < prevQSeq.length(); i++) {
for (int j = 0; j < currentQSeq.length(); j++) {
stats[prevQSeq.intAt(i)].adjust(currentQSeq.intAt(j), 1.);
}
}
}
this.prevQSeq = currentQSeq;
{ // stats dump
if (++index % 10000000 == 0) {
final String outputFileI = StreamTools.stripExtension(outputFile) + "-" + (index / 10000000) + ".stats";
System.out.println("Dump " + outputFileI);
try (final Writer out = new OutputStreamWriter(new FileOutputStream(outputFileI))) {
for (int i = 0; i < stats.length; i++) {
final SparseVec stat = stats[i];
if (i < dict.size())
out.append(dict.get(i).toString());
else
out.append(SimpleGenerativeModel.EMPTY_ID);
out.append("\t");
out.append(converter.convertTo(stat));
out.append("\n");
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
});
break;
}
}
case "-query": {
final TIntList freqs = new TIntArrayList();
final ListDictionary<CharSeq> dict = loadDictionaryWithFreqs(args[1], freqs);
final SimpleGenerativeModel model = new SimpleGenerativeModel(dict, freqs);
model.load(args[2]);
CharSeqTools.processLines(new InputStreamReader(System.in, StreamTools.UTF), (Action<CharSequence>) arg -> {
String query = arg.toString();
normalizeQuery(query);
System.out.println(model.findTheBestExpansion(convertToSeq(normalizeQuery(arg.toString()))));
});
}
}
}
@Nullable
private static String normalizeQuery(String query) {
query = query.replaceAll("[;,.:\\(\\)\"\'«»!\\]\\[\\{\\}<>]", "");
query = query.replaceAll("\\s+", " ");
return query.toLowerCase();
}
@NotNull
private static ListDictionary<CharSeq> loadDictionaryWithFreqs(String arg, final TIntList freqs) throws IOException {
final ListDictionary<CharSeq> dict;
{ // dict
final TObjectDoubleMap<Seq<CharSeq>> freqsHash = new TObjectDoubleCustomHashMap<>(new HashingStrategy<Object>() {
@Override
public int computeHashCode(Object object) {
return object.hashCode();
}
@Override
public boolean equals(Object o1, Object o2) {
return o1.equals(o2);
}
});
final List<Seq<CharSeq>> dictSeqs = new ArrayList<>();
CharSeqTools.processLines(StreamTools.openTextFile(arg), (Action<CharSequence>) line -> {
final CharSequence[] split = CharSeqTools.split(line, '\t', new CharSequence[2]);
final CharSequence[] parts = CharSeqTools.split(split[0].subSequence(1, split[0].length() - 1), ", ");
final SeqBuilder<CharSeq> builder = new ArraySeqBuilder<>(CharSeq.class);
for (final CharSequence part : parts) {
builder.add(CharSeq.create(part.toString()));
}
final Seq<CharSeq> seq = builder.build();
dictSeqs.add(seq);
freqsHash.put(seq, CharSeqTools.parseDouble(split[1]));
});
//noinspection unchecked
dict = new ListDictionary<>(dictSeqs.toArray(new Seq[dictSeqs.size()]));
freqs.fill(0, dict.size(), 0);
for (int i = 0; i < dict.size(); i++) {
final double val = freqsHash.get(dict.get(i));
freqs.add((int)val);
}
}
return dict;
}
@Nullable
private static IntSeq dropUnknown(IntSeq parse) {
final IntSeqBuilder builder = new IntSeqBuilder();
for (int i = 0; i < parse.length(); i++) {
final int val = parse.intAt(i);
if (val >= 0)
builder.add(val);
}
return builder.length() > 0 ? builder.build() : null;
}
private static ArraySeq<CharSeq> convertToSeq(CharSequence word) {
final CharSeq[] words = new CharSeq[100];
final String query = word.toString();
final int wcount = CharSeqTools.trySplit(CharSeq.create(query), ' ', words);
return new ArraySeq<>(words, 0, wcount);
}
}