package com.amazon.randomcutforest.interpolation;

import com.amazon.randomcutforest.Visitor;
import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.InterpolationMeasure;
import com.amazon.randomcutforest.tree.BoundingBox;
import com.amazon.randomcutforest.tree.Node;
import java.util.Arrays;

/* loaded from: input_file:com/amazon/randomcutforest/interpolation/SimpleInterpolationVisitor.class */
public class SimpleInterpolationVisitor implements Visitor<InterpolationMeasure> {
    private final double[] pointToScore;
    private final long sampleSize;
    private final boolean centerOfMass;
    public InterpolationMeasure stored;
    double[] directionalDistanceVector;
    double[] differenceInRangeVector;
    boolean[] coordInsideBox;
    private BoundingBox theShadowBox;
    private double savedMass;
    private double pointMass;
    double sumOfNewRange = 0.0d;
    double sumOfDifferenceInRange = 0.0d;
    boolean pointInsideBox = false;
    private boolean pointEqualsLeaf = false;

    public SimpleInterpolationVisitor(double[] dArr, int i, double d, boolean z) {
        this.pointToScore = Arrays.copyOf(dArr, dArr.length);
        this.sampleSize = i;
        this.pointMass = d;
        this.stored = new DensityOutput(dArr.length, i);
        this.directionalDistanceVector = new double[2 * dArr.length];
        this.differenceInRangeVector = new double[2 * dArr.length];
        this.centerOfMass = z;
        this.coordInsideBox = new boolean[dArr.length];
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // com.amazon.randomcutforest.Visitor
    public InterpolationMeasure getResult() {
        return this.stored;
    }

    @Override // com.amazon.randomcutforest.Visitor
    public void accept(Node node, int i) {
        BoundingBox boundingBox;
        BoundingBox mergedBox;
        if (this.pointInsideBox) {
            return;
        }
        if (this.pointEqualsLeaf) {
            mergedBox = node.getBoundingBox();
            Node rightChild = Node.isLeftOf(this.pointToScore, node) ? node.getRightChild() : node.getLeftChild();
            this.theShadowBox = this.theShadowBox == null ? rightChild.getBoundingBox() : this.theShadowBox.getMergedBox(rightChild.getBoundingBox());
            boundingBox = this.theShadowBox;
        } else {
            boundingBox = node.getBoundingBox();
            mergedBox = boundingBox.getMergedBox(this.pointToScore);
        }
        updateForCompute(boundingBox, mergedBox);
        double d = this.sumOfDifferenceInRange / this.sumOfNewRange;
        if (d <= 0.0d) {
            this.pointInsideBox = true;
            return;
        }
        double fieldExt = fieldExt(node, this.centerOfMass, this.savedMass, this.pointToScore);
        double influenceExt = influenceExt(node, this.centerOfMass, this.savedMass, this.pointToScore);
        for (int i2 = 0; i2 < this.pointToScore.length; i2++) {
            double d2 = this.differenceInRangeVector[2 * i2] / this.sumOfNewRange;
            this.stored.probMass.high[i2] = (d2 * influenceExt) + ((1.0d - d) * this.stored.probMass.high[i2]);
            this.stored.measure.high[i2] = (d2 * fieldExt) + ((1.0d - d) * this.stored.measure.high[i2]);
            this.stored.distances.high[i2] = (d2 * this.directionalDistanceVector[2 * i2] * influenceExt) + ((1.0d - d) * this.stored.distances.high[i2]);
        }
        for (int i3 = 0; i3 < this.pointToScore.length; i3++) {
            double d3 = this.differenceInRangeVector[(2 * i3) + 1] / this.sumOfNewRange;
            this.stored.probMass.low[i3] = (d3 * influenceExt) + ((1.0d - d) * this.stored.probMass.low[i3]);
            this.stored.measure.low[i3] = (d3 * fieldExt) + ((1.0d - d) * this.stored.measure.low[i3]);
            this.stored.distances.low[i3] = (d3 * this.directionalDistanceVector[(2 * i3) + 1] * influenceExt) + ((1.0d - d) * this.stored.distances.low[i3]);
        }
    }

    @Override // com.amazon.randomcutforest.Visitor
    public void acceptLeaf(Node node, int i) {
        updateForCompute(node.getBoundingBox(), node.getBoundingBox().getMergedBox(this.pointToScore));
        if (this.sumOfDifferenceInRange <= 0.0d) {
            this.savedMass = this.pointMass + node.getMass();
            this.pointEqualsLeaf = true;
            for (int i2 = 0; i2 < this.pointToScore.length; i2++) {
                double selfField = (0.5d * selfField(node, this.savedMass)) / this.pointToScore.length;
                this.stored.measure.low[i2] = selfField;
                this.stored.measure.high[i2] = selfField;
                double selfInfluence = (0.5d * selfInfluence(node, this.savedMass)) / this.pointToScore.length;
                this.stored.probMass.low[i2] = selfInfluence;
                this.stored.probMass.high[i2] = selfInfluence;
            }
            Arrays.fill(this.coordInsideBox, false);
            return;
        }
        this.savedMass = this.pointMass;
        double fieldPoint = fieldPoint(node, this.savedMass, this.pointToScore);
        double influencePoint = influencePoint(node, this.savedMass, this.pointToScore);
        for (int i3 = 0; i3 < this.pointToScore.length; i3++) {
            double d = this.differenceInRangeVector[2 * i3] / this.sumOfNewRange;
            this.stored.probMass.high[i3] = d * influencePoint;
            this.stored.measure.high[i3] = d * fieldPoint;
            this.stored.distances.high[i3] = d * this.directionalDistanceVector[2 * i3] * influencePoint;
        }
        for (int i4 = 0; i4 < this.pointToScore.length; i4++) {
            double d2 = this.differenceInRangeVector[(2 * i4) + 1] / this.sumOfNewRange;
            this.stored.probMass.low[i4] = d2 * influencePoint;
            this.stored.measure.low[i4] = d2 * fieldPoint;
            this.stored.distances.low[i4] = d2 * this.directionalDistanceVector[(2 * i4) + 1] * influencePoint;
        }
    }

    void updateForCompute(BoundingBox boundingBox, BoundingBox boundingBox2) {
        this.sumOfNewRange = 0.0d;
        this.sumOfDifferenceInRange = 0.0d;
        Arrays.fill(this.directionalDistanceVector, 0.0d);
        Arrays.fill(this.differenceInRangeVector, 0.0d);
        for (int i = 0; i < this.pointToScore.length; i++) {
            this.sumOfNewRange += boundingBox2.getRange(i);
            if (!this.coordInsideBox[i]) {
                double max = Math.max(boundingBox2.getMaxValue(i) - boundingBox.getMaxValue(i), 0.0d);
                double max2 = Math.max(boundingBox.getMinValue(i) - boundingBox2.getMinValue(i), 0.0d);
                if (max + max2 > 0.0d) {
                    this.sumOfDifferenceInRange += max2 + max;
                    this.differenceInRangeVector[2 * i] = max;
                    this.differenceInRangeVector[(2 * i) + 1] = max2;
                    if (max > 0.0d) {
                        this.directionalDistanceVector[2 * i] = max + boundingBox.getRange(i);
                    } else {
                        this.directionalDistanceVector[(2 * i) + 1] = max2 + boundingBox.getRange(i);
                    }
                } else {
                    this.coordInsideBox[i] = true;
                }
            }
        }
    }

    double fieldExt(Node node, boolean z, double d, double[] dArr) {
        return node.getMass() + d;
    }

    double influenceExt(Node node, boolean z, double d, double[] dArr) {
        return 1.0d;
    }

    double fieldPoint(Node node, double d, double[] dArr) {
        return node.getMass() + d;
    }

    double influencePoint(Node node, double d, double[] dArr) {
        return 1.0d;
    }

    double selfField(Node node, double d) {
        return d;
    }

    double selfInfluence(Node node, double d) {
        return 1.0d;
    }
}
