package com.amazon.randomcutforest.tree;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.MultiVisitor;
import com.amazon.randomcutforest.Visitor;
import com.amazon.randomcutforest.sampler.WeightedPoint;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:com/amazon/randomcutforest/tree/RandomCutTree.class */
public class RandomCutTree {
    public static final boolean DEFAULT_STORE_SEQUENCE_INDEXES_ENABLED = false;
    public static final boolean DEFAULT_CENTER_OF_MASS_ENABLED = false;
    private final boolean storeSequenceIndexesEnabled;
    private final boolean centerOfMassEnabled;
    private final Random random;
    protected Node root;

    /* loaded from: input_file:com/amazon/randomcutforest/tree/RandomCutTree$Builder.class */
    public static class Builder<T extends Builder<T>> {
        private boolean storeSequenceIndexesEnabled = false;
        private boolean centerOfMassEnabled = false;
        private Random random = null;

        public T storeSequenceIndexesEnabled(boolean z) {
            this.storeSequenceIndexesEnabled = z;
            return this;
        }

        public T centerOfMassEnabled(boolean z) {
            this.centerOfMassEnabled = z;
            return this;
        }

        public T random(Random random) {
            this.random = random;
            return this;
        }

        public T randomSeed(long j) {
            this.random = new Random(j);
            return this;
        }

        public RandomCutTree build() {
            return new RandomCutTree(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public RandomCutTree(Builder<?> builder) {
        this.storeSequenceIndexesEnabled = ((Builder) builder).storeSequenceIndexesEnabled;
        this.centerOfMassEnabled = ((Builder) builder).centerOfMassEnabled;
        if (((Builder) builder).random != null) {
            this.random = ((Builder) builder).random;
        } else {
            this.random = new Random();
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static RandomCutTree defaultTree(long j) {
        return builder().randomSeed(j).build();
    }

    public static RandomCutTree defaultTree() {
        return builder().build();
    }

    public boolean centerOfMassEnabled() {
        return this.centerOfMassEnabled;
    }

    public boolean storeSequenceIndexesEnabled() {
        return this.storeSequenceIndexesEnabled;
    }

    static Cut randomCut(Random random, BoundingBox boundingBox) {
        double rangeSum = boundingBox.getRangeSum();
        CommonUtils.checkArgument(rangeSum > 0.0d, "box.getRangeSum() must be greater than 0");
        double nextDouble = random.nextDouble() * rangeSum;
        for (int i = 0; i < boundingBox.getDimensions(); i++) {
            double range = boundingBox.getRange(i);
            if (nextDouble <= range) {
                double minValue = boundingBox.getMinValue(i) + nextDouble;
                if (minValue == boundingBox.getMaxValue(i) && boundingBox.getMinValue(i) < boundingBox.getMaxValue(i)) {
                    minValue = Math.nextAfter(boundingBox.getMaxValue(i), boundingBox.getMinValue(i));
                }
                return new Cut(i, minValue);
            }
            nextDouble -= range;
        }
        throw new IllegalStateException("The break point did not lie inside the expected range");
    }

    static void replaceNode(Node node, Node node2) {
        Node parent = node.getParent();
        if (parent != null) {
            if (parent.getLeftChild() == node) {
                parent.setLeftChild(node2);
            } else {
                parent.setRightChild(node2);
            }
        }
        node2.setParent(parent);
    }

    static Node getSibling(Node node) {
        CommonUtils.checkNotNull(node.getParent(), "node parent must not be null");
        Node parent = node.getParent();
        if (parent.getLeftChild() == node) {
            return parent.getRightChild();
        }
        if (parent.getRightChild() == node) {
            return parent.getLeftChild();
        }
        throw new IllegalArgumentException("node parent does not link back to node");
    }

    public void deletePoint(WeightedPoint weightedPoint) {
        CommonUtils.checkState(this.root != null, "root must not be null");
        deletePoint(this.root, weightedPoint.getPoint(), weightedPoint.getSequenceIndex());
    }

    void deletePoint(Node node, double[] dArr, long j) {
        if (!node.isLeaf()) {
            if (Node.isLeftOf(dArr, node)) {
                deletePoint(node.getLeftChild(), dArr, j);
            } else {
                deletePoint(node.getRightChild(), dArr, j);
            }
            node.setBoundingBox(node.getLeftChild().getBoundingBox().getMergedBox(node.getRightChild().getBoundingBox()));
            node.decrementMass();
            if (this.centerOfMassEnabled) {
                node.subtractFromPointSum(dArr);
                return;
            }
            return;
        }
        if (!node.leafPointEquals(dArr)) {
            throw new IllegalStateException(Arrays.toString(dArr) + " " + Arrays.toString(node.getLeafPoint()) + " " + Arrays.equals(node.getLeafPoint(), dArr) + " Inconsistency in trees in delete step.");
        }
        if (this.storeSequenceIndexesEnabled && !node.getSequenceIndexes().contains(Long.valueOf(j))) {
            throw new IllegalStateException("Error in sequence index. Inconsistency in trees in delete step.");
        }
        if (node.getMass() > 1) {
            node.decrementMass();
            if (this.storeSequenceIndexesEnabled) {
                node.deleteSequenceIndex(j);
                return;
            }
            return;
        }
        Node parent = node.getParent();
        if (parent == null) {
            this.root = null;
        } else if (parent.getParent() != null) {
            replaceNode(parent, getSibling(node));
        } else {
            this.root = getSibling(node);
            this.root.setParent(null);
        }
    }

    public void addPoint(WeightedPoint weightedPoint) {
        if (this.root == null) {
            this.root = newLeafNode(weightedPoint.getPoint(), weightedPoint.getSequenceIndex());
        } else {
            addPoint(this.root, weightedPoint.getPoint(), weightedPoint.getSequenceIndex());
        }
    }

    private void addPoint(Node node, double[] dArr, long j) {
        if (node.isLeaf() && node.leafPointEquals(dArr)) {
            node.incrementMass();
            if (this.storeSequenceIndexesEnabled) {
                node.addSequenceIndex(j);
                return;
            }
            return;
        }
        BoundingBox boundingBox = node.getBoundingBox();
        BoundingBox mergedBox = boundingBox.getMergedBox(dArr);
        if (!boundingBox.contains(dArr)) {
            Cut randomCut = randomCut(this.random, mergedBox);
            int dimension = randomCut.getDimension();
            double value = randomCut.getValue();
            double minValue = boundingBox.getMinValue(dimension);
            double maxValue = boundingBox.getMaxValue(dimension);
            if (minValue > value || maxValue <= value) {
                Node newLeafNode = newLeafNode(dArr, j);
                Node newNode = minValue > value ? newNode(newLeafNode, node, randomCut, mergedBox) : newNode(node, newLeafNode, randomCut, mergedBox);
                if (node.getParent() == null) {
                    this.root = newNode;
                } else {
                    replaceNode(node, newNode);
                }
                newLeafNode.setParent(newNode);
                node.setParent(newNode);
                return;
            }
        }
        if (Node.isLeftOf(dArr, node)) {
            addPoint(node.getLeftChild(), dArr, j);
        } else {
            addPoint(node.getRightChild(), dArr, j);
        }
        node.setBoundingBox(mergedBox);
        node.incrementMass();
        if (this.centerOfMassEnabled) {
            node.addToPointSum(dArr);
        }
    }

    public <R> R traverseTree(double[] dArr, Visitor<R> visitor) {
        CommonUtils.checkState(this.root != null, "this tree doesn't contain any nodes");
        traversePathToLeafAndVisitNodes(dArr, visitor, this.root, 0);
        return visitor.getResult();
    }

    private <R> void traversePathToLeafAndVisitNodes(double[] dArr, Visitor<R> visitor, Node node, int i) {
        if (node.isLeaf()) {
            visitor.acceptLeaf(node, i);
        } else {
            traversePathToLeafAndVisitNodes(dArr, visitor, Node.isLeftOf(dArr, node) ? node.getLeftChild() : node.getRightChild(), i + 1);
            visitor.accept(node, i);
        }
    }

    public <R> R traverseTreeMulti(double[] dArr, MultiVisitor<R> multiVisitor) {
        CommonUtils.checkNotNull(dArr, "point must not be null");
        CommonUtils.checkNotNull(multiVisitor, "visitor must not be null");
        CommonUtils.checkState(this.root != null, "this tree doesn't contain any nodes");
        traverseTreeMulti(dArr, multiVisitor, this.root, 0);
        return multiVisitor.getResult();
    }

    private <R> void traverseTreeMulti(double[] dArr, MultiVisitor<R> multiVisitor, Node node, int i) {
        if (node.isLeaf()) {
            multiVisitor.acceptLeaf(node, i);
            return;
        }
        if (!multiVisitor.trigger(node)) {
            traverseTreeMulti(dArr, multiVisitor, Node.isLeftOf(dArr, node) ? node.getLeftChild() : node.getRightChild(), i + 1);
            multiVisitor.accept(node, i);
            return;
        }
        traverseTreeMulti(dArr, multiVisitor, node.getLeftChild(), i + 1);
        MultiVisitor<R> newCopy = multiVisitor.newCopy();
        traverseTreeMulti(dArr, newCopy, node.getRightChild(), i + 1);
        multiVisitor.combine(newCopy);
        multiVisitor.accept(node, i);
    }

    private Node newLeafNode(double[] dArr, long j) {
        Node node = new Node(dArr);
        node.setMass(1);
        if (this.storeSequenceIndexesEnabled) {
            node.addSequenceIndex(j);
        }
        return node;
    }

    private Node newNode(Node node, Node node2, Cut cut, BoundingBox boundingBox) {
        Node node3 = new Node(node, node2, cut, boundingBox, this.centerOfMassEnabled);
        if (node != null) {
            node3.addMass(node.getMass());
            if (this.centerOfMassEnabled) {
                node3.addToPointSum(node.getPointSum());
            }
        }
        if (node2 != null) {
            node3.addMass(node2.getMass());
            if (this.centerOfMassEnabled) {
                node3.addToPointSum(node2.getPointSum());
            }
        }
        return node3;
    }

    public Node getRoot() {
        return this.root;
    }
}
