package com.amazon.randomcutforest.imputation;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.MultiVisitor;
import com.amazon.randomcutforest.tree.Node;
import java.util.Arrays;

/* loaded from: input_file:com/amazon/randomcutforest/imputation/ImputeVisitor.class */
public class ImputeVisitor implements MultiVisitor<double[]> {
    private final boolean[] missing;
    private final int numberOfMissingValues;
    private double[] queryPoint;
    private double rank;

    public ImputeVisitor(double[] dArr, int i, int[] iArr) {
        this.queryPoint = Arrays.copyOf(dArr, dArr.length);
        this.missing = new boolean[dArr.length];
        iArr = iArr == null ? new int[0] : iArr;
        this.numberOfMissingValues = i;
        for (int i2 = 0; i2 < this.numberOfMissingValues; i2++) {
            CommonUtils.checkArgument(0 <= iArr[i2] && iArr[i2] < dArr.length, "Missing value indexes must be between 0 (inclusive) and queryPoint.length (exclusive)");
            this.missing[iArr[i2]] = true;
        }
        this.rank = 10.0d;
    }

    ImputeVisitor(ImputeVisitor imputeVisitor) {
        int length = imputeVisitor.queryPoint.length;
        this.queryPoint = Arrays.copyOf(imputeVisitor.queryPoint, length);
        this.missing = Arrays.copyOf(imputeVisitor.missing, length);
        this.numberOfMissingValues = imputeVisitor.numberOfMissingValues;
        this.rank = 10.0d;
    }

    public double getRank() {
        return this.rank;
    }

    @Override // com.amazon.randomcutforest.Visitor
    public void accept(Node node, int i) {
        double probabilityOfSeparation = CommonUtils.getProbabilityOfSeparation(node.getBoundingBox(), this.queryPoint);
        if (probabilityOfSeparation <= 0.0d) {
            return;
        }
        this.rank = (probabilityOfSeparation * scoreUnseen(i, node.getMass())) + ((1.0d - probabilityOfSeparation) * this.rank);
    }

    @Override // com.amazon.randomcutforest.Visitor
    public void acceptLeaf(Node node, int i) {
        for (int i2 = 0; i2 < this.queryPoint.length; i2++) {
            if (this.missing[i2]) {
                this.queryPoint[i2] = node.getBoundingBox().getMinValue(i2);
            }
        }
        if (CommonUtils.getProbabilityOfSeparation(node.getBoundingBox(), this.queryPoint) > 0.0d) {
            this.rank = scoreUnseen(i, node.getMass());
        } else if (i == 0) {
            this.rank = 0.0d;
        } else {
            this.rank = scoreSeen(i, node.getMass());
        }
    }

    @Override // com.amazon.randomcutforest.Visitor
    public double[] getResult() {
        return this.queryPoint;
    }

    @Override // com.amazon.randomcutforest.MultiVisitor
    public boolean trigger(Node node) {
        return this.missing[node.getCut().getDimension()];
    }

    @Override // com.amazon.randomcutforest.MultiVisitor
    public MultiVisitor<double[]> newCopy() {
        return new ImputeVisitor(this);
    }

    @Override // com.amazon.randomcutforest.MultiVisitor
    public void combine(MultiVisitor<double[]> multiVisitor) {
        ImputeVisitor imputeVisitor = (ImputeVisitor) multiVisitor;
        if (imputeVisitor.getRank() < this.rank) {
            System.arraycopy(imputeVisitor.queryPoint, 0, this.queryPoint, 0, this.queryPoint.length);
            this.rank = imputeVisitor.getRank();
        }
    }

    protected double scoreSeen(int i, int i2) {
        return CommonUtils.defaultScoreSeenFunction(i, i2);
    }

    protected double scoreUnseen(int i, int i2) {
        return CommonUtils.defaultScoreUnseenFunction(i, i2);
    }
}
