/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.inference.AbstractBeliefPropagation;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.Variable;
import gnu.trove.THashSet;

public class TreeBP
extends AbstractBeliefPropagation {
    private transient THashSet marked;
    private transient Variable root;

    public static TreeBP createForMaxProduct() {
        return (TreeBP)new TreeBP().setMessager(new AbstractBeliefPropagation.MaxProductMessageStrategy());
    }

    @Override
    public void computeMarginals(FactorGraph fg) {
        this.initForGraph(fg);
        this.marked = new THashSet();
        this.lambdaPropagation(fg, null, this.root);
        this.marked = new THashSet();
        this.piPropagation(fg, this.root);
    }

    @Override
    protected void initForGraph(FactorGraph fg) {
        super.initForGraph(fg);
        this.root = (Variable)fg.variablesIterator().next();
    }

    private void lambdaPropagation(FactorGraph mdl, Factor parent, Variable child) {
        logger.finer("lambda propagation " + parent + " , " + child);
        this.marked.add((Object)child);
        for (Factor gchild : mdl.allFactorsContaining(child)) {
            if (this.marked.contains((Object)gchild)) continue;
            this.lambdaPropagation(mdl, child, gchild);
        }
        if (parent != null) {
            this.sendMessage(mdl, child, parent);
        }
    }

    private void lambdaPropagation(FactorGraph mdl, Variable parent, Factor child) {
        logger.finer("lambda propagation " + parent + " , " + child);
        this.marked.add((Object)child);
        for (Variable gchild : child.varSet()) {
            if (this.marked.contains((Object)gchild)) continue;
            this.lambdaPropagation(mdl, child, gchild);
        }
        if (parent != null) {
            this.sendMessage(mdl, child, parent);
        }
    }

    private void piPropagation(FactorGraph mdl, Variable var) {
        logger.finer("Pi propagation from " + var);
        this.marked.add((Object)var);
        for (Factor child : mdl.allFactorsContaining(var)) {
            if (this.marked.contains((Object)child)) continue;
            this.sendMessage(mdl, var, child);
            this.piPropagation(mdl, child);
        }
    }

    private void piPropagation(FactorGraph mdl, Factor factor) {
        logger.finer("Pi propagation from " + factor);
        this.marked.add((Object)factor);
        for (Variable child : factor.varSet()) {
            if (this.marked.contains((Object)child)) continue;
            this.sendMessage(mdl, factor, child);
            this.piPropagation(mdl, child);
        }
    }
}

