package clarion.tools;

import clarion.system.AbstractImplicitModule;
import clarion.system.AbstractOutputChunk;
import clarion.system.AbstractTrainableImplicitModule;
import clarion.system.Dimension;
import clarion.system.DimensionValueCollection;
import clarion.system.InterfaceRuntimeTrainable;
import clarion.system.StochasticSelector;
import clarion.system.Value;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.ListIterator;

/* loaded from: input_file:clarion/tools/TrainableImplicitModulePreTrainer.class */
public class TrainableImplicitModulePreTrainer {
    public int NUM_TRAINING_REPEATS = GLOBAL_NUM_TRAINING_REPEATS;
    public double SUM_SQ_ERRORS_THRESHOLD = GLOBAL_SUM_SQ_ERRORS_THRESHOLD;
    public TerminationConditions TERMINATION_CONDITION = GLOBAL_TERMINATION_CONDITION;
    public boolean PRINT_PROGRESS_TO_SYSTEM_OUT = GLOBAL_PRINT_PROGRESS_TO_SYSTEM_OUT;
    public StochasticSelector SELECTOR = new StochasticSelector();
    public static int GLOBAL_NUM_TRAINING_REPEATS = 5000;
    public static double GLOBAL_SUM_SQ_ERRORS_THRESHOLD = 0.001d;
    public static TerminationConditions GLOBAL_TERMINATION_CONDITION = TerminationConditions.FIXED;
    public static boolean GLOBAL_PRINT_PROGRESS_TO_SYSTEM_OUT = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:clarion/tools/TrainableImplicitModulePreTrainer$SumSquaredErrorTracker.class */
    public class SumSquaredErrorTracker {
        double sumsqerr;
        int sumsqerrcounter;

        private SumSquaredErrorTracker() {
            this.sumsqerr = TrainableImplicitModulePreTrainer.this.SUM_SQ_ERRORS_THRESHOLD;
        }

        public double getMeanSumOfSquaredErrors() {
            return this.sumsqerr / this.sumsqerrcounter;
        }

        /* synthetic */ SumSquaredErrorTracker(TrainableImplicitModulePreTrainer trainableImplicitModulePreTrainer, SumSquaredErrorTracker sumSquaredErrorTracker) {
            this();
        }
    }

    /* loaded from: input_file:clarion/tools/TrainableImplicitModulePreTrainer$TerminationConditions.class */
    public enum TerminationConditions {
        FIXED,
        SUM_SQ_ERROR,
        BOTH;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static TerminationConditions[] valuesCustom() {
            TerminationConditions[] valuesCustom = values();
            int length = valuesCustom.length;
            TerminationConditions[] terminationConditionsArr = new TerminationConditions[length];
            System.arraycopy(valuesCustom, 0, terminationConditionsArr, 0, length);
            return terminationConditionsArr;
        }
    }

    public void trainModule(AbstractTrainableImplicitModule abstractTrainableImplicitModule, AbstractImplicitModule abstractImplicitModule, Collection<? extends DimensionValueCollection> collection) {
        SumSquaredErrorTracker sumSquaredErrorTracker = new SumSquaredErrorTracker(this, null);
        int i = 0;
        while (true) {
            if ((this.TERMINATION_CONDITION != TerminationConditions.FIXED || i >= this.NUM_TRAINING_REPEATS) && ((this.TERMINATION_CONDITION != TerminationConditions.SUM_SQ_ERROR || sumSquaredErrorTracker.getMeanSumOfSquaredErrors() <= this.SUM_SQ_ERRORS_THRESHOLD) && (this.TERMINATION_CONDITION != TerminationConditions.BOTH || i >= this.NUM_TRAINING_REPEATS || sumSquaredErrorTracker.getMeanSumOfSquaredErrors() <= this.SUM_SQ_ERRORS_THRESHOLD))) {
                return;
            }
            if ((this.TERMINATION_CONDITION == TerminationConditions.FIXED || this.TERMINATION_CONDITION == TerminationConditions.BOTH) && this.PRINT_PROGRESS_TO_SYSTEM_OUT) {
                System.out.println("Training Trial # " + (i + 1));
            }
            sumSquaredErrorTracker.sumsqerr = 0.0d;
            sumSquaredErrorTracker.sumsqerrcounter = 0;
            Iterator<? extends DimensionValueCollection> it = collection.iterator();
            while (it.hasNext()) {
                LinkedList linkedList = new LinkedList(it.next().values());
                dataRecursor(abstractTrainableImplicitModule, abstractImplicitModule, linkedList.listIterator(), new LinkedList(((Dimension) linkedList.getFirst()).values()).listIterator(), sumSquaredErrorTracker);
            }
            if ((this.TERMINATION_CONDITION == TerminationConditions.SUM_SQ_ERROR || this.TERMINATION_CONDITION == TerminationConditions.BOTH) && this.PRINT_PROGRESS_TO_SYSTEM_OUT) {
                System.out.println("Mean Sum of Squared Error: " + sumSquaredErrorTracker.getMeanSumOfSquaredErrors());
            }
            i++;
        }
    }

    private void dataRecursor(AbstractTrainableImplicitModule abstractTrainableImplicitModule, AbstractImplicitModule abstractImplicitModule, ListIterator<Dimension> listIterator, ListIterator<? extends Value> listIterator2, SumSquaredErrorTracker sumSquaredErrorTracker) {
        Dimension next = listIterator.next();
        Value next2 = listIterator2.next();
        if (!(next2 instanceof Range)) {
            innerLoop(abstractTrainableImplicitModule, abstractImplicitModule, listIterator, listIterator2, next, next2, sumSquaredErrorTracker);
            return;
        }
        double lowerBound = ((Range) next2).getLowerBound();
        while (true) {
            double d = lowerBound;
            if (d > ((Range) next2).getUpperBound()) {
                return;
            }
            next2.setActivation(d);
            innerLoop(abstractTrainableImplicitModule, abstractImplicitModule, listIterator, listIterator2, next, next2, sumSquaredErrorTracker);
            lowerBound = d + ((Range) next2).INCREMENT;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void innerLoop(AbstractTrainableImplicitModule abstractTrainableImplicitModule, AbstractImplicitModule abstractImplicitModule, ListIterator<Dimension> listIterator, ListIterator<? extends Value> listIterator2, Dimension dimension, Value value, SumSquaredErrorTracker sumSquaredErrorTracker) {
        abstractTrainableImplicitModule.setInput(dimension.getID(), value);
        abstractImplicitModule.setInput(dimension.getID(), value);
        if (!listIterator2.hasNext() && !listIterator.hasNext()) {
            abstractTrainableImplicitModule.forwardPass();
            abstractImplicitModule.forwardPass();
            if (abstractTrainableImplicitModule instanceof InterfaceRuntimeTrainable) {
                AbstractOutputChunk abstractOutputChunk = (AbstractOutputChunk) this.SELECTOR.select(abstractTrainableImplicitModule.getOutput());
                abstractTrainableImplicitModule.setChosenOutput(abstractOutputChunk);
                ((InterfaceRuntimeTrainable) abstractTrainableImplicitModule).setFeedback(abstractImplicitModule.getOutput(abstractOutputChunk.getID()).getActivation());
            } else {
                abstractTrainableImplicitModule.setDesiredOutput(abstractImplicitModule.getOutput());
            }
            abstractTrainableImplicitModule.backwardPass();
            sumSquaredErrorTracker.sumsqerr += abstractTrainableImplicitModule.getSumSqErrors();
            sumSquaredErrorTracker.sumsqerrcounter++;
            return;
        }
        if (listIterator2.hasNext()) {
            listIterator.previous();
            dataRecursor(abstractTrainableImplicitModule, abstractImplicitModule, listIterator, listIterator2, sumSquaredErrorTracker);
            listIterator2.previous();
        } else if (listIterator.hasNext()) {
            LinkedList linkedList = new LinkedList(listIterator.next().values());
            listIterator.previous();
            dataRecursor(abstractTrainableImplicitModule, abstractImplicitModule, listIterator, linkedList.listIterator(), sumSquaredErrorTracker);
            listIterator.previous();
        }
    }
}
