package com.amazon.randomcutforest.sampler;

import com.amazon.randomcutforest.CommonUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Random;
import java.util.stream.Collectors;

/* loaded from: input_file:com/amazon/randomcutforest/sampler/SimpleStreamSampler.class */
public class SimpleStreamSampler {
    static Comparator<WeightedPoint> POINT_COMPARATOR = Comparator.comparingDouble((v0) -> {
        return v0.getWeight();
    }).reversed();
    private final Queue<WeightedPoint> weightedSamples;
    private final int sampleSize;
    private final double lambda;
    private final Random random;
    private long entriesSeen;
    private transient WeightedPoint evictedPoint;

    /* loaded from: input_file:com/amazon/randomcutforest/sampler/SimpleStreamSampler$PriorityQueueWrapper.class */
    static class PriorityQueueWrapper<WeightedPoint> extends PriorityQueue<WeightedPoint> {
        PriorityQueueWrapper() {
            super(SimpleStreamSampler.POINT_COMPARATOR);
        }
    }

    public SimpleStreamSampler(int i, double d, long j) {
        this(i, d, new Random(j));
    }

    protected SimpleStreamSampler(int i, double d, Random random) {
        this.sampleSize = i;
        this.entriesSeen = 0L;
        this.weightedSamples = new PriorityQueueWrapper();
        this.random = random;
        this.lambda = d;
    }

    public static SimpleStreamSampler uniformSampler(int i, long j) {
        return new SimpleStreamSampler(i, 0.0d, j);
    }

    public WeightedPoint sample(double[] dArr, long j) {
        this.evictedPoint = null;
        WeightedPoint weightedPoint = null;
        double computeWeight = computeWeight(j);
        this.entriesSeen++;
        if (this.entriesSeen <= this.sampleSize || computeWeight < this.weightedSamples.element().getWeight()) {
            if (isFull()) {
                this.evictedPoint = this.weightedSamples.poll();
            }
            weightedPoint = new WeightedPoint(dArr, j, computeWeight);
            this.weightedSamples.add(weightedPoint);
            CommonUtils.checkState(this.weightedSamples.size() <= this.sampleSize, "The number of points in the sampler is greater than the sample size");
        }
        return weightedPoint;
    }

    public WeightedPoint getEvictedPoint() {
        return this.evictedPoint;
    }

    public List<double[]> getSamples() {
        return (List) this.weightedSamples.stream().map((v0) -> {
            return v0.getPoint();
        }).collect(Collectors.toList());
    }

    public List<WeightedPoint> getWeightedSamples() {
        return new ArrayList(this.weightedSamples);
    }

    public boolean isReady() {
        return this.weightedSamples.size() >= this.sampleSize / 4;
    }

    public boolean isFull() {
        return this.weightedSamples.size() == this.sampleSize;
    }

    protected double computeWeight(long j) {
        double d = 0.0d;
        while (true) {
            double d2 = d;
            if (d2 != 0.0d) {
                return ((-j) * this.lambda) + Math.log(-Math.log(d2));
            }
            d = this.random.nextDouble();
        }
    }

    public long getCapacity() {
        return this.sampleSize;
    }

    public long getSize() {
        return this.weightedSamples.size();
    }

    public double getLambda() {
        return this.lambda;
    }
}
