/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.converter.neural_network;

import com.google.common.collect.Iterables;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralEntity;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;

public class NeuralNetworkUtil {
    private NeuralNetworkUtil() {
    }

    public static NeuralInputs createNeuralInputs(List<? extends Feature> features, DataType dataType) {
        NeuralInputs neuralInputs = new NeuralInputs();
        for (int i = 0; i < features.size(); ++i) {
            FieldRef expression;
            Feature feature = features.get(i);
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                expression = new NormDiscrete(binaryFeature.getName(), binaryFeature.getValue());
            } else if (feature instanceof BooleanFeature) {
                BooleanFeature booleanFeature = (BooleanFeature)feature;
                expression = new NormDiscrete(booleanFeature.getName(), (Object)"true");
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                expression = continuousFeature.ref();
            }
            DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, dataType).setExpression((Expression)expression);
            NeuralInput neuralInput = new NeuralInput().setId("input/" + String.valueOf(i + 1)).setDerivedField(derivedField);
            neuralInputs.addNeuralInputs(new NeuralInput[]{neuralInput});
        }
        return neuralInputs;
    }

    public static Neuron createNeuron(List<? extends NeuralEntity> entities, List<Double> weights, Double bias) {
        if (entities.size() != weights.size()) {
            throw new IllegalArgumentException();
        }
        Neuron neuron = new Neuron();
        for (int i = 0; i < entities.size(); ++i) {
            NeuralEntity entity = entities.get(i);
            Double weight = weights.get(i);
            if (weight == null || ValueUtil.isZeroLike(weight)) continue;
            Connection connection = new Connection().setFrom((String)entity.getId()).setWeight(weight.doubleValue());
            neuron.addConnections(new Connection[]{connection});
        }
        if (bias != null && !ValueUtil.isZeroLike(bias)) {
            neuron.setBias(bias);
        }
        return neuron;
    }

    public static List<NeuralLayer> createBinaryLogisticTransformation(NeuralEntity entity) {
        NeuralLayer inputLayer = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
        Neuron logisticNeuron = new Neuron().setId("logistic/1").setBias(null).addConnections(new Connection[]{new Connection((String)entity.getId(), 1.0)});
        inputLayer.addNeurons(new Neuron[]{logisticNeuron});
        entity = logisticNeuron;
        NeuralLayer outputLayer = new NeuralLayer().setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY);
        Neuron noEventNeuron = new Neuron().setId("event/false").setBias(Double.valueOf(1.0)).addConnections(new Connection[]{new Connection((String)entity.getId(), -1.0)});
        Neuron eventNeuron = new Neuron().setId("event/true").setBias(null).addConnections(new Connection[]{new Connection((String)entity.getId(), 1.0)});
        outputLayer.addNeurons(new Neuron[]{noEventNeuron, eventNeuron});
        return Arrays.asList(inputLayer, outputLayer);
    }

    public static NeuralOutputs createRegressionNeuralOutputs(List<? extends NeuralEntity> entities, ContinuousLabel continuousLabel) {
        if (entities.size() != 1) {
            throw new IllegalArgumentException();
        }
        NeuralEntity entity = (NeuralEntity)Iterables.getOnlyElement(entities);
        DerivedField derivedField = new DerivedField(OpType.CONTINUOUS, continuousLabel.getDataType()).setExpression((Expression)new FieldRef(continuousLabel.getName()));
        NeuralOutput neuralOutput = new NeuralOutput().setOutputNeuron((String)entity.getId()).setDerivedField(derivedField);
        NeuralOutputs neuralOutputs = new NeuralOutputs().addNeuralOutputs(new NeuralOutput[]{neuralOutput});
        return neuralOutputs;
    }

    public static NeuralOutputs createClassificationNeuralOutputs(List<? extends NeuralEntity> entities, CategoricalLabel categoricalLabel) {
        SchemaUtil.checkSize(entities.size(), categoricalLabel);
        NeuralOutputs neuralOutputs = new NeuralOutputs();
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            NeuralEntity entity = entities.get(i);
            DerivedField derivedField = new DerivedField(OpType.CATEGORICAL, categoricalLabel.getDataType()).setExpression((Expression)new NormDiscrete(categoricalLabel.getName(), categoricalLabel.getValue(i)));
            NeuralOutput neuralOutput = new NeuralOutput().setOutputNeuron((String)entity.getId()).setDerivedField(derivedField);
            neuralOutputs.addNeuralOutputs(new NeuralOutput[]{neuralOutput});
        }
        return neuralOutputs;
    }
}

