package com.amazon.opendistroforelasticsearch.ad.ml;

import com.amazon.opendistroforelasticsearch.ad.DetectorModelSize;
import com.amazon.opendistroforelasticsearch.ad.MemoryTracker;
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.ResourceNotFoundException;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager;
import com.amazon.opendistroforelasticsearch.ad.ml.rcf.CombinedRcfResult;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.serialize.RandomCutForestSerDe;
import com.google.gson.Gson;
import java.security.AccessController;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.CheckedConsumer;

/* loaded from: input_file:com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.class */
public class ModelManager implements DetectorModelSize {
    protected static final String DETECTOR_ID_PATTERN = "(.*)_model_.+";
    protected static final String ENTITY_SAMPLE = "sp";
    protected static final String ENTITY_RCF = "rcf";
    protected static final String ENTITY_THRESHOLD = "th";
    private static final double FULL_CONFIDENCE_EXPONENT = 18.43d;
    private static final Logger logger = LogManager.getLogger(ModelManager.class);
    private RCFMemoryAwareConcurrentHashmap<String> forests;
    private Map<String, ModelState<ThresholdingModel>> thresholds = new ConcurrentHashMap();
    private final int rcfNumTrees;
    private final int rcfNumSamplesInTree;
    private final double rcfTimeDecay;
    private final int rcfNumMinSamples;
    private final double thresholdMinPvalue;
    private final double thresholdMaxRankError;
    private final double thresholdMaxScore;
    private final int thresholdNumLogNormalQuantiles;
    private final int thresholdDownsamples;
    private final long thresholdMaxSamples;
    private final Class<? extends ThresholdingModel> thresholdingModelClass;
    private final int minPreviewSize;
    private final Duration modelTtl;
    private final Duration checkpointInterval;
    private final RandomCutForestSerDe rcfSerde;
    private final CheckpointDao checkpointDao;
    private final Gson gson;
    private final Clock clock;
    public FeatureManager featureManager;
    private EntityColdStarter entityColdStarter;
    private ModelPartitioner modelPartitioner;
    private MemoryTracker memoryTracker;

    /* loaded from: input_file:com/amazon/opendistroforelasticsearch/ad/ml/ModelManager$ModelType.class */
    public enum ModelType {
        RCF("rcf"),
        THRESHOLD("threshold"),
        ENTITY("entity");

        private String name;

        ModelType(String str) {
            this.name = str;
        }

        public String getName() {
            return this.name;
        }
    }

    public ModelManager(RandomCutForestSerDe randomCutForestSerDe, CheckpointDao checkpointDao, Gson gson, Clock clock, int i, int i2, double d, int i3, double d2, double d3, double d4, int i4, int i5, long j, Class<? extends ThresholdingModel> cls, int i6, Duration duration, Duration duration2, EntityColdStarter entityColdStarter, ModelPartitioner modelPartitioner, FeatureManager featureManager, MemoryTracker memoryTracker) {
        this.rcfSerde = randomCutForestSerDe;
        this.checkpointDao = checkpointDao;
        this.gson = gson;
        this.clock = clock;
        this.rcfNumTrees = i;
        this.rcfNumSamplesInTree = i2;
        this.rcfTimeDecay = d;
        this.rcfNumMinSamples = i3;
        this.thresholdMinPvalue = d2;
        this.thresholdMaxRankError = d3;
        this.thresholdMaxScore = d4;
        this.thresholdNumLogNormalQuantiles = i4;
        this.thresholdDownsamples = i5;
        this.thresholdMaxSamples = j;
        this.thresholdingModelClass = cls;
        this.minPreviewSize = i6;
        this.modelTtl = duration;
        this.checkpointInterval = duration2;
        this.forests = new RCFMemoryAwareConcurrentHashmap<>(memoryTracker);
        this.entityColdStarter = entityColdStarter;
        this.modelPartitioner = modelPartitioner;
        this.featureManager = featureManager;
        this.memoryTracker = memoryTracker;
    }

    public CombinedRcfResult combineRcfResults(List<RcfResult> list, int i) {
        CombinedRcfResult combinedRcfResult;
        if (list.isEmpty()) {
            combinedRcfResult = new CombinedRcfResult(0.0d, 0.0d, new double[0]);
        } else {
            int sum = list.stream().mapToInt((v0) -> {
                return v0.getForestSize();
            }).sum();
            combinedRcfResult = sum == 0 ? new CombinedRcfResult(0.0d, 0.0d, new double[0]) : new CombinedRcfResult(list.stream().mapToDouble(rcfResult -> {
                return rcfResult.getScore() * rcfResult.getForestSize();
            }).sum() / sum, list.stream().mapToDouble(rcfResult2 -> {
                return rcfResult2.getConfidence() * rcfResult2.getForestSize();
            }).sum() / Math.max(this.rcfNumTrees, sum), combineAttribution(list, i, sum));
        }
        return combinedRcfResult;
    }

    private double[] combineAttribution(List<RcfResult> list, int i, int i2) {
        double[] dArr = new double[i];
        double d = 0.0d;
        Iterator<RcfResult> it = list.iterator();
        while (it.hasNext()) {
            double[] attribution = it.next().getAttribution();
            for (int i3 = 0; i3 < i; i3++) {
                double forestSize = (attribution[(attribution.length - i) + i3] * r0.getForestSize()) / i2;
                int i4 = i3;
                dArr[i4] = dArr[i4] + forestSize;
                d += forestSize;
            }
        }
        for (int i5 = 0; i5 < i; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] / d;
        }
        return dArr;
    }

    public String getDetectorIdForModelId(String str) {
        Matcher matcher = Pattern.compile(DETECTOR_ID_PATTERN).matcher(str);
        if (matcher.matches()) {
            return matcher.group(1);
        }
        throw new IllegalArgumentException("Invalid model id " + str);
    }

    public void getRcfResult(String str, String str2, double[] dArr, ActionListener<RcfResult> actionListener) {
        if (this.forests.containsKey(str2)) {
            getRcfResult(this.forests.get(str2), dArr, actionListener);
            return;
        }
        CheckpointDao checkpointDao = this.checkpointDao;
        CheckedConsumer checkedConsumer = optional -> {
            processRcfCheckpoint(optional, str2, str, dArr, actionListener);
        };
        Objects.requireNonNull(actionListener);
        checkpointDao.getModelCheckpoint(str2, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void getRcfResult(ModelState<RandomCutForest> modelState, double[] dArr, ActionListener<RcfResult> actionListener) {
        modelState.setLastUsedTime(this.clock.instant());
        RandomCutForest model = modelState.getModel();
        double anomalyScore = model.getAnomalyScore(dArr);
        double computeRcfConfidence = computeRcfConfidence(model);
        int numberOfTrees = model.getNumberOfTrees();
        double[] anomalyAttribution = getAnomalyAttribution(model, dArr);
        model.update(dArr);
        actionListener.onResponse(new RcfResult(anomalyScore, computeRcfConfidence, numberOfTrees, anomalyAttribution));
    }

    private double[] getAnomalyAttribution(RandomCutForest randomCutForest, double[] dArr) {
        DiVector anomalyAttribution = randomCutForest.getAnomalyAttribution(dArr);
        anomalyAttribution.renormalize(1.0d);
        double[] dArr2 = new double[anomalyAttribution.getDimensions()];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = anomalyAttribution.getHighLowSum(i);
        }
        return dArr2;
    }

    private Optional<ModelState<RandomCutForest>> restoreCheckpoint(Optional<String> optional, String str, String str2) {
        return optional.map(str3 -> {
            return (RandomCutForest) AccessController.doPrivileged(() -> {
                return this.rcfSerde.fromJson(str3);
            });
        }).filter(randomCutForest -> {
            return this.memoryTracker.isHostingAllowed(str2, randomCutForest);
        }).map(randomCutForest2 -> {
            return ModelState.createSingleEntityModelState(randomCutForest2, str, str2, ModelType.RCF.getName(), this.clock);
        });
    }

    private void processRcfCheckpoint(Optional<String> optional, String str, String str2, double[] dArr, ActionListener<RcfResult> actionListener) {
        Optional<ModelState<RandomCutForest>> restoreCheckpoint = restoreCheckpoint(optional, str, str2);
        if (!restoreCheckpoint.isPresent()) {
            throw new ResourceNotFoundException(str2, CommonErrorMessages.NO_CHECKPOINT_ERR_MSG + str);
        }
        this.forests.put((RCFMemoryAwareConcurrentHashmap<String>) str, restoreCheckpoint.get());
        getRcfResult(restoreCheckpoint.get(), dArr, actionListener);
    }

    private void processRcfCheckpoint(Optional<String> optional, String str, String str2, ActionListener<Long> actionListener) {
        logger.info("Restoring checkpoint for {}", str);
        Optional<ModelState<RandomCutForest>> restoreCheckpoint = restoreCheckpoint(optional, str, str2);
        if (!restoreCheckpoint.isPresent()) {
            actionListener.onFailure(new ResourceNotFoundException(str2, CommonErrorMessages.NO_CHECKPOINT_ERR_MSG + str));
        } else {
            this.forests.put((RCFMemoryAwareConcurrentHashmap<String>) str, restoreCheckpoint.get());
            actionListener.onResponse(Long.valueOf(restoreCheckpoint.get().getModel().getTotalUpdates()));
        }
    }

    public void getThresholdingResult(String str, String str2, double d, ActionListener<ThresholdingResult> actionListener) {
        if (this.thresholds.containsKey(str2)) {
            getThresholdingResult(this.thresholds.get(str2), d, actionListener);
            return;
        }
        CheckpointDao checkpointDao = this.checkpointDao;
        CheckedConsumer checkedConsumer = optional -> {
            processThresholdCheckpoint(optional, str2, str, d, actionListener);
        };
        Objects.requireNonNull(actionListener);
        checkpointDao.getModelCheckpoint(str2, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void getThresholdingResult(ModelState<ThresholdingModel> modelState, double d, ActionListener<ThresholdingResult> actionListener) {
        ThresholdingModel model = modelState.getModel();
        double grade = model.grade(d);
        double confidence = model.confidence();
        if (d > 0.0d) {
            model.update(d);
        }
        modelState.setLastUsedTime(this.clock.instant());
        actionListener.onResponse(new ThresholdingResult(grade, confidence, d));
    }

    private void processThresholdCheckpoint(Optional<String> optional, String str, String str2, double d, ActionListener<ThresholdingResult> actionListener) {
        Optional map = optional.map(str3 -> {
            return (ThresholdingModel) AccessController.doPrivileged(() -> {
                return (ThresholdingModel) this.gson.fromJson(str3, this.thresholdingModelClass);
            });
        }).map(thresholdingModel -> {
            return ModelState.createSingleEntityModelState(thresholdingModel, str, str2, ModelType.THRESHOLD.getName(), this.clock);
        });
        if (!map.isPresent()) {
            throw new ResourceNotFoundException(str2, CommonErrorMessages.NO_CHECKPOINT_ERR_MSG + str);
        }
        this.thresholds.put(str, (ModelState) map.get());
        getThresholdingResult((ModelState) map.get(), d, actionListener);
    }

    public Set<String> getAllModelIds() {
        return (Set) Stream.of((Object[]) new Set[]{this.forests.keySet(), this.thresholds.keySet()}).flatMap(set -> {
            return set.stream();
        }).collect(Collectors.toSet());
    }

    public List<ModelState<?>> getAllModels() {
        return (List) Stream.concat(this.forests.values().stream(), this.thresholds.values().stream()).collect(Collectors.toList());
    }

    @Deprecated
    public void stopModel(String str, String str2) {
        logger.info(String.format(Locale.ROOT, "Stopping detector %s model %s", str, str2));
        stopModel(this.forests, str2, this::toCheckpoint);
        stopModel(this.thresholds, str2, this::toCheckpoint);
    }

    private <T> void stopModel(Map<String, ModelState<T>> map, String str, Function<T, String> function) {
        Instant instant = this.clock.instant();
        Optional.ofNullable(map.remove(str)).filter(modelState -> {
            return modelState.getLastCheckpointTime().plus((TemporalAmount) this.checkpointInterval).isBefore(instant);
        }).ifPresent(modelState2 -> {
            this.checkpointDao.putModelCheckpoint(str, (String) function.apply(modelState2.getModel()));
        });
    }

    public void stopModel(String str, String str2, ActionListener<Void> actionListener) {
        logger.info(String.format(Locale.ROOT, "Stopping detector %s model %s", str, str2));
        RCFMemoryAwareConcurrentHashmap<String> rCFMemoryAwareConcurrentHashmap = this.forests;
        Function function = this::toCheckpoint;
        CheckedConsumer checkedConsumer = r9 -> {
            stopModel(this.thresholds, str2, this::toCheckpoint, actionListener);
        };
        Objects.requireNonNull(actionListener);
        stopModel(rCFMemoryAwareConcurrentHashmap, str2, function, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private <T> void stopModel(Map<String, ModelState<T>> map, String str, Function<T, String> function, ActionListener<Void> actionListener) {
        Instant instant = this.clock.instant();
        Optional<T> filter = Optional.ofNullable(map.remove(str)).filter(modelState -> {
            return modelState.getLastCheckpointTime().plus((TemporalAmount) this.checkpointInterval).isBefore(instant);
        });
        if (filter.isPresent()) {
            filter.ifPresent(modelState2 -> {
                CheckpointDao checkpointDao = this.checkpointDao;
                String str2 = (String) function.apply(modelState2.getModel());
                CheckedConsumer checkedConsumer = r4 -> {
                    actionListener.onResponse((Object) null);
                };
                Objects.requireNonNull(actionListener);
                checkpointDao.putModelCheckpoint(str, str2, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
            });
        } else {
            actionListener.onResponse((Object) null);
        }
    }

    @Deprecated
    public void clear(String str) {
        clearModels(str, this.forests);
        clearModels(str, this.thresholds);
    }

    public void clear(String str, ActionListener<Void> actionListener) {
        RCFMemoryAwareConcurrentHashmap<String> rCFMemoryAwareConcurrentHashmap = this.forests;
        CheckedConsumer checkedConsumer = r8 -> {
            clearModels(str, this.thresholds, actionListener);
        };
        Objects.requireNonNull(actionListener);
        clearModels(str, rCFMemoryAwareConcurrentHashmap, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void clearModels(String str, Map<String, ?> map, ActionListener<Void> actionListener) {
        clearModelForIterator(str, map, map.keySet().iterator(), actionListener);
    }

    private void clearModelForIterator(String str, Map<String, ?> map, Iterator<String> it, ActionListener<Void> actionListener) {
        if (!it.hasNext()) {
            actionListener.onResponse((Object) null);
            return;
        }
        String next = it.next();
        if (!getDetectorIdForModelId(next).equals(str)) {
            clearModelForIterator(str, map, it, actionListener);
            return;
        }
        map.remove(next);
        CheckpointDao checkpointDao = this.checkpointDao;
        CheckedConsumer checkedConsumer = r11 -> {
            clearModelForIterator(str, map, it, actionListener);
        };
        Objects.requireNonNull(actionListener);
        checkpointDao.deleteModelCheckpoint(next, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    @Deprecated
    public void trainModel(AnomalyDetector anomalyDetector, double[][] dArr) {
        int intValue = anomalyDetector.getShingleSize().intValue();
        if (dArr.length == 0 || dArr[0].length == 0) {
            throw new IllegalArgumentException("Data points must not be empty.");
        }
        if (dArr[0].length != anomalyDetector.getEnabledFeatureIds().size() * intValue) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "Feature dimension is not correct, we expect %s but get %d", Integer.valueOf(anomalyDetector.getEnabledFeatureIds().size() * intValue), Integer.valueOf(dArr[0].length)));
        }
        int length = dArr[0].length;
        Map.Entry<Integer, Integer> partitionedForestSizes = this.modelPartitioner.getPartitionedForestSizes(anomalyDetector);
        int intValue2 = partitionedForestSizes.getKey().intValue();
        int intValue3 = partitionedForestSizes.getValue().intValue();
        double[] dArr2 = new double[dArr.length];
        Arrays.fill(dArr2, 0.0d);
        for (int i = 0; i < intValue2; i++) {
            RandomCutForest build = RandomCutForest.builder().dimensions(length).sampleSize(this.rcfNumSamplesInTree).numberOfTrees(intValue3).lambda(this.rcfTimeDecay).outputAfter(this.rcfNumMinSamples).parallelExecutionEnabled(false).build();
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + build.getAnomalyScore(dArr[i2]);
                build.update(dArr[i2]);
            }
            this.checkpointDao.putModelCheckpoint(this.modelPartitioner.getRcfModelId(anomalyDetector.getDetectorId(), i), (String) AccessController.doPrivileged(() -> {
                return this.rcfSerde.toJson(build);
            }));
        }
        double[] array = DoubleStream.of(dArr2).filter(d -> {
            return d > 0.0d;
        }).map(d2 -> {
            return d2 / intValue2;
        }).toArray();
        HybridThresholdingModel hybridThresholdingModel = new HybridThresholdingModel(this.thresholdMinPvalue, this.thresholdMaxRankError, this.thresholdMaxScore, this.thresholdNumLogNormalQuantiles, this.thresholdDownsamples, this.thresholdMaxSamples);
        hybridThresholdingModel.train(array);
        this.checkpointDao.putModelCheckpoint(this.modelPartitioner.getThresholdModelId(anomalyDetector.getDetectorId()), (String) AccessController.doPrivileged(() -> {
            return this.gson.toJson(hybridThresholdingModel);
        }));
    }

    public void trainModel(AnomalyDetector anomalyDetector, double[][] dArr, ActionListener<Void> actionListener) {
        if (dArr.length == 0 || dArr[0].length == 0) {
            actionListener.onFailure(new IllegalArgumentException("Data points must not be empty."));
            return;
        }
        int length = dArr[0].length;
        try {
            Map.Entry<Integer, Integer> partitionedForestSizes = this.modelPartitioner.getPartitionedForestSizes(RandomCutForest.builder().dimensions(length).sampleSize(this.rcfNumSamplesInTree).numberOfTrees(this.rcfNumTrees).outputAfter(this.rcfNumSamplesInTree).parallelExecutionEnabled(false).build(), anomalyDetector.getDetectorId());
            int intValue = partitionedForestSizes.getKey().intValue();
            int intValue2 = partitionedForestSizes.getValue().intValue();
            double[] dArr2 = new double[dArr.length];
            Arrays.fill(dArr2, 0.0d);
            trainModelForStep(anomalyDetector, dArr, length, intValue, intValue2, dArr2, 0, actionListener);
        } catch (LimitExceededException e) {
            actionListener.onFailure(e);
        }
    }

    private void trainModelForStep(AnomalyDetector anomalyDetector, double[][] dArr, int i, int i2, int i3, double[] dArr2, int i4, ActionListener<Void> actionListener) {
        if (i4 >= i2) {
            double[] array = DoubleStream.of(dArr2).filter(d -> {
                return d > 0.0d;
            }).map(d2 -> {
                return d2 / i2;
            }).toArray();
            HybridThresholdingModel hybridThresholdingModel = new HybridThresholdingModel(this.thresholdMinPvalue, this.thresholdMaxRankError, this.thresholdMaxScore, this.thresholdNumLogNormalQuantiles, this.thresholdDownsamples, this.thresholdMaxSamples);
            hybridThresholdingModel.train(array);
            String thresholdModelId = this.modelPartitioner.getThresholdModelId(anomalyDetector.getDetectorId());
            String str = (String) AccessController.doPrivileged(() -> {
                return this.gson.toJson(hybridThresholdingModel);
            });
            CheckpointDao checkpointDao = this.checkpointDao;
            CheckedConsumer checkedConsumer = r4 -> {
                actionListener.onResponse((Object) null);
            };
            Objects.requireNonNull(actionListener);
            checkpointDao.putModelCheckpoint(thresholdModelId, str, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
            return;
        }
        RandomCutForest build = RandomCutForest.builder().dimensions(i).sampleSize(this.rcfNumSamplesInTree).numberOfTrees(i3).lambda(this.rcfTimeDecay).outputAfter(this.rcfNumMinSamples).parallelExecutionEnabled(false).build();
        for (int i5 = 0; i5 < dArr.length; i5++) {
            int i6 = i5;
            dArr2[i6] = dArr2[i6] + build.getAnomalyScore(dArr[i5]);
            build.update(dArr[i5]);
        }
        String rcfModelId = this.modelPartitioner.getRcfModelId(anomalyDetector.getDetectorId(), i4);
        String str2 = (String) AccessController.doPrivileged(() -> {
            return this.rcfSerde.toJson(build);
        });
        CheckpointDao checkpointDao2 = this.checkpointDao;
        CheckedConsumer checkedConsumer2 = r19 -> {
            trainModelForStep(anomalyDetector, dArr, i, i2, i3, dArr2, i4 + 1, actionListener);
        };
        Objects.requireNonNull(actionListener);
        checkpointDao2.putModelCheckpoint(rcfModelId, str2, ActionListener.wrap(checkedConsumer2, actionListener::onFailure));
    }

    private void clearModels(String str, Map<String, ?> map) {
        map.keySet().stream().filter(str2 -> {
            return getDetectorIdForModelId(str2).equals(str);
        }).forEach(str3 -> {
            map.remove(str3);
            this.checkpointDao.deleteModelCheckpoint(str3);
        });
    }

    private String toCheckpoint(RandomCutForest randomCutForest) {
        return (String) AccessController.doPrivileged(() -> {
            return this.rcfSerde.toJson(randomCutForest);
        });
    }

    private String toCheckpoint(ThresholdingModel thresholdingModel) {
        return (String) AccessController.doPrivileged(() -> {
            return this.gson.toJson(thresholdingModel);
        });
    }

    @Deprecated
    public void maintenance() {
        maintenance(this.forests, this::toCheckpoint);
        maintenance(this.thresholds, this::toCheckpoint);
    }

    private <T> void maintenance(Map<String, ModelState<T>> map, Function<T, String> function) {
        map.entrySet().stream().forEach(entry -> {
            String str = (String) entry.getKey();
            try {
                ModelState modelState = (ModelState) entry.getValue();
                Instant instant = this.clock.instant();
                if (modelState.getLastCheckpointTime().plus((TemporalAmount) this.checkpointInterval).isBefore(instant)) {
                    this.checkpointDao.putModelCheckpoint(str, (String) function.apply(modelState.getModel()));
                    modelState.setLastCheckpointTime(instant);
                }
                if (modelState.getLastUsedTime().plus((TemporalAmount) this.modelTtl).isBefore(instant)) {
                    map.remove(str);
                }
            } catch (Exception e) {
                logger.warn("Failed to finish maintenance for model id " + str, e);
            }
        });
    }

    public void maintenance(ActionListener<Void> actionListener) {
        RCFMemoryAwareConcurrentHashmap<String> rCFMemoryAwareConcurrentHashmap = this.forests;
        Function function = this::toCheckpoint;
        Iterator<Map.Entry<String, ModelState<RandomCutForest>>> it = this.forests.entrySet().iterator();
        CheckedConsumer checkedConsumer = r8 -> {
            maintenanceForIterator(this.thresholds, this::toCheckpoint, this.thresholds.entrySet().iterator(), actionListener);
        };
        Objects.requireNonNull(actionListener);
        maintenanceForIterator(rCFMemoryAwareConcurrentHashmap, function, it, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private <T> void maintenanceForIterator(Map<String, ModelState<T>> map, Function<T, String> function, Iterator<Map.Entry<String, ModelState<T>>> it, ActionListener<Void> actionListener) {
        if (!it.hasNext()) {
            actionListener.onResponse((Object) null);
            return;
        }
        Map.Entry<String, ModelState<T>> next = it.next();
        String key = next.getKey();
        ModelState<T> value = next.getValue();
        Instant instant = this.clock.instant();
        if (value.expired(this.modelTtl)) {
            map.remove(key);
        }
        if (value.getLastCheckpointTime().plus((TemporalAmount) this.checkpointInterval).isBefore(instant)) {
            this.checkpointDao.putModelCheckpoint(key, function.apply(value.getModel()), ActionListener.wrap(r13 -> {
                value.setLastCheckpointTime(instant);
                maintenanceForIterator(map, function, it, actionListener);
            }, exc -> {
                logger.warn("Failed to finish maintenance for model id " + key, exc);
                maintenanceForIterator(map, function, it, actionListener);
            }));
        } else {
            maintenanceForIterator(map, function, it, actionListener);
        }
    }

    public List<ThresholdingResult> getPreviewResults(double[][] dArr) {
        if (dArr.length < this.minPreviewSize) {
            throw new IllegalArgumentException("Insufficient data for preview results. Minimum required: " + this.minPreviewSize);
        }
        RandomCutForest build = RandomCutForest.builder().randomSeed(0L).dimensions(dArr[0].length).sampleSize(this.rcfNumSamplesInTree).numberOfTrees(this.rcfNumTrees).lambda(this.rcfTimeDecay).outputAfter(this.rcfNumSamplesInTree).parallelExecutionEnabled(false).build();
        double[] array = Arrays.stream(dArr).mapToDouble(dArr2 -> {
            double anomalyScore = build.getAnomalyScore(dArr2);
            build.update(dArr2);
            return anomalyScore;
        }).filter(d -> {
            return d > 0.0d;
        }).toArray();
        HybridThresholdingModel hybridThresholdingModel = new HybridThresholdingModel(this.thresholdMinPvalue, this.thresholdMaxRankError, this.thresholdMaxScore, this.thresholdNumLogNormalQuantiles, this.thresholdDownsamples, this.thresholdMaxSamples);
        hybridThresholdingModel.train(array);
        return (List) Arrays.stream(dArr).map(dArr3 -> {
            double anomalyScore = build.getAnomalyScore(dArr3);
            build.update(dArr3);
            ThresholdingResult thresholdingResult = new ThresholdingResult(hybridThresholdingModel.grade(anomalyScore), hybridThresholdingModel.confidence(), anomalyScore);
            hybridThresholdingModel.update(anomalyScore);
            return thresholdingResult;
        }).collect(Collectors.toList());
    }

    private double computeRcfConfidence(RandomCutForest randomCutForest) {
        long totalUpdates = randomCutForest.getTotalUpdates();
        double lambda = randomCutForest.getLambda();
        double d = totalUpdates * lambda;
        if (d >= FULL_CONFIDENCE_EXPONENT) {
            return 1.0d;
        }
        double exp = Math.exp(d);
        return Math.max(0.0d, (exp - Math.exp(lambda * Math.min(totalUpdates, randomCutForest.getSampleSize()))) / (exp - 1.0d));
    }

    @Override // com.amazon.opendistroforelasticsearch.ad.DetectorModelSize
    public Map<String, Long> getModelSize(String str) {
        HashMap hashMap = new HashMap();
        this.forests.entrySet().stream().filter(entry -> {
            return getDetectorIdForModelId((String) entry.getKey()).equals(str);
        }).forEach(entry2 -> {
            hashMap.put((String) entry2.getKey(), Long.valueOf(this.memoryTracker.estimateModelSize((RandomCutForest) ((ModelState) entry2.getValue()).getModel())));
        });
        this.thresholds.entrySet().stream().filter(entry3 -> {
            return getDetectorIdForModelId((String) entry3.getKey()).equals(str);
        }).forEach(entry4 -> {
            hashMap.put((String) entry4.getKey(), Long.valueOf(this.memoryTracker.getThresholdModelBytes()));
        });
        return hashMap;
    }

    public void getTotalUpdates(String str, String str2, ActionListener<Long> actionListener) {
        ModelState<RandomCutForest> modelState = this.forests.get(str);
        if (modelState != null) {
            actionListener.onResponse(Long.valueOf(modelState.getModel().getTotalUpdates()));
            return;
        }
        CheckpointDao checkpointDao = this.checkpointDao;
        CheckedConsumer checkedConsumer = optional -> {
            processRcfCheckpoint(optional, str, str2, actionListener);
        };
        Objects.requireNonNull(actionListener);
        checkpointDao.getModelCheckpoint(str, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    public ThresholdingResult getAnomalyResultForEntity(String str, double[] dArr, String str2, ModelState<EntityModel> modelState, String str3) {
        ThresholdingResult thresholdingResult;
        if (modelState != null) {
            Queue<double[]> samples = modelState.getModel().getSamples();
            samples.add(dArr);
            if (samples.size() > this.rcfNumMinSamples) {
                samples.remove();
            }
            thresholdingResult = maybeTrainBeforeScore(modelState, str2);
        } else {
            thresholdingResult = new ThresholdingResult(0.0d, 0.0d, 0.0d);
        }
        return thresholdingResult;
    }

    private ThresholdingResult score(Queue<double[]> queue, String str, ModelState<EntityModel> modelState) {
        EntityModel model = modelState.getModel();
        RandomCutForest rcf = model.getRcf();
        ThresholdingModel threshold = model.getThreshold();
        double d = 0.0d;
        while (queue.peek() != null) {
            double[] poll = queue.poll();
            d = rcf.getAnomalyScore(poll);
            rcf.update(poll);
            threshold.update(d);
        }
        ThresholdingResult thresholdingResult = new ThresholdingResult(threshold.grade(d), computeRcfConfidence(rcf) * threshold.confidence(), d);
        modelState.setLastUsedTime(this.clock.instant());
        return thresholdingResult;
    }

    public String getEntityModelId(String str, String str2) {
        return str + "_entity_" + str2;
    }

    public void processEntityCheckpoint(Optional<Map.Entry<EntityModel, Instant>> optional, String str, String str2, ModelState<EntityModel> modelState) {
        if (optional.isPresent()) {
            Map.Entry<EntityModel, Instant> entry = optional.get();
            EntityModel key = entry.getKey();
            combineSamples(modelState.getModel(), key);
            modelState.setModel(key);
            modelState.setLastCheckpointTime(entry.getValue());
        } else {
            modelState.setLastCheckpointTime(this.clock.instant().minus((TemporalAmount) this.checkpointInterval));
        }
        if (modelState.getModel() == null) {
            modelState.setModel(new EntityModel(str, new ArrayDeque(), null, null));
        }
        maybeTrainBeforeScore(modelState, str2);
    }

    private void combineSamples(EntityModel entityModel, EntityModel entityModel2) {
        Queue<double[]> samples = entityModel.getSamples();
        while (samples.peek() != null) {
            entityModel2.addSample(samples.poll());
        }
    }

    private ThresholdingResult maybeTrainBeforeScore(ModelState<EntityModel> modelState, String str) {
        EntityModel model = modelState.getModel();
        Queue<double[]> samples = model.getSamples();
        String modelId = model.getModelId();
        String detectorId = modelState.getDetectorId();
        if (model.getRcf() == null || model.getThreshold() == null) {
            this.entityColdStarter.trainModel(samples, modelId, str, detectorId, modelState);
        }
        return (model.getRcf() == null || model.getThreshold() == null || 0 != 0) ? new ThresholdingResult(0.0d, 0.0d, 0.0d) : score(samples, modelId, modelState);
    }
}
