package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Pair;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:stanford-parser.jar:edu/stanford/nlp/parser/lexparser/UnknownGTTrainer.class */
public class UnknownGTTrainer {
    ClassicCounter<Pair<String, String>> wtCount = new ClassicCounter<>();
    ClassicCounter<String> tagCount = new ClassicCounter<>();
    ClassicCounter<String> r1 = new ClassicCounter<>();
    ClassicCounter<String> r0 = new ClassicCounter<>();
    Set<String> seenWords = new HashSet();
    double tokens = 0.0d;
    HashMap<String, Float> unknownGT = new HashMap<>();

    public void train(Collection<Tree> collection) {
        train(collection, 1.0d);
    }

    public void train(Collection<Tree> collection, double d) {
        Iterator<Tree> it = collection.iterator();
        while (it.hasNext()) {
            train(it.next(), d);
        }
    }

    public void train(Tree tree, double d) {
        Iterator<TaggedWord> it = tree.taggedYield().iterator();
        while (it.hasNext()) {
            train(it.next(), d);
        }
    }

    public void train(TaggedWord taggedWord, double d) {
        this.tokens += d;
        String word = taggedWord.word();
        String tag = taggedWord.tag();
        this.wtCount.incrementCount(new Pair<>(word, tag), d);
        this.tagCount.incrementCount(tag, d);
        this.seenWords.add(word);
    }

    public void finishTraining() {
        System.err.println("Total tokens: " + this.tokens);
        System.err.println("Total WordTag types: " + this.wtCount.keySet().size());
        System.err.println("Total tag types: " + this.tagCount.keySet().size());
        System.err.println("Total word types: " + this.seenWords.size());
        for (Pair<String, String> pair : this.wtCount.keySet()) {
            if (this.wtCount.getCount(pair) == 1.0d) {
                this.r1.incrementCount(pair.second());
            }
        }
        for (String str : this.tagCount.keySet()) {
            Iterator<String> it = this.seenWords.iterator();
            while (it.hasNext()) {
                if (!this.wtCount.keySet().contains(new Pair(it.next(), str))) {
                    this.r0.incrementCount(str);
                }
            }
        }
        for (String str2 : this.tagCount.keySet()) {
            this.unknownGT.put(str2, Float.valueOf((float) Math.log(this.r1.getCount(str2) / (this.tagCount.getCount(str2) * this.r0.getCount(str2)))));
        }
    }
}
