package marf.Classification.NeuralNetwork;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Random;
import java.util.Vector;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import marf.Classification.Classification;
import marf.Classification.ClassificationException;
import marf.FeatureExtraction.IFeatureExtraction;
import marf.MARF;
import marf.Storage.ITrainingSample;
import marf.Storage.Result;
import marf.Storage.StorageException;
import marf.util.Debug;
import org.w3c.dom.Document;
import org.w3c.dom.NamedNodeMap;
import org.w3c.dom.Node;
import org.xml.sax.ErrorHandler;
import org.xml.sax.SAXException;
import org.xml.sax.SAXParseException;

/* loaded from: input_file:marf/Classification/NeuralNetwork/NeuralNetwork.class */
public class NeuralNetwork extends Classification {
    public static final int DEFAULT_OUTPUT_NEURON_BITS = 32;
    public static final double DEFAULT_TRAINING_CONSTANT = 1.0d;
    public static final int DEFAULT_EPOCH_NUMBER = 64;
    public static final double DEFAULT_MIN_ERROR = 0.1d;
    private ArrayList oLayers;
    private transient Layer oCurrentLayer;
    private transient int iCurrenLayer;
    private transient int iCurrLayerBuf;
    private transient Neuron oCurrNeuron;
    private transient int iNeuronType;
    private Layer oInputs;
    private Layer oOutputs;
    public static final String OUTPUT_ENCODING = "UTF-8";
    public static final String JAXP_SCHEMA_LANGUAGE = "http://java.sun.com/xml/jaxp/properties/schemaLanguage";
    public static final String W3C_XML_SCHEMA = "http://www.w3.org/2001/XMLSchema";
    public static final String JAXP_SCHEMA_SOURCE = "http://java.sun.com/xml/jaxp/properties/schemaSource";
    private static final long serialVersionUID = 6116721242820120028L;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:marf/Classification/NeuralNetwork/NeuralNetwork$NeuralNetworkErrorHandler.class */
    public static class NeuralNetworkErrorHandler implements ErrorHandler {
        private PrintWriter oOut;

        NeuralNetworkErrorHandler(PrintWriter printWriter) {
            this.oOut = printWriter;
        }

        private String getParseExceptionInfo(SAXParseException sAXParseException) {
            String systemId = sAXParseException.getSystemId();
            if (systemId == null) {
                systemId = "null";
            }
            return new StringBuffer().append("URI=").append(systemId).append(" Line=").append(sAXParseException.getLineNumber()).append(": ").append(sAXParseException.getMessage()).toString();
        }

        @Override // org.xml.sax.ErrorHandler
        public void warning(SAXParseException sAXParseException) throws SAXException {
            this.oOut.println(new StringBuffer().append("WARNING: ").append(getParseExceptionInfo(sAXParseException)).toString());
        }

        @Override // org.xml.sax.ErrorHandler
        public void error(SAXParseException sAXParseException) throws SAXException {
            throw new SAXException(new StringBuffer().append("ERROR: ").append(getParseExceptionInfo(sAXParseException)).toString());
        }

        @Override // org.xml.sax.ErrorHandler
        public void fatalError(SAXParseException sAXParseException) throws SAXException {
            throw new SAXException(new StringBuffer().append("FATAL: ").append(getParseExceptionInfo(sAXParseException)).toString());
        }
    }

    public NeuralNetwork(IFeatureExtraction iFeatureExtraction) {
        super(iFeatureExtraction);
        this.oLayers = new ArrayList();
        this.iCurrenLayer = 0;
        this.iCurrLayerBuf = 0;
        this.iNeuronType = -1;
        this.oInputs = new Layer();
        this.oOutputs = new Layer();
        this.iCurrentDumpMode = 0;
        this.strFilename = getDefaultFilename();
    }

    @Override // marf.Classification.Classification, marf.Classification.IClassification
    public final boolean train() throws ClassificationException {
        return trainImplementation(null);
    }

    @Override // marf.Classification.Classification, marf.Classification.IClassification
    public final boolean train(double[] dArr) throws ClassificationException {
        return trainImplementation(dArr);
    }

    private final boolean trainImplementation(double[] dArr) throws ClassificationException {
        try {
            int i = this.iCurrentDumpMode;
            this.iCurrentDumpMode = 0;
            if (dArr == null) {
                super.train();
            } else {
                super.train(dArr);
            }
            this.iCurrentDumpMode = i;
            double d = 1.0d;
            int i2 = 64;
            double d2 = 0.1d;
            if (MARF.getModuleParams() != null) {
                Vector classificationParams = MARF.getModuleParams().getClassificationParams();
                if (classificationParams.size() > 1) {
                    d = ((Double) classificationParams.elementAt(1)).doubleValue();
                    i2 = ((Integer) classificationParams.elementAt(2)).intValue();
                    d2 = ((Double) classificationParams.elementAt(3)).doubleValue();
                }
            }
            restore();
            Vector clusters = this.oTrainingSet.getClusters();
            int i3 = 0;
            double d3 = d2 + 1.0d;
            while (d3 > d2 && i3 < i2) {
                for (int i4 = 0; i4 < clusters.size(); i4++) {
                    ITrainingSample iTrainingSample = (ITrainingSample) clusters.get(i4);
                    train(iTrainingSample.getMeanVector(), iTrainingSample.getSubjectID(), d);
                    commit();
                }
                double d4 = 0.0d;
                int i5 = 0;
                while (i5 < clusters.size()) {
                    ITrainingSample iTrainingSample2 = (ITrainingSample) clusters.get(i5);
                    setInputs(iTrainingSample2.getMeanVector());
                    runNNet();
                    d4 += d2 * Math.abs(iTrainingSample2.getSubjectID() - r0);
                    Debug.debug(new StringBuffer().append("Expected: ").append(iTrainingSample2.getSubjectID()).append(", Got: ").append(interpretAsBinary()).append(", Error: ").append(d4).toString());
                    i5++;
                }
                if (i5 == 0) {
                    throw new ClassificationException("NeuralNetwork.train() --- There are no training samples!");
                }
                d3 = d4 / i5;
                i3++;
                Debug.debug(new StringBuffer().append("Epoch: error = ").append(d3).append(", limit = ").append(i3).toString());
            }
            dump();
            return true;
        } catch (NullPointerException e) {
            e.printStackTrace(System.err);
            throw new ClassificationException(new StringBuffer().append("NeuralNetwork.train(): Missing required ModuleParam (").append((Object) null).append(") or TrainingSample (").append((Object) null).append(")").toString());
        } catch (StorageException e2) {
            e2.printStackTrace(System.err);
            throw new ClassificationException(new StringBuffer().append("StorageException while dumping/restoring neural net: ").append(e2.getMessage()).toString(), e2);
        }
    }

    @Override // marf.Classification.IClassification
    public final boolean classify(double[] dArr) throws ClassificationException {
        try {
            restore();
            if (dArr.length != this.oInputs.size()) {
                throw new ClassificationException(new StringBuffer().append("Input array size (").append(dArr.length).append(") not consistent with input layer (").append(this.oInputs.size()).append(")").toString());
            }
            for (int i = 0; i < dArr.length; i++) {
                this.oInputs.get(i).dResult = dArr[i];
            }
            runNNet();
            this.oResultSet.addResult(new Result(interpretAsBinary()));
            this.oResultSet.addResult(new Result(interpretAsBinary() + 1));
            return true;
        } catch (StorageException e) {
            e.printStackTrace(System.err);
            throw new ClassificationException(e);
        }
    }

    public final void eval() {
        runNNet();
    }

    private final void runNNet() {
        for (int i = 0; i < this.oLayers.size(); i++) {
            ((Layer) this.oLayers.get(i)).eval();
        }
    }

    public final void initialize(String str, boolean z) throws StorageException {
        try {
            Debug.debug("Initializing XML parser...");
            DocumentBuilderFactory newInstance = DocumentBuilderFactory.newInstance();
            newInstance.setNamespaceAware(true);
            newInstance.setValidating(z);
            DocumentBuilder newDocumentBuilder = newInstance.newDocumentBuilder();
            newDocumentBuilder.setErrorHandler(new NeuralNetworkErrorHandler(new PrintWriter((Writer) new OutputStreamWriter(System.err, OUTPUT_ENCODING), true)));
            Debug.debug("Parsing XML file...");
            Document parse = newDocumentBuilder.parse(new File(str));
            this.oLayers.add(this.oInputs);
            Debug.debug("Making the NNet structure...");
            buildNetwork(parse);
            this.oLayers.add(this.oOutputs);
            Debug.debug("Setting the inputs and outputs for each Neuron...");
            this.iCurrenLayer = 0;
            createLinks(parse);
        } catch (FileNotFoundException e) {
            try {
                generate();
                dump();
            } catch (ClassificationException e2) {
                e2.printStackTrace(System.err);
                throw new StorageException(e2);
            }
        } catch (Exception e3) {
            e3.printStackTrace(System.err);
            throw new StorageException(e3);
        }
    }

    public void generate() throws ClassificationException {
        Debug.debug("Generating new net...");
        int length = this.oFeatureExtraction.getFeaturesArray().length;
        int abs = Math.abs(length - 32) / 2;
        if (abs == 0) {
            abs = length / 2;
        }
        generate(length, new int[]{length * 2, length, abs}, 32);
        Debug.debug("Dumping newly generated net...");
    }

    private final void buildNetwork(Node node) {
        if (node.getNodeType() == 1) {
            String nodeName = node.getNodeName();
            if (nodeName.equals("input") || nodeName.equals("output")) {
                return;
            }
            NamedNodeMap attributes = node.getAttributes();
            if (nodeName.equals("layer")) {
                for (int i = 0; i < attributes.getLength(); i++) {
                    Node item = attributes.item(i);
                    String nodeName2 = item.getNodeName();
                    String nodeValue = item.getNodeValue();
                    if (nodeName2.equals("type")) {
                        if (nodeValue.equals("input")) {
                            this.oCurrentLayer = this.oInputs;
                            this.iNeuronType = 0;
                        } else if (nodeValue.equals("output")) {
                            this.oCurrentLayer = this.oOutputs;
                            this.iNeuronType = 2;
                        } else {
                            this.oCurrentLayer = new Layer();
                            this.oLayers.add(this.oCurrentLayer);
                            this.iNeuronType = 1;
                        }
                    } else if (nodeName2.equals("index")) {
                        Debug.debug("Indexing layers currently not supported... Assumings written order.");
                    } else {
                        System.err.println(new StringBuffer().append("Unknown layer attribute: ").append(nodeName2).toString());
                    }
                }
            } else if (nodeName.equals("neuron")) {
                String str = new String();
                double d = 0.0d;
                for (int i2 = 0; i2 < attributes.getLength(); i2++) {
                    Node item2 = attributes.item(i2);
                    String nodeName3 = item2.getNodeName();
                    String nodeValue2 = item2.getNodeValue();
                    if (nodeName3.equals("index")) {
                        str = new String(nodeValue2);
                    } else if (nodeName3.equals("thresh")) {
                        try {
                            d = Double.valueOf(nodeValue2.trim()).doubleValue();
                        } catch (NumberFormatException e) {
                            System.err.println(new StringBuffer().append("NumberFormatException: ").append(e.getMessage()).toString());
                            e.printStackTrace(System.err);
                        }
                    } else {
                        System.err.println(new StringBuffer().append("Unknown layer attribute: ").append(nodeName3).toString());
                    }
                }
                Neuron neuron = new Neuron(str, this.iNeuronType);
                neuron.dThreshold = d;
                this.oCurrentLayer.add(neuron);
            }
        }
        Node firstChild = node.getFirstChild();
        while (true) {
            Node node2 = firstChild;
            if (node2 == null) {
                return;
            }
            buildNetwork(node2);
            firstChild = node2.getNextSibling();
        }
    }

    private final void createLinks(Node node) throws ClassificationException {
        if (node.getNodeType() == 1) {
            String nodeName = node.getNodeName();
            NamedNodeMap attributes = node.getAttributes();
            if (nodeName.equals("layer")) {
                for (int i = 0; i < attributes.getLength(); i++) {
                    Node item = attributes.item(i);
                    String nodeName2 = item.getNodeName();
                    String nodeValue = item.getNodeValue();
                    if (nodeName2.equals("type")) {
                        if (nodeValue.equals("input")) {
                            this.oCurrentLayer = this.oInputs;
                            this.iCurrenLayer = 0;
                        } else if (nodeValue.equals("output")) {
                            this.oCurrentLayer = this.oOutputs;
                            this.iCurrenLayer = this.oLayers.size() - 1;
                        } else {
                            int i2 = this.iCurrLayerBuf + 1;
                            this.iCurrLayerBuf = i2;
                            this.iCurrenLayer = i2;
                            this.oCurrentLayer = (Layer) this.oLayers.get(this.iCurrenLayer);
                        }
                    }
                }
            } else if (nodeName.equals("neuron")) {
                String str = new String();
                for (int i3 = 0; i3 < attributes.getLength(); i3++) {
                    Node item2 = attributes.item(i3);
                    String nodeName3 = item2.getNodeName();
                    String nodeValue2 = item2.getNodeValue();
                    if (nodeName3.equals("index")) {
                        str = new String(nodeValue2);
                    }
                }
                this.oCurrNeuron = this.oCurrentLayer.getNeuron(str);
            } else if (nodeName.equals("input")) {
                String str2 = null;
                double d = -1.0d;
                for (int i4 = 0; i4 < attributes.getLength(); i4++) {
                    Node item3 = attributes.item(i4);
                    String nodeName4 = item3.getNodeName();
                    String nodeValue3 = item3.getNodeValue();
                    if (nodeName4.equals("ref")) {
                        str2 = new String(nodeValue3);
                    } else if (nodeName4.equals("weight")) {
                        try {
                            d = Double.valueOf(nodeValue3.trim()).doubleValue();
                        } catch (NumberFormatException e) {
                            System.err.println(new StringBuffer().append("NumberFormatException: ").append(e.getMessage()).toString());
                            e.printStackTrace(System.err);
                        }
                    }
                }
                if (str2 == null || str2.equals("")) {
                    throw new ClassificationException(new StringBuffer().append("No 'ref' value assigned for neuron ").append(this.oCurrNeuron.strName).append(" in layer ").append(this.iCurrenLayer).toString());
                }
                if (this.iCurrenLayer <= 0) {
                    throw new ClassificationException("Input element not allowed in input layer");
                }
                Neuron neuron = ((Layer) this.oLayers.get(this.iCurrenLayer - 1)).getNeuron(str2);
                if (neuron == null) {
                    throw new ClassificationException(new StringBuffer().append("Cannot find neuron ").append(str2).append(" in layer ").append(this.iCurrenLayer - 1).toString());
                }
                this.oCurrNeuron.addInput(neuron, d);
            } else if (nodeName.equals("output")) {
                String str3 = null;
                for (int i5 = 0; i5 < attributes.getLength(); i5++) {
                    Node item4 = attributes.item(i5);
                    String nodeName5 = item4.getNodeName();
                    String nodeValue4 = item4.getNodeValue();
                    if (nodeName5.equals("ref")) {
                        str3 = new String(nodeValue4);
                    }
                }
                if (str3 == null || str3.equals("")) {
                    throw new ClassificationException(new StringBuffer().append("No 'ref' value assigned for neuron ").append(this.oCurrNeuron.strName).append(" in layer ").append(this.iCurrenLayer).toString());
                }
                if (this.iCurrenLayer >= 0) {
                    Neuron neuron2 = ((Layer) this.oLayers.get(this.iCurrenLayer + 1)).getNeuron(str3);
                    if (neuron2 == null) {
                        throw new ClassificationException(new StringBuffer().append("Cannot find neuron ").append(str3).append(" in layer ").append(this.iCurrenLayer + 1).toString());
                    }
                    this.oCurrNeuron.addOutput(neuron2);
                }
            }
        }
        Node firstChild = node.getFirstChild();
        while (true) {
            Node node2 = firstChild;
            if (node2 == null) {
                return;
            }
            createLinks(node2);
            firstChild = node2.getNextSibling();
        }
    }

    public final void setInputs(double[] dArr) throws ClassificationException {
        if (dArr.length != this.oInputs.size()) {
            throw new ClassificationException("Input array size not consistent with input layer.");
        }
        for (int i = 0; i < dArr.length; i++) {
            this.oInputs.get(i).dResult = dArr[i];
        }
    }

    public double[] getOutputResults() {
        double[] dArr = new double[this.oOutputs.size()];
        for (int i = 0; i < this.oOutputs.size(); i++) {
            dArr[i] = this.oOutputs.get(i).dResult;
        }
        return dArr;
    }

    public static final void indent(BufferedWriter bufferedWriter, int i) throws IOException {
        for (int i2 = 0; i2 < i; i2++) {
            bufferedWriter.write("\t");
        }
    }

    public final void dumpXML(String str) throws StorageException {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str));
            bufferedWriter.write("<?xml version=\"1.0\"?>");
            bufferedWriter.newLine();
            bufferedWriter.write("<net>");
            bufferedWriter.newLine();
            for (int i = 0; i < this.oLayers.size(); i++) {
                Layer layer = (Layer) this.oLayers.get(i);
                indent(bufferedWriter, 1);
                bufferedWriter.write("<layer type=\"");
                if (i == 0) {
                    bufferedWriter.write("input");
                } else if (i == this.oLayers.size() - 1) {
                    bufferedWriter.write("output");
                } else {
                    bufferedWriter.write("hidden");
                }
                bufferedWriter.write(new StringBuffer("\" index=\"").append(i).append("\">").toString());
                bufferedWriter.newLine();
                for (int i2 = 0; i2 < layer.size(); i2++) {
                    layer.get(i2).printXML(bufferedWriter, 2);
                }
                indent(bufferedWriter, 1);
                bufferedWriter.write("</layer>");
                bufferedWriter.newLine();
            }
            bufferedWriter.write("</net>");
            bufferedWriter.newLine();
            bufferedWriter.close();
        } catch (Exception e) {
            throw new StorageException(e);
        }
    }

    public final void generate(int i, int[] iArr, int i2) throws ClassificationException {
        if (iArr == null || iArr.length == 0) {
            throw new ClassificationException("Number of hidden layers may not be null or of 0 length.");
        }
        for (int i3 = 1; i3 <= 1 + iArr.length + 1; i3++) {
            if (i3 == 1) {
                for (int i4 = 1; i4 <= i; i4++) {
                    Neuron neuron = new Neuron(new StringBuffer().append("").append(i4).toString(), 0);
                    neuron.dThreshold = 1.0d;
                    this.oInputs.add(neuron);
                }
                this.oLayers.add(this.oInputs);
            } else if (i3 == 1 + iArr.length + 1) {
                for (int i5 = 1; i5 <= i2; i5++) {
                    Neuron neuron2 = new Neuron(new StringBuffer().append("").append(i5).toString(), 2);
                    neuron2.dThreshold = 1.0d;
                    this.oOutputs.add(neuron2);
                }
                this.oLayers.add(this.oOutputs);
            } else {
                Layer layer = new Layer();
                for (int i6 = 1; i6 <= iArr[i3 - 2]; i6++) {
                    Neuron neuron3 = new Neuron(new StringBuffer().append("").append(i6).toString(), 1);
                    neuron3.dThreshold = 1.0d;
                    layer.add(neuron3);
                }
                this.oLayers.add(layer);
            }
        }
        Debug.debug("Setting the inputs and outputs for each Neuron...");
        for (int i7 = 0; i7 < this.oLayers.size() - 1; i7++) {
            Layer layer2 = (Layer) this.oLayers.get(i7);
            for (int i8 = 0; i8 < layer2.size(); i8++) {
                Neuron neuron4 = layer2.get(i8);
                Layer layer3 = (Layer) this.oLayers.get(i7 + 1);
                for (int i9 = 0; i9 < layer3.size(); i9++) {
                    Neuron neuron5 = layer3.get(i9);
                    neuron4.addOutput(neuron5);
                    neuron5.addInput(neuron4, (new Random().nextDouble() * 2.0d) - 1.0d);
                }
            }
        }
    }

    public final void train(double[] dArr, int i, double d) throws ClassificationException {
        if (d <= 0.0d) {
            throw new ClassificationException(new StringBuffer().append("NeuralNetwork.train(): Training constant must be > 0.0, supplied: ").append(d).toString());
        }
        if (dArr.length != this.oInputs.size()) {
            throw new ClassificationException(new StringBuffer().append("NeuralNetwork.train(): Input array size (").append(dArr.length).append(") not consistent with input layer (").append(this.oInputs.size()).append(")").toString());
        }
        setInputs(dArr);
        runNNet();
        for (int size = this.oOutputs.size() - 1; size >= 0; size--) {
            int i2 = i % 2;
            i /= 2;
            this.oOutputs.get(size).train(i2, d, 1.0d);
        }
        for (int size2 = this.oLayers.size() - 2; size2 >= 0; size2--) {
            ((Layer) this.oLayers.get(size2)).train(d);
        }
    }

    public final void commit() {
        for (int i = 0; i < this.oLayers.size(); i++) {
            ((Layer) this.oLayers.get(i)).commit();
        }
    }

    private final int interpretAsBinary() {
        int i = 0;
        for (int i2 = 0; i2 < this.oOutputs.size(); i2++) {
            i *= 2;
            if (this.oOutputs.get(i2).dResult > 0.5d) {
                i++;
            }
            Debug.debug(new StringBuffer().append(this.oOutputs.get(i2).dResult).append(",").toString());
        }
        Debug.debug(new StringBuffer().append("Interpreted binary result (ID) = ").append(i).toString());
        return i;
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:3:0x0004. Please report as an issue. */
    @Override // marf.Classification.Classification, marf.Storage.StorageManager, marf.Storage.IStorageManager
    public void dump() throws StorageException {
        try {
            switch (this.iCurrentDumpMode) {
                case 0:
                case 2:
                    if (this.oInputs.size() == 0) {
                        generate();
                    }
                    Vector vector = new Vector(3);
                    vector.add(this.oInputs);
                    vector.add(this.oLayers);
                    vector.add(this.oOutputs);
                    this.oObjectToSerialize = vector;
                default:
                    switch (this.iCurrentDumpMode) {
                        case 0:
                            dumpGzipBinary();
                            return;
                        case 2:
                            dumpBinary();
                            return;
                        default:
                            super.dump();
                            return;
                    }
            }
        } catch (ClassificationException e) {
            e.printStackTrace(System.err);
            throw new StorageException(e);
        }
    }

    @Override // marf.Classification.Classification, marf.Storage.StorageManager, marf.Storage.IStorageManager
    public void restore() throws StorageException {
        switch (this.iCurrentDumpMode) {
            case 0:
                restoreGzipBinary();
                return;
            case 2:
                restoreBinary();
                return;
            default:
                super.restore();
                return;
        }
    }

    @Override // marf.Storage.StorageManager, marf.Storage.IStorageManager
    public void dumpXML() throws StorageException {
        dumpXML(getDefaultFilename());
    }

    @Override // marf.Storage.StorageManager, marf.Storage.IStorageManager
    public void restoreXML() throws StorageException {
        initialize(getDefaultFilename(), false);
    }

    @Override // marf.Storage.StorageManager
    public void backSynchronizeObject() {
        Vector vector = (Vector) this.oObjectToSerialize;
        this.oInputs = (Layer) vector.firstElement();
        this.oLayers = (ArrayList) vector.elementAt(1);
        this.oOutputs = (Layer) vector.lastElement();
    }

    protected String getDefaultFilename() {
        return new StringBuffer().append(getClass().getName()).append(".").append(MARF.getPreprocessingMethod()).append(".").append(MARF.getFeatureExtractionMethod()).append(".").append(getDefaultExtension()).toString();
    }

    @Override // marf.Classification.IClassification
    public Result getResult() {
        return this.oResultSet.getMinimumResult();
    }

    public static String getMARFSourceCodeRevision() {
        return "$Revision: 1.61 $";
    }
}
