package com.amazon.randomcutforest;

import com.amazon.randomcutforest.returntypes.ConvergingAccumulator;
import com.amazon.randomcutforest.tree.RandomCutTree;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;

/* loaded from: input_file:com/amazon/randomcutforest/ParallelForestTraversalExecutor.class */
public class ParallelForestTraversalExecutor extends AbstractForestTraversalExecutor {
    private ForkJoinPool forkJoinPool;
    private final int threadPoolSize;

    public ParallelForestTraversalExecutor(ArrayList<TreeUpdater> arrayList, int i) {
        super(arrayList);
        this.threadPoolSize = i;
        this.forkJoinPool = new ForkJoinPool(i);
    }

    @Override // com.amazon.randomcutforest.AbstractForestTraversalExecutor
    protected void update(double[] dArr, long j) {
        submitAndJoin(() -> {
            this.treeUpdaters.parallelStream().forEach(treeUpdater -> {
                treeUpdater.update(dArr, j);
            });
            return null;
        });
    }

    @Override // com.amazon.randomcutforest.AbstractForestTraversalExecutor
    public <R, S> S traverseForest(double[] dArr, Function<RandomCutTree, Visitor<R>> function, BinaryOperator<R> binaryOperator, Function<R, S> function2) {
        return (S) ((Optional) submitAndJoin(() -> {
            return this.treeUpdaters.parallelStream().map((v0) -> {
                return v0.getTree();
            }).map(randomCutTree -> {
                return randomCutTree.traverseTree(dArr, (Visitor) function.apply(randomCutTree));
            }).reduce(binaryOperator).map(function2);
        })).orElseThrow(() -> {
            return new IllegalStateException("accumulator returned an empty result");
        });
    }

    @Override // com.amazon.randomcutforest.AbstractForestTraversalExecutor
    public <R, S> S traverseForest(double[] dArr, Function<RandomCutTree, Visitor<R>> function, Collector<R, ?, S> collector) {
        return (S) submitAndJoin(() -> {
            return this.treeUpdaters.parallelStream().map((v0) -> {
                return v0.getTree();
            }).map(randomCutTree -> {
                return randomCutTree.traverseTree(dArr, (Visitor) function.apply(randomCutTree));
            }).collect(collector);
        });
    }

    @Override // com.amazon.randomcutforest.AbstractForestTraversalExecutor
    public <R, S> S traverseForest(double[] dArr, Function<RandomCutTree, Visitor<R>> function, ConvergingAccumulator<R> convergingAccumulator, Function<R, S> function2) {
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= this.treeUpdaters.size()) {
                break;
            }
            int min = Math.min(i2 + this.threadPoolSize, this.treeUpdaters.size());
            List list = (List) submitAndJoin(() -> {
                return (List) this.treeUpdaters.subList(i2, min).parallelStream().map((v0) -> {
                    return v0.getTree();
                }).map(randomCutTree -> {
                    return randomCutTree.traverseTree(dArr, (Visitor) function.apply(randomCutTree));
                }).collect(Collectors.toList());
            });
            Objects.requireNonNull(convergingAccumulator);
            list.forEach(convergingAccumulator::accept);
            if (convergingAccumulator.isConverged()) {
                break;
            }
            i = i2 + this.threadPoolSize;
        }
        return function2.apply(convergingAccumulator.getAccumulatedValue());
    }

    @Override // com.amazon.randomcutforest.AbstractForestTraversalExecutor
    public <R, S> S traverseForestMulti(double[] dArr, Function<RandomCutTree, MultiVisitor<R>> function, BinaryOperator<R> binaryOperator, Function<R, S> function2) {
        return (S) ((Optional) submitAndJoin(() -> {
            return this.treeUpdaters.parallelStream().map((v0) -> {
                return v0.getTree();
            }).map(randomCutTree -> {
                return randomCutTree.traverseTreeMulti(dArr, (MultiVisitor) function.apply(randomCutTree));
            }).reduce(binaryOperator).map(function2);
        })).orElseThrow(() -> {
            return new IllegalStateException("accumulator returned an empty result");
        });
    }

    @Override // com.amazon.randomcutforest.AbstractForestTraversalExecutor
    public <R, S> S traverseForestMulti(double[] dArr, Function<RandomCutTree, MultiVisitor<R>> function, Collector<R, ?, S> collector) {
        return (S) submitAndJoin(() -> {
            return this.treeUpdaters.parallelStream().map((v0) -> {
                return v0.getTree();
            }).map(randomCutTree -> {
                return randomCutTree.traverseTreeMulti(dArr, (MultiVisitor) function.apply(randomCutTree));
            }).collect(collector);
        });
    }

    private <T> T submitAndJoin(Callable<T> callable) {
        if (this.forkJoinPool == null) {
            this.forkJoinPool = new ForkJoinPool(this.threadPoolSize);
        }
        return this.forkJoinPool.submit((Callable) callable).join();
    }
}
