/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.transferlearning;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TransferLearning {
    private static final Logger log = LoggerFactory.getLogger(TransferLearning.class);

    public static class GraphBuilder {
        private ComputationGraph origGraph;
        private ComputationGraphConfiguration origConfig;
        private FineTuneConfiguration fineTuneConfiguration;
        private ComputationGraphConfiguration.GraphBuilder editedConfigBuilder;
        private String[] frozenOutputAt;
        private boolean hasFrozen = false;
        private Set<String> editedVertices = new HashSet<String>();
        private WorkspaceMode workspaceMode;
        private Boolean validateOutputLayerConfig = null;
        private Map<String, Integer> nInFromNewConfig = new HashMap<String, Integer>();

        public GraphBuilder(ComputationGraph origGraph) {
            this.origGraph = origGraph;
            this.origConfig = origGraph.getConfiguration().clone();
        }

        public GraphBuilder fineTuneConfiguration(FineTuneConfiguration fineTuneConfiguration) {
            this.fineTuneConfiguration = fineTuneConfiguration;
            this.editedConfigBuilder = new ComputationGraphConfiguration.GraphBuilder(this.origConfig, fineTuneConfiguration.appliedNeuralNetConfigurationBuilder());
            Map<String, org.deeplearning4j.nn.conf.graph.GraphVertex> vertices = this.editedConfigBuilder.getVertices();
            for (Map.Entry<String, org.deeplearning4j.nn.conf.graph.GraphVertex> gv : vertices.entrySet()) {
                if (!(gv.getValue() instanceof LayerVertex)) continue;
                LayerVertex lv = (LayerVertex)gv.getValue();
                NeuralNetConfiguration nnc = lv.getLayerConf().clone();
                fineTuneConfiguration.applyToNeuralNetConfiguration(nnc);
                vertices.put(gv.getKey(), new LayerVertex(nnc, lv.getPreProcessor()));
                nnc.getLayer().setLayerName(gv.getKey());
            }
            return this;
        }

        public GraphBuilder setFeatureExtractor(String ... layerName) {
            this.hasFrozen = true;
            this.frozenOutputAt = layerName;
            return this;
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme) {
            return this.nOutReplace(layerName, nOut, scheme, scheme);
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist) {
            return this.nOutReplace(layerName, nOut, dist, dist);
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist, Distribution distNext) {
            return this.nOutReplace(layerName, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext));
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, Distribution dist) {
            if (scheme == WeightInit.DISTRIBUTION) {
                throw new UnsupportedOperationException("Not supported!, Use nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext)) instead!");
            }
            return this.nOutReplace(layerName, nOut, scheme.getWeightInitFunction(), new WeightInitDistribution(dist));
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, Distribution dist, WeightInit scheme) {
            if (scheme == WeightInit.DISTRIBUTION) {
                throw new UnsupportedOperationException("Not supported!, Use nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext)) instead!");
            }
            return this.nOutReplace(layerName, nOut, new WeightInitDistribution(dist), scheme.getWeightInitFunction());
        }

        public GraphBuilder nOutReplace(String layerName, int nOut, WeightInit scheme, WeightInit schemeNext) {
            if (scheme == WeightInit.DISTRIBUTION || schemeNext == WeightInit.DISTRIBUTION) {
                throw new UnsupportedOperationException("Not supported!, Use nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext)) instead!");
            }
            return this.nOutReplace(layerName, nOut, scheme.getWeightInitFunction(), schemeNext.getWeightInitFunction());
        }

        public GraphBuilder nInReplace(String layerName, int nIn, WeightInit scheme) {
            return this.nInReplace(layerName, nIn, scheme, null);
        }

        public GraphBuilder validateOutputLayerConfig(boolean validateOutputLayerConfig) {
            this.validateOutputLayerConfig = validateOutputLayerConfig;
            return this;
        }

        public GraphBuilder nInReplace(String layerName, int nIn, WeightInit scheme, Distribution dist) {
            return this.nInReplace(layerName, nIn, scheme.getWeightInitFunction(dist));
        }

        public GraphBuilder nInReplace(String layerName, int nIn, IWeightInit scheme) {
            Layer l;
            Preconditions.checkState((this.origGraph.getVertex(layerName) != null ? 1 : 0) != 0, (String)"Layer with name %s not found", (Object)layerName);
            Preconditions.checkState((boolean)this.origGraph.getVertex(layerName).hasLayer(), (String)"nInReplace can only be applied on vertices with layers. Vertex %s does not have a layer", (Object)layerName);
            this.initBuilderIfReq();
            NeuralNetConfiguration layerConf = this.origGraph.getLayer(layerName).conf();
            Layer layerImpl = layerConf.getLayer().clone();
            Preconditions.checkState((boolean)(layerImpl instanceof FeedForwardLayer), (String)"Can only use nInReplace on FeedForward layers;got layer of type %s for layer name %s", (Object)layerImpl.getClass().getSimpleName(), (Object)layerName);
            layerImpl.resetLayerDefaultConfig();
            FeedForwardLayer layerImplF = (FeedForwardLayer)layerImpl;
            layerImplF.setWeightInitFn(scheme);
            layerImplF.setNIn(nIn);
            if (this.editedVertices.contains(layerName) && this.editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex && this.nInFromNewConfig.containsKey(layerName) && (l = ((LayerVertex)this.editedConfigBuilder.getVertices().get(layerName)).getLayerConf().getLayer()) instanceof FeedForwardLayer) {
                layerImplF.setNIn(this.nInFromNewConfig.get(layerName).intValue());
            }
            this.editedConfigBuilder.removeVertex(layerName, false);
            LayerVertex lv = (LayerVertex)this.origConfig.getVertices().get(layerName);
            String[] lvInputs = this.origConfig.getVertexInputs().get(layerName).toArray(new String[0]);
            this.editedConfigBuilder.addLayer(layerName, layerImpl, lv.getPreProcessor(), lvInputs);
            this.editedVertices.add(layerName);
            return this;
        }

        private GraphBuilder nOutReplace(String layerName, int nOut, IWeightInit scheme, IWeightInit schemeNext) {
            this.initBuilderIfReq();
            if (this.origGraph.getVertex(layerName).hasLayer()) {
                Layer l;
                NeuralNetConfiguration layerConf = this.origGraph.getLayer(layerName).conf();
                Layer layerImpl = layerConf.getLayer().clone();
                layerImpl.resetLayerDefaultConfig();
                FeedForwardLayer layerImplF = (FeedForwardLayer)layerImpl;
                layerImplF.setWeightInitFn(scheme);
                layerImplF.setNOut(nOut);
                if (this.editedVertices.contains(layerName) && this.editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex && this.nInFromNewConfig.containsKey(layerName) && (l = ((LayerVertex)this.editedConfigBuilder.getVertices().get(layerName)).getLayerConf().getLayer()) instanceof FeedForwardLayer) {
                    layerImplF.setNIn(this.nInFromNewConfig.get(layerName).intValue());
                }
                this.editedConfigBuilder.removeVertex(layerName, false);
                LayerVertex lv = (LayerVertex)this.origConfig.getVertices().get(layerName);
                String[] lvInputs = this.origConfig.getVertexInputs().get(layerName).toArray(new String[0]);
                this.editedConfigBuilder.addLayer(layerName, layerImpl, lv.getPreProcessor(), lvInputs);
                this.editedVertices.add(layerName);
                ArrayList<String> fanoutVertices = new ArrayList<String>();
                for (Map.Entry<String, List<String>> entry : this.origConfig.getVertexInputs().entrySet()) {
                    String currentVertex = entry.getKey();
                    if (currentVertex.equals(layerName) || !entry.getValue().contains(layerName)) continue;
                    fanoutVertices.add(currentVertex);
                }
                for (String fanoutVertexName : fanoutVertices) {
                    if (!this.origGraph.getVertex(fanoutVertexName).hasLayer()) {
                        throw new UnsupportedOperationException("Cannot modify nOut of a layer vertex that feeds non-layer vertices. Use removeVertexKeepConnections followed by addVertex instead");
                    }
                    layerConf = this.origGraph.getLayer(fanoutVertexName).conf();
                    if (!(layerConf.getLayer() instanceof FeedForwardLayer)) continue;
                    layerImpl = layerConf.getLayer().clone();
                    layerImplF = (FeedForwardLayer)layerImpl;
                    layerImplF.setWeightInitFn(schemeNext);
                    layerImplF.setNIn(nOut);
                    this.nInFromNewConfig.put(fanoutVertexName, nOut);
                    this.editedConfigBuilder.removeVertex(fanoutVertexName, false);
                    lv = (LayerVertex)this.origConfig.getVertices().get(fanoutVertexName);
                    lvInputs = this.origConfig.getVertexInputs().get(fanoutVertexName).toArray(new String[0]);
                    this.editedConfigBuilder.addLayer(fanoutVertexName, layerImpl, lv.getPreProcessor(), lvInputs);
                    this.editedVertices.add(fanoutVertexName);
                    if (this.validateOutputLayerConfig == null) continue;
                    this.editedConfigBuilder.validateOutputLayerConfig(this.validateOutputLayerConfig);
                }
            } else {
                throw new IllegalArgumentException("noutReplace can only be applied to layer vertices. " + layerName + " is not a layer vertex");
            }
            return this;
        }

        public GraphBuilder removeVertexKeepConnections(String outputName) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.removeVertex(outputName, false);
            return this;
        }

        public GraphBuilder removeVertexAndConnections(String vertexName) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.removeVertex(vertexName, true);
            return this;
        }

        public GraphBuilder addLayer(String layerName, Layer layer, String ... layerInputs) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.addLayer(layerName, layer, (InputPreProcessor)null, layerInputs);
            this.editedVertices.add(layerName);
            return this;
        }

        public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor, String ... layerInputs) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.addLayer(layerName, layer, preProcessor, layerInputs);
            this.editedVertices.add(layerName);
            return this;
        }

        public GraphBuilder addVertex(String vertexName, org.deeplearning4j.nn.conf.graph.GraphVertex vertex, String ... vertexInputs) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.addVertex(vertexName, vertex, vertexInputs);
            this.editedVertices.add(vertexName);
            return this;
        }

        public GraphBuilder setOutputs(String ... outputNames) {
            this.initBuilderIfReq();
            this.editedConfigBuilder.setOutputs(outputNames);
            return this;
        }

        private void initBuilderIfReq() {
            if (this.editedConfigBuilder == null) {
                this.fineTuneConfiguration(new FineTuneConfiguration.Builder().seed(this.origConfig.getDefaultConfiguration().getSeed()).build());
            }
        }

        public GraphBuilder setInputs(String ... inputs) {
            this.editedConfigBuilder.setNetworkInputs(Arrays.asList(inputs));
            return this;
        }

        public GraphBuilder setInputTypes(InputType ... inputTypes) {
            this.editedConfigBuilder.setInputTypes(inputTypes);
            return this;
        }

        public GraphBuilder addInputs(String ... inputNames) {
            this.editedConfigBuilder.addInputs(inputNames);
            return this;
        }

        public GraphBuilder setWorkspaceMode(WorkspaceMode workspaceMode) {
            this.workspaceMode = workspaceMode;
            return this;
        }

        public ComputationGraph build() {
            this.initBuilderIfReq();
            ComputationGraphConfiguration newConfig = this.editedConfigBuilder.validateOutputLayerConfig(this.validateOutputLayerConfig == null ? true : this.validateOutputLayerConfig).build();
            if (this.workspaceMode != null) {
                newConfig.setTrainingWorkspaceMode(this.workspaceMode);
            }
            ComputationGraph newGraph = new ComputationGraph(newConfig);
            newGraph.init();
            int[] topologicalOrder = newGraph.topologicalSortOrder();
            GraphVertex[] vertices = newGraph.getVertices();
            if (!this.editedVertices.isEmpty()) {
                for (int i = 0; i < topologicalOrder.length; ++i) {
                    if (!vertices[topologicalOrder[i]].hasLayer()) continue;
                    org.deeplearning4j.nn.api.Layer layer = vertices[topologicalOrder[i]].getLayer();
                    String layerName = vertices[topologicalOrder[i]].getVertexName();
                    long range = layer.numParams();
                    if (range <= 0L || this.editedVertices.contains(layerName)) continue;
                    INDArray origParams = this.origGraph.getLayer(layerName).params();
                    layer.setParams(origParams.dup());
                }
            } else {
                newGraph.setParams(this.origGraph.params());
            }
            if (this.hasFrozen) {
                HashSet<String> allFrozen = new HashSet<String>();
                Collections.addAll(allFrozen, this.frozenOutputAt);
                for (int i = topologicalOrder.length - 1; i >= 0; --i) {
                    VertexIndices[] inputs;
                    GraphVertex gv = vertices[topologicalOrder[i]];
                    if (!allFrozen.contains(gv.getVertexName())) continue;
                    if (gv.hasLayer()) {
                        org.deeplearning4j.nn.api.Layer l = gv.getLayer();
                        gv.setLayerAsFrozen();
                        String layerName = gv.getVertexName();
                        LayerVertex currLayerVertex = (LayerVertex)newConfig.getVertices().get(layerName);
                        Layer origLayerConf = currLayerVertex.getLayerConf().getLayer();
                        org.deeplearning4j.nn.conf.layers.misc.FrozenLayer newLayerConf = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(origLayerConf);
                        ((Layer)newLayerConf).setLayerName(origLayerConf.getLayerName());
                        NeuralNetConfiguration newNNC = currLayerVertex.getLayerConf().clone();
                        currLayerVertex.setLayerConf(newNNC);
                        currLayerVertex.getLayerConf().setLayer(newLayerConf);
                        List<String> vars = currLayerVertex.getLayerConf().variables(true);
                        currLayerVertex.getLayerConf().clearVariables();
                        for (String s : vars) {
                            newNNC.variables(false).add(s);
                        }
                        org.deeplearning4j.nn.api.Layer[] layers = newGraph.getLayers();
                        for (int j = 0; j < layers.length; ++j) {
                            if (layers[j] != l) continue;
                            layers[j] = gv.getLayer();
                            break;
                        }
                    } else if (!(gv instanceof InputVertex)) {
                        org.deeplearning4j.nn.conf.graph.GraphVertex currVertexConf = newConfig.getVertices().get(gv.getVertexName());
                        org.deeplearning4j.nn.conf.graph.FrozenVertex newVertexConf = new org.deeplearning4j.nn.conf.graph.FrozenVertex(currVertexConf);
                        newConfig.getVertices().put(gv.getVertexName(), newVertexConf);
                        vertices[topologicalOrder[i]] = new FrozenVertex(gv);
                    }
                    if ((inputs = gv.getInputVertices()) == null || inputs.length <= 0) continue;
                    for (int j = 0; j < inputs.length; ++j) {
                        int inputVertexIdx = inputs[j].getVertexIndex();
                        String alsoFreeze = vertices[inputVertexIdx].getVertexName();
                        allFrozen.add(alsoFreeze);
                    }
                }
                newGraph.initGradientsView();
            }
            return newGraph;
        }
    }

    public static class Builder {
        private MultiLayerConfiguration origConf;
        private MultiLayerNetwork origModel;
        private MultiLayerNetwork editedModel;
        private FineTuneConfiguration finetuneConfiguration;
        private int frozenTill = -1;
        private int popN = 0;
        private boolean prepDone = false;
        private Set<Integer> editedLayers = new HashSet<Integer>();
        private Map<Integer, Triple<Integer, IWeightInit, IWeightInit>> editedLayersMap = new HashMap<Integer, Triple<Integer, IWeightInit, IWeightInit>>();
        private Map<Integer, Pair<Integer, IWeightInit>> nInEditedMap = new HashMap<Integer, Pair<Integer, IWeightInit>>();
        private List<INDArray> editedParams = new ArrayList<INDArray>();
        private List<NeuralNetConfiguration> editedConfs = new ArrayList<NeuralNetConfiguration>();
        private List<INDArray> appendParams = new ArrayList<INDArray>();
        private List<NeuralNetConfiguration> appendConfs = new ArrayList<NeuralNetConfiguration>();
        private Map<Integer, InputPreProcessor> inputPreProcessors = new HashMap<Integer, InputPreProcessor>();
        private InputType inputType;
        private Boolean validateOutputLayerConfig;
        private DataType dataType;

        public Builder(MultiLayerNetwork origModel) {
            this.origModel = origModel;
            this.origConf = origModel.getLayerWiseConfigurations().clone();
            this.dataType = origModel.getLayerWiseConfigurations().getDataType();
            this.inputPreProcessors = this.origConf.getInputPreProcessors();
        }

        public Builder fineTuneConfiguration(FineTuneConfiguration finetuneConfiguration) {
            this.finetuneConfiguration = finetuneConfiguration;
            return this;
        }

        public Builder setFeatureExtractor(int layerNum) {
            this.frozenTill = layerNum;
            return this;
        }

        public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme) {
            return this.nOutReplace(layerNum, nOut, scheme, scheme);
        }

        public Builder nOutReplace(int layerNum, int nOut, Distribution dist) {
            return this.nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(dist));
        }

        public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, WeightInit schemeNext) {
            if (scheme == WeightInit.DISTRIBUTION || schemeNext == WeightInit.DISTRIBUTION) {
                throw new UnsupportedOperationException("Not supported!, Use nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext)) instead!");
            }
            return this.nOutReplace(layerNum, nOut, scheme.getWeightInitFunction(), schemeNext.getWeightInitFunction());
        }

        public Builder nOutReplace(int layerNum, int nOut, Distribution dist, Distribution distNext) {
            return this.nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), new WeightInitDistribution(distNext));
        }

        public Builder nOutReplace(int layerNum, int nOut, WeightInit scheme, Distribution distNext) {
            if (scheme == WeightInit.DISTRIBUTION) {
                throw new UnsupportedOperationException("Not supported!, Use nOutReplace(int layerNum, int nOut, Distribution dist, Distribution distNext) instead!");
            }
            return this.nOutReplace(layerNum, nOut, scheme.getWeightInitFunction(), new WeightInitDistribution(distNext));
        }

        public Builder nOutReplace(int layerNum, int nOut, Distribution dist, WeightInit schemeNext) {
            return this.nOutReplace(layerNum, nOut, new WeightInitDistribution(dist), schemeNext.getWeightInitFunction());
        }

        public Builder nOutReplace(int layerNum, int nOut, IWeightInit scheme, IWeightInit schemeNext) {
            this.editedLayers.add(layerNum);
            Triple t = new Triple((Object)nOut, (Object)scheme, (Object)schemeNext);
            this.editedLayersMap.put(layerNum, (Triple<Integer, IWeightInit, IWeightInit>)t);
            return this;
        }

        public Builder nInReplace(int layerNum, int nIn, WeightInit scheme) {
            return this.nInReplace(layerNum, nIn, scheme, null);
        }

        public Builder nInReplace(int layerNum, int nIn, WeightInit scheme, Distribution dist) {
            return this.nInReplace(layerNum, nIn, scheme.getWeightInitFunction(dist));
        }

        public Builder nInReplace(int layerNum, int nIn, IWeightInit scheme) {
            Pair d = new Pair((Object)nIn, (Object)scheme);
            this.nInEditedMap.put(layerNum, (Pair<Integer, IWeightInit>)d);
            return this;
        }

        public Builder removeOutputLayer() {
            this.popN = 1;
            return this;
        }

        public Builder removeLayersFromOutput(int layerNum) {
            if (this.popN != 0) {
                throw new IllegalArgumentException("Remove layers from can only be called once");
            }
            this.popN = layerNum;
            return this;
        }

        public Builder addLayer(Layer layer) {
            if (!this.prepDone) {
                this.doPrep();
            }
            NeuralNetConfiguration layerConf = this.finetuneConfiguration.appliedNeuralNetConfigurationBuilder().layer(layer).build();
            long numParams = layer.initializer().numParams(layerConf);
            if (numParams > 0L) {
                INDArray params = Nd4j.create((DataType)this.origModel.getLayerWiseConfigurations().getDataType(), (long[])new long[]{numParams});
                org.deeplearning4j.nn.api.Layer someLayer = layer.instantiate(layerConf, null, 0, params, true, this.dataType);
                this.appendParams.add(someLayer.params());
                this.appendConfs.add(someLayer.conf());
            } else {
                this.appendConfs.add(layerConf);
            }
            return this;
        }

        public Builder setInputPreProcessor(int layer, InputPreProcessor processor) {
            this.inputPreProcessors.put(layer, processor);
            return this;
        }

        public Builder validateOutputLayerConfig(boolean validate) {
            this.validateOutputLayerConfig = validate;
            return this;
        }

        public MultiLayerNetwork build() {
            if (!this.prepDone) {
                this.doPrep();
            }
            this.editedModel = new MultiLayerNetwork(this.constructConf(), this.constructParams());
            if (this.frozenTill != -1) {
                org.deeplearning4j.nn.api.Layer[] layers = this.editedModel.getLayers();
                for (int i = this.frozenTill; i >= 0; --i) {
                    NeuralNetConfiguration origNNC = this.editedModel.getLayerWiseConfigurations().getConf(i);
                    NeuralNetConfiguration layerNNC = origNNC.clone();
                    layers[i].setConf(layerNNC);
                    layers[i] = new FrozenLayer(layers[i]);
                    if (origNNC.getVariables() != null) {
                        List<String> vars = origNNC.variables(true);
                        origNNC.clearVariables();
                        layerNNC.clearVariables();
                        for (String s : vars) {
                            origNNC.variables(false).add(s);
                            layerNNC.variables(false).add(s);
                        }
                    }
                    Layer origLayerConf = this.editedModel.getLayerWiseConfigurations().getConf(i).getLayer();
                    org.deeplearning4j.nn.conf.layers.misc.FrozenLayer newLayerConf = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(origLayerConf);
                    ((Layer)newLayerConf).setLayerName(origLayerConf.getLayerName());
                    this.editedModel.getLayerWiseConfigurations().getConf(i).setLayer(newLayerConf);
                }
                this.editedModel.setLayers(layers);
            }
            return this.editedModel;
        }

        private void doPrep() {
            int i;
            this.fineTuneConfigurationBuild();
            for (i = 0; i < this.origModel.getnLayers(); ++i) {
                if (this.origModel.getLayer(i).numParams() > 0L) {
                    this.editedParams.add(this.origModel.getLayer(i).params().dup());
                    continue;
                }
                this.editedParams.add(this.origModel.getLayer(i).params());
            }
            if (!this.editedLayers.isEmpty()) {
                Object[] editedLayersSorted = this.editedLayers.toArray(new Integer[this.editedLayers.size()]);
                Arrays.sort(editedLayersSorted);
                for (int i2 = 0; i2 < editedLayersSorted.length; ++i2) {
                    int layerNum = (Integer)editedLayersSorted[i2];
                    this.nOutReplaceBuild(layerNum, (Integer)this.editedLayersMap.get(layerNum).getLeft(), (IWeightInit)this.editedLayersMap.get(layerNum).getMiddle(), (IWeightInit)this.editedLayersMap.get(layerNum).getRight());
                }
            }
            if (!this.nInEditedMap.isEmpty()) {
                Object[] editedLayersSorted = this.nInEditedMap.keySet().toArray(new Integer[this.nInEditedMap.size()]);
                Arrays.sort(editedLayersSorted);
                for (Object layerNum : editedLayersSorted) {
                    Pair<Integer, IWeightInit> d = this.nInEditedMap.get(layerNum);
                    this.nInReplaceBuild((Integer)layerNum, (Integer)d.getFirst(), (IWeightInit)d.getSecond());
                }
            }
            for (i = 0; i < this.popN; ++i) {
                Integer layerNum = this.origModel.getnLayers() - i;
                if (this.inputPreProcessors.containsKey(layerNum)) {
                    this.inputPreProcessors.remove(layerNum);
                }
                this.editedConfs.remove(this.editedConfs.size() - 1);
                this.editedParams.remove(this.editedParams.size() - 1);
            }
            this.prepDone = true;
        }

        private void fineTuneConfigurationBuild() {
            for (int i = 0; i < this.origConf.getConfs().size(); ++i) {
                NeuralNetConfiguration layerConf;
                if (this.finetuneConfiguration != null) {
                    NeuralNetConfiguration nnc = this.origConf.getConf(i).clone();
                    this.finetuneConfiguration.applyToNeuralNetConfiguration(nnc);
                    layerConf = nnc;
                } else {
                    layerConf = this.origConf.getConf(i).clone();
                }
                this.editedConfs.add(layerConf);
            }
        }

        private void nInReplaceBuild(int layerNum, int nIn, IWeightInit init) {
            Preconditions.checkArgument((layerNum >= 0 && layerNum < this.editedConfs.size() ? 1 : 0) != 0, (String)"Invalid layer index: must be 0 to numLayers-1 = %s inclusive, got %s", (int)this.editedConfs.size(), (int)layerNum);
            NeuralNetConfiguration layerConf = this.editedConfs.get(layerNum);
            Layer layerImpl = layerConf.getLayer();
            Preconditions.checkArgument((boolean)(layerImpl instanceof FeedForwardLayer), (String)"nInReplace can only be applied on FeedForward layers;got layer of type %s", (Object)layerImpl.getClass().getSimpleName());
            FeedForwardLayer layerImplF = (FeedForwardLayer)layerImpl;
            layerImplF.setWeightInitFn(init);
            layerImplF.setNIn(nIn);
            long numParams = layerImpl.initializer().numParams(layerConf);
            INDArray params = Nd4j.create((DataType)this.origModel.getLayerWiseConfigurations().getDataType(), (long[])new long[]{numParams});
            org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, this.dataType);
            this.editedParams.set(layerNum, someLayer.params());
        }

        private void nOutReplaceBuild(int layerNum, int nOut, IWeightInit scheme, IWeightInit schemeNext) {
            Preconditions.checkArgument((layerNum >= 0 && layerNum < this.editedConfs.size() ? 1 : 0) != 0, (String)"Invalid layer index: must be 0 to numLayers-1 = %s includive, got %s", (int)this.editedConfs.size(), (int)layerNum);
            NeuralNetConfiguration layerConf = this.editedConfs.get(layerNum);
            Layer layerImpl = layerConf.getLayer();
            Preconditions.checkArgument((boolean)(layerImpl instanceof FeedForwardLayer), (String)"nOutReplace can only be applide on FeedForward layers;got layer of type %s", (Object)layerImpl.getClass().getSimpleName());
            FeedForwardLayer layerImplF = (FeedForwardLayer)layerImpl;
            layerImplF.setWeightInitFn(scheme);
            layerImplF.setNOut(nOut);
            long numParams = layerImpl.initializer().numParams(layerConf);
            INDArray params = Nd4j.create((DataType)this.origModel.getLayerWiseConfigurations().getDataType(), (long[])new long[]{numParams});
            org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, this.dataType);
            INDArray params1 = someLayer.params();
            this.editedParams.set(layerNum, params1.reshape(new long[]{params1.length()}));
            if (layerNum + 1 < this.editedConfs.size() && (layerImpl = (layerConf = this.editedConfs.get(layerNum + 1)).getLayer()) instanceof FeedForwardLayer) {
                layerImplF = (FeedForwardLayer)layerImpl;
                layerImplF.setWeightInitFn(schemeNext);
                layerImplF.setNIn(nOut);
                numParams = layerImpl.initializer().numParams(layerConf);
                if (numParams > 0L) {
                    params = Nd4j.create((DataType)this.origModel.getLayerWiseConfigurations().getDataType(), (long[])new long[]{numParams});
                    someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, this.dataType);
                    params1 = someLayer.params();
                    this.editedParams.set(layerNum + 1, params1.reshape(new long[]{params1.length()}));
                }
            }
        }

        private INDArray constructParams() {
            INDArray keepView = null;
            for (INDArray aParam : this.editedParams) {
                if (aParam == null) continue;
                if (keepView == null) {
                    keepView = aParam;
                    continue;
                }
                keepView = Nd4j.hstack((INDArray[])new INDArray[]{keepView, aParam});
            }
            if (!this.appendParams.isEmpty()) {
                INDArray appendView = Nd4j.hstack(this.appendParams);
                return Nd4j.hstack((INDArray[])new INDArray[]{keepView, appendView});
            }
            return keepView;
        }

        private MultiLayerConfiguration constructConf() {
            ArrayList<NeuralNetConfiguration> allConfs = new ArrayList<NeuralNetConfiguration>();
            allConfs.addAll(this.editedConfs);
            allConfs.addAll(this.appendConfs);
            for (int i = 0; i < allConfs.size(); ++i) {
                if (((NeuralNetConfiguration)allConfs.get(i)).getLayer().getLayerName() != null) continue;
                ((NeuralNetConfiguration)allConfs.get(i)).getLayer().setLayerName("layer" + i);
            }
            MultiLayerConfiguration conf = new MultiLayerConfiguration.Builder().inputPreProcessors(this.inputPreProcessors).setInputType(this.inputType).confs(allConfs).validateOutputLayerConfig(this.validateOutputLayerConfig == null ? true : this.validateOutputLayerConfig).dataType(this.origConf.getDataType()).build();
            if (this.finetuneConfiguration != null) {
                this.finetuneConfiguration.applyToMultiLayerConfiguration(conf);
            }
            return conf;
        }
    }
}

