package com.amazon.randomcutforest;

import com.amazon.randomcutforest.anomalydetection.AnomalyAttributionVisitor;
import com.amazon.randomcutforest.anomalydetection.AnomalyScoreVisitor;
import com.amazon.randomcutforest.imputation.ImputeVisitor;
import com.amazon.randomcutforest.inspect.NearNeighborVisitor;
import com.amazon.randomcutforest.interpolation.SimpleInterpolationVisitor;
import com.amazon.randomcutforest.returntypes.ConvergingAccumulator;
import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.InterpolationMeasure;
import com.amazon.randomcutforest.returntypes.Neighbor;
import com.amazon.randomcutforest.returntypes.OneSidedConvergingDiVectorAccumulator;
import com.amazon.randomcutforest.returntypes.OneSidedConvergingDoubleAccumulator;
import com.amazon.randomcutforest.sampler.SimpleStreamSampler;
import com.amazon.randomcutforest.tree.RandomCutTree;
import com.amazon.randomcutforest.util.ShingleBuilder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;

/* loaded from: input_file:com/amazon/randomcutforest/RandomCutForest.class */
public class RandomCutForest {
    public static final int DEFAULT_SAMPLE_SIZE = 256;
    public static final double DEFAULT_OUTPUT_AFTER_FRACTION = 0.25d;
    public static final double DEFAULT_SAMPLE_SIZE_COEFFICIENT_IN_LAMBDA = 10.0d;
    public static final int DEFAULT_NUMBER_OF_TREES = 50;
    public static final boolean DEFAULT_STORE_SEQUENCE_INDEXES_ENABLED = false;
    public static final boolean DEFAULT_CENTER_OF_MASS_ENABLED = false;
    public static final boolean DEFAULT_PARALLEL_EXECUTION_ENABLED = false;
    public static final boolean DEFAULT_APPROXIMATE_ANOMALY_SCORE_HIGH_IS_CRITICAL = true;
    public static final double DEFAULT_APPROXIMATE_DYNAMIC_SCORE_PRECISION = 0.1d;
    public static final int DEFAULT_APPROXIMATE_DYNAMIC_SCORE_MIN_VALUES_ACCEPTED = 5;
    protected final Random rng;
    protected final int dimensions;
    protected final int sampleSize;
    protected final int outputAfter;
    protected final int numberOfTrees;
    protected final double lambda;
    protected final boolean storeSequenceIndexesEnabled;
    protected final boolean centerOfMassEnabled;
    protected final boolean parallelExecutionEnabled;
    protected final int threadPoolSize;
    protected final AbstractForestTraversalExecutor executor;

    /* loaded from: input_file:com/amazon/randomcutforest/RandomCutForest$Builder.class */
    public static class Builder<T extends Builder<T>> {
        private int dimensions;
        private int sampleSize = RandomCutForest.DEFAULT_SAMPLE_SIZE;
        private Optional<Integer> outputAfter = Optional.empty();
        private int numberOfTrees = 50;
        private Optional<Double> lambda = Optional.empty();
        private Optional<Long> randomSeed = Optional.empty();
        private boolean storeSequenceIndexesEnabled = false;
        private boolean centerOfMassEnabled = false;
        private boolean parallelExecutionEnabled = false;
        private Optional<Integer> threadPoolSize = Optional.empty();

        public T dimensions(int i) {
            this.dimensions = i;
            return this;
        }

        public T sampleSize(int i) {
            this.sampleSize = i;
            return this;
        }

        public T outputAfter(int i) {
            this.outputAfter = Optional.of(Integer.valueOf(i));
            return this;
        }

        public T numberOfTrees(int i) {
            this.numberOfTrees = i;
            return this;
        }

        public T lambda(double d) {
            this.lambda = Optional.of(Double.valueOf(d));
            return this;
        }

        public T randomSeed(long j) {
            this.randomSeed = Optional.of(Long.valueOf(j));
            return this;
        }

        public T storeSequenceIndexesEnabled(boolean z) {
            this.storeSequenceIndexesEnabled = z;
            return this;
        }

        public T centerOfMassEnabled(boolean z) {
            this.centerOfMassEnabled = z;
            return this;
        }

        public T parallelExecutionEnabled(boolean z) {
            this.parallelExecutionEnabled = z;
            return this;
        }

        public T threadPoolSize(int i) {
            this.threadPoolSize = Optional.of(Integer.valueOf(i));
            return this;
        }

        public RandomCutForest build() {
            return new RandomCutForest(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public RandomCutForest(Builder<?> builder) {
        CommonUtils.checkArgument(((Builder) builder).numberOfTrees > 0, "numberOfTrees must be greater than 0");
        CommonUtils.checkArgument(((Builder) builder).sampleSize > 0, "sampleSize must be greater than 0");
        ((Builder) builder).outputAfter.ifPresent(num -> {
            CommonUtils.checkArgument(num.intValue() > 0, "outputAfter must be greater than 0");
            CommonUtils.checkArgument(num.intValue() <= builder.sampleSize, "outputAfter must be smaller or equal to sampleSize");
        });
        CommonUtils.checkArgument(((Builder) builder).dimensions > 0, "dimensions must be greater than 0");
        ((Builder) builder).lambda.ifPresent(d -> {
            CommonUtils.checkArgument(d.doubleValue() >= 0.0d, "lambda must be greater than or equal to 0");
        });
        ((Builder) builder).threadPoolSize.ifPresent(num2 -> {
            CommonUtils.checkArgument(num2.intValue() > 0, "threadPoolSize must be greater than 0. To disable thread pool, set parallel execution to 'false'.");
        });
        this.numberOfTrees = ((Builder) builder).numberOfTrees;
        this.sampleSize = ((Builder) builder).sampleSize;
        this.outputAfter = ((Integer) ((Builder) builder).outputAfter.orElse(Integer.valueOf((int) (this.sampleSize * 0.25d)))).intValue();
        this.dimensions = ((Builder) builder).dimensions;
        this.lambda = ((Double) ((Builder) builder).lambda.orElse(Double.valueOf(1.0d / (10.0d * this.sampleSize)))).doubleValue();
        this.storeSequenceIndexesEnabled = ((Builder) builder).storeSequenceIndexesEnabled;
        this.centerOfMassEnabled = ((Builder) builder).centerOfMassEnabled;
        this.parallelExecutionEnabled = ((Builder) builder).parallelExecutionEnabled;
        ArrayList arrayList = new ArrayList(this.numberOfTrees);
        this.rng = (Random) ((Builder) builder).randomSeed.map((v1) -> {
            return new Random(v1);
        }).orElseGet(Random::new);
        for (int i = 0; i < this.numberOfTrees; i++) {
            arrayList.add(new TreeUpdater(new SimpleStreamSampler(this.sampleSize, this.lambda, this.rng.nextLong()), RandomCutTree.builder().storeSequenceIndexesEnabled(this.storeSequenceIndexesEnabled).centerOfMassEnabled(this.centerOfMassEnabled).randomSeed(this.rng.nextLong()).build()));
        }
        if (this.parallelExecutionEnabled) {
            this.threadPoolSize = ((Integer) ((Builder) builder).threadPoolSize.orElse(Integer.valueOf(Runtime.getRuntime().availableProcessors() - 1))).intValue();
            this.executor = new ParallelForestTraversalExecutor(arrayList, this.threadPoolSize);
        } else {
            this.threadPoolSize = 0;
            this.executor = new SequentialForestTraversalExecutor(arrayList);
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static RandomCutForest defaultForest(int i, long j) {
        return builder().dimensions(i).randomSeed(j).build();
    }

    public static RandomCutForest defaultForest(int i) {
        return builder().dimensions(i).build();
    }

    public int getNumberOfTrees() {
        return this.numberOfTrees;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public int getOutputAfter() {
        return this.outputAfter;
    }

    public int getDimensions() {
        return this.dimensions;
    }

    public double getLambda() {
        return this.lambda;
    }

    public boolean storeSequenceIndexesEnabled() {
        return this.storeSequenceIndexesEnabled;
    }

    public boolean centerOfMassEnabled() {
        return this.centerOfMassEnabled;
    }

    public boolean parallelExecutionEnabled() {
        return this.parallelExecutionEnabled;
    }

    public int getThreadPoolSize() {
        return this.threadPoolSize;
    }

    public void update(double[] dArr) {
        CommonUtils.checkNotNull(dArr, "point must not be null");
        CommonUtils.checkArgument(dArr.length == this.dimensions, String.format("point.length must equal %d", Integer.valueOf(this.dimensions)));
        this.executor.update(dArr);
    }

    public <R, S> S traverseForest(double[] dArr, Function<RandomCutTree, Visitor<R>> function, BinaryOperator<R> binaryOperator, Function<R, S> function2) {
        CommonUtils.checkNotNull(dArr, "point must not be null");
        CommonUtils.checkArgument(dArr.length == this.dimensions, String.format("point.length must equal %d", Integer.valueOf(this.dimensions)));
        CommonUtils.checkNotNull(function, "visitorFactory must not be null");
        CommonUtils.checkNotNull(binaryOperator, "accumulator must not be null");
        CommonUtils.checkNotNull(function2, "finisher must not be null");
        return (S) this.executor.traverseForest(dArr, function, binaryOperator, function2);
    }

    public <R, S> S traverseForest(double[] dArr, Function<RandomCutTree, Visitor<R>> function, Collector<R, ?, S> collector) {
        CommonUtils.checkNotNull(dArr, "point must not be null");
        CommonUtils.checkArgument(dArr.length == this.dimensions, String.format("point.length must equal %d", Integer.valueOf(this.dimensions)));
        CommonUtils.checkNotNull(function, "visitorFactory must not be null");
        CommonUtils.checkNotNull(collector, "collector must not be null");
        return (S) this.executor.traverseForest(dArr, function, collector);
    }

    public <R, S> S traverseForest(double[] dArr, Function<RandomCutTree, Visitor<R>> function, ConvergingAccumulator<R> convergingAccumulator, Function<R, S> function2) {
        CommonUtils.checkNotNull(dArr, "point must not be null");
        CommonUtils.checkArgument(dArr.length == this.dimensions, String.format("point.length must equal %d", Integer.valueOf(this.dimensions)));
        CommonUtils.checkNotNull(function, "visitorFactory must not be null");
        CommonUtils.checkNotNull(convergingAccumulator, "accumulator must not be null");
        CommonUtils.checkNotNull(function2, "finisher must not be null");
        return (S) this.executor.traverseForest(dArr, function, convergingAccumulator, function2);
    }

    public <R, S> S traverseForestMulti(double[] dArr, Function<RandomCutTree, MultiVisitor<R>> function, BinaryOperator<R> binaryOperator, Function<R, S> function2) {
        CommonUtils.checkNotNull(dArr, "point must not be null");
        CommonUtils.checkArgument(dArr.length == this.dimensions, String.format("point.length must equal %d", Integer.valueOf(this.dimensions)));
        CommonUtils.checkNotNull(function, "visitorFactory must not be null");
        CommonUtils.checkNotNull(binaryOperator, "accumulator must not be null");
        CommonUtils.checkNotNull(function2, "finisher must not be null");
        return (S) this.executor.traverseForestMulti(dArr, function, binaryOperator, function2);
    }

    public <R, S> S traverseForestMulti(double[] dArr, Function<RandomCutTree, MultiVisitor<R>> function, Collector<R, ?, S> collector) {
        CommonUtils.checkNotNull(dArr, "point must not be null");
        CommonUtils.checkArgument(dArr.length == this.dimensions, String.format("point.length must equal %d", Integer.valueOf(this.dimensions)));
        CommonUtils.checkNotNull(function, "visitorFactory must not be null");
        CommonUtils.checkNotNull(collector, "collector must not be null");
        return (S) this.executor.traverseForestMulti(dArr, function, collector);
    }

    public double getAnomalyScore(double[] dArr) {
        if (isOutputReady()) {
            return ((Double) traverseForest(dArr, randomCutTree -> {
                return new AnomalyScoreVisitor(dArr, randomCutTree.getRoot().getMass());
            }, (v0, v1) -> {
                return Double.sum(v0, v1);
            }, d -> {
                return Double.valueOf(d.doubleValue() / this.numberOfTrees);
            })).doubleValue();
        }
        return 0.0d;
    }

    public double getApproximateAnomalyScore(double[] dArr) {
        if (!isOutputReady()) {
            return 0.0d;
        }
        Function function = randomCutTree -> {
            return new AnomalyScoreVisitor(dArr, randomCutTree.getRoot().getMass());
        };
        OneSidedConvergingDoubleAccumulator oneSidedConvergingDoubleAccumulator = new OneSidedConvergingDoubleAccumulator(true, 0.1d, 5, this.numberOfTrees);
        return ((Double) traverseForest(dArr, function, oneSidedConvergingDoubleAccumulator, d -> {
            return Double.valueOf(d.doubleValue() / oneSidedConvergingDoubleAccumulator.getValuesAccepted());
        })).doubleValue();
    }

    public DiVector getAnomalyAttribution(double[] dArr) {
        return !isOutputReady() ? new DiVector(this.dimensions) : (DiVector) traverseForest(dArr, randomCutTree -> {
            return new AnomalyAttributionVisitor(dArr, randomCutTree.getRoot().getMass());
        }, DiVector::addToLeft, diVector -> {
            return diVector.scale(1.0d / this.numberOfTrees);
        });
    }

    public DiVector getApproximateAnomalyAttribution(double[] dArr) {
        if (!isOutputReady()) {
            return new DiVector(this.dimensions);
        }
        Function function = randomCutTree -> {
            return new AnomalyAttributionVisitor(dArr, randomCutTree.getRoot().getMass());
        };
        OneSidedConvergingDiVectorAccumulator oneSidedConvergingDiVectorAccumulator = new OneSidedConvergingDiVectorAccumulator(this.dimensions, true, 0.1d, 5, this.numberOfTrees);
        return (DiVector) traverseForest(dArr, function, oneSidedConvergingDiVectorAccumulator, diVector -> {
            return diVector.scale(1.0d / oneSidedConvergingDiVectorAccumulator.getValuesAccepted());
        });
    }

    public DensityOutput getSimpleDensity(double[] dArr) {
        return !samplersFull() ? new DensityOutput(this.dimensions, this.sampleSize) : new DensityOutput((InterpolationMeasure) traverseForest(dArr, randomCutTree -> {
            return new SimpleInterpolationVisitor(dArr, this.sampleSize, 1.0d, this.centerOfMassEnabled);
        }, InterpolationMeasure.collector(this.dimensions, this.sampleSize, this.numberOfTrees)));
    }

    public double[] imputeMissingValues(double[] dArr, int i, int[] iArr) {
        CommonUtils.checkArgument(i >= 0, "numberOfMissingValues must be greater than or equal to 0");
        CommonUtils.checkNotNull(dArr, "point must not be null");
        if (i == 0) {
            return Arrays.copyOf(dArr, dArr.length);
        }
        CommonUtils.checkNotNull(iArr, "missingIndexes must not be null");
        CommonUtils.checkArgument(i <= iArr.length, "numberOfMissingValues must be less than or equal to missingIndexes.length");
        if (!isOutputReady()) {
            return new double[this.dimensions];
        }
        Function function = randomCutTree -> {
            return new ImputeVisitor(dArr, i, iArr);
        };
        if (i != 1) {
            return (double[]) ((ArrayList) traverseForestMulti(dArr, function, Collector.of(ArrayList::new, (v0, v1) -> {
                v0.add(v1);
            }, (arrayList, arrayList2) -> {
                arrayList.addAll(arrayList2);
                return arrayList;
            }, arrayList3 -> {
                arrayList3.sort(Comparator.comparing(this::getAnomalyScore));
                return arrayList3;
            }, new Collector.Characteristics[0]))).get(this.numberOfTrees / 4);
        }
        ArrayList arrayList4 = (ArrayList) traverseForestMulti(dArr, function, Collector.of(ArrayList::new, (arrayList5, dArr2) -> {
            arrayList5.add(Double.valueOf(dArr2[iArr[0]]));
        }, (arrayList6, arrayList7) -> {
            arrayList6.addAll(arrayList7);
            return arrayList6;
        }, arrayList8 -> {
            arrayList8.sort(Comparator.comparing((v0) -> {
                return v0.doubleValue();
            }));
            return arrayList8;
        }, new Collector.Characteristics[0]));
        double[] copyOf = Arrays.copyOf(dArr, this.dimensions);
        copyOf[iArr[0]] = ((Double) arrayList4.get(this.numberOfTrees / 2)).doubleValue();
        return copyOf;
    }

    public double[] extrapolateBasic(double[] dArr, int i, int i2, boolean z, int i3) {
        CommonUtils.checkArgument(0 < i2 && i2 < this.dimensions, "blockSize must be between 0 and dimensions (exclusive)");
        CommonUtils.checkArgument(this.dimensions % i2 == 0, "dimensions must be evenly divisible by blockSize");
        CommonUtils.checkArgument(0 <= i3 && i3 < this.dimensions / i2, "shingleIndex must be between 0 (inclusive) and dimensions / blockSize");
        double[] dArr2 = new double[i2 * i];
        int[] iArr = new int[i2];
        double[] copyOf = Arrays.copyOf(dArr, this.dimensions);
        if (z) {
            extrapolateBasicCyclic(dArr2, i, i2, i3, copyOf, iArr);
        } else {
            extrapolateBasicSliding(dArr2, i, i2, copyOf, iArr);
        }
        return dArr2;
    }

    public double[] extrapolateBasic(double[] dArr, int i, int i2, boolean z) {
        return extrapolateBasic(dArr, i, i2, z, 0);
    }

    public double[] extrapolateBasic(ShingleBuilder shingleBuilder, int i) {
        return extrapolateBasic(shingleBuilder.getShingle(), i, shingleBuilder.getInputPointSize(), shingleBuilder.isCyclic(), shingleBuilder.getShingleIndex());
    }

    void extrapolateBasicSliding(double[] dArr, int i, int i2, double[] dArr2, int[] iArr) {
        int i3 = 0;
        Arrays.fill(iArr, 0);
        for (int i4 = 0; i4 < i2; i4++) {
            iArr[i4] = (this.dimensions - i2) + i4;
        }
        for (int i5 = 0; i5 < i; i5++) {
            System.arraycopy(dArr2, i2, dArr2, 0, this.dimensions - i2);
            double[] imputeMissingValues = imputeMissingValues(dArr2, i2, iArr);
            for (int i6 = 0; i6 < i2; i6++) {
                int i7 = i3;
                i3++;
                int i8 = (this.dimensions - i2) + i6;
                double d = imputeMissingValues[(this.dimensions - i2) + i6];
                dArr2[i8] = d;
                dArr[i7] = d;
            }
        }
    }

    void extrapolateBasicCyclic(double[] dArr, int i, int i2, int i3, double[] dArr2, int[] iArr) {
        int i4 = 0;
        int i5 = i3;
        Arrays.fill(iArr, 0);
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < i2; i7++) {
                iArr[i7] = (i5 + i7) % this.dimensions;
            }
            double[] imputeMissingValues = imputeMissingValues(dArr2, i2, iArr);
            for (int i8 = 0; i8 < i2; i8++) {
                int i9 = i4;
                i4++;
                int i10 = (i5 + i8) % this.dimensions;
                double d = imputeMissingValues[(i5 + i8) % this.dimensions];
                dArr2[i10] = d;
                dArr[i9] = d;
            }
            i5 = (i5 + i2) % this.dimensions;
        }
    }

    public List<Neighbor> getNearNeighborsInSample(double[] dArr, double d) {
        CommonUtils.checkNotNull(dArr, "point must not be null");
        CommonUtils.checkArgument(d > 0.0d, "distanceThreshold must be greater than 0");
        return !isOutputReady() ? Collections.emptyList() : (List) traverseForest(dArr, randomCutTree -> {
            return new NearNeighborVisitor(dArr, d);
        }, Neighbor.collector());
    }

    public List<Neighbor> getNearNeighborsInSample(double[] dArr) {
        return getNearNeighborsInSample(dArr, Double.POSITIVE_INFINITY);
    }

    public boolean isOutputReady() {
        return this.executor.getTotalUpdates() >= ((long) this.outputAfter);
    }

    public boolean samplersFull() {
        return this.executor.getTotalUpdates() >= ((long) this.sampleSize);
    }

    public long getTotalUpdates() {
        return this.executor.getTotalUpdates();
    }
}
