package com.amazon.opendistroforelasticsearch.ad.ml;

import com.amazon.opendistroforelasticsearch.ad.AnomalyDetectorPlugin;
import com.amazon.opendistroforelasticsearch.ad.NodeStateManager;
import com.amazon.opendistroforelasticsearch.ad.common.exception.AnomalyDetectionException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.EndRunException;
import com.amazon.opendistroforelasticsearch.ad.dataprocessor.Interpolator;
import com.amazon.opendistroforelasticsearch.ad.feature.FeatureManager;
import com.amazon.opendistroforelasticsearch.ad.feature.SearchFeatureDao;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.model.IntervalTimeConfiguration;
import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings;
import com.amazon.randomcutforest.RandomCutForest;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.AbstractMap;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.core.util.Throwables;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;

/* loaded from: input_file:com/amazon/opendistroforelasticsearch/ad/ml/EntityColdStarter.class */
public class EntityColdStarter {
    private static final Logger logger = LogManager.getLogger(EntityColdStarter.class);
    private final Clock clock;
    private final ThreadPool threadPool;
    private final NodeStateManager nodeStateManager;
    private final int rcfSampleSize;
    private final int numberOfTrees;
    private final double rcfTimeDecay;
    private final int numMinSamples;
    private final double thresholdMinPvalue;
    private final double thresholdMaxRankError;
    private final double thresholdMaxScore;
    private final int thresholdNumLogNormalQuantiles;
    private final int thresholdDownsamples;
    private final long thresholdMaxSamples;
    private final int maxSampleStride;
    private final int maxTrainSamples;
    private final Interpolator interpolator;
    private final SearchFeatureDao searchFeatureDao;
    private final int shingleSize;
    private Instant lastThrottledColdStartTime = Instant.MIN;
    private final FeatureManager featureManager;
    private final Cache<String, Instant> lastColdStartTime;
    private final CheckpointDao checkpointDao;
    private int coolDownMinutes;

    public EntityColdStarter(Clock clock, ThreadPool threadPool, NodeStateManager nodeStateManager, int i, int i2, double d, int i3, int i4, int i5, Interpolator interpolator, SearchFeatureDao searchFeatureDao, int i6, double d2, double d3, double d4, int i7, int i8, long j, FeatureManager featureManager, Duration duration, long j2, CheckpointDao checkpointDao, Settings settings) {
        this.clock = clock;
        this.threadPool = threadPool;
        this.nodeStateManager = nodeStateManager;
        this.rcfSampleSize = i;
        this.numberOfTrees = i2;
        this.rcfTimeDecay = d;
        this.numMinSamples = i3;
        this.maxSampleStride = i4;
        this.maxTrainSamples = i5;
        this.interpolator = interpolator;
        this.searchFeatureDao = searchFeatureDao;
        this.shingleSize = i6;
        this.thresholdMinPvalue = d2;
        this.thresholdMaxRankError = d3;
        this.thresholdMaxScore = d4;
        this.thresholdNumLogNormalQuantiles = i7;
        this.thresholdDownsamples = i8;
        this.thresholdMaxSamples = j;
        this.featureManager = featureManager;
        this.lastColdStartTime = CacheBuilder.newBuilder().expireAfterAccess(duration.toHours(), TimeUnit.HOURS).maximumSize(j2).concurrencyLevel(1).build();
        this.checkpointDao = checkpointDao;
        this.coolDownMinutes = (int) ((TimeValue) AnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings)).getMinutes();
    }

    private void coldStart(String str, String str2, String str3, ModelState<EntityModel> modelState) {
        if (!this.nodeStateManager.isColdStartRunning(str3) && this.lastColdStartTime.getIfPresent(str) == null && this.lastThrottledColdStartTime.plus((TemporalAmount) Duration.ofMinutes(this.coolDownMinutes)).isBefore(this.clock.instant())) {
            Releasable markColdStartRunning = this.nodeStateManager.markColdStartRunning(str3);
            logger.debug("Trigger cold start for {}", str);
            ActionListener wrap = ActionListener.wrap(optional -> {
                if (!optional.isPresent()) {
                    logger.info("Cannot get training data for {}", str);
                    return;
                }
                List<double[][]> list = (List) optional.get();
                if (hasEnoughSample(list, modelState)) {
                    trainModelFromDataSegments(list, str, modelState);
                } else {
                    combineTrainSamples(list, str, modelState);
                }
                logger.info("Succeeded in training entity: {}", str);
            }, exc -> {
                Throwable rootCause = Throwables.getRootCause(exc);
                if (rootCause instanceof RejectedExecutionException) {
                    logger.error("too many requests");
                    this.lastThrottledColdStartTime = Instant.now();
                } else if ((rootCause instanceof AnomalyDetectionException) || (exc instanceof AnomalyDetectionException)) {
                    this.nodeStateManager.setLastColdStartException(str3, (AnomalyDetectionException) exc);
                } else {
                    logger.error(new ParameterizedMessage("Error while cold start {}", str), exc);
                }
            });
            Objects.requireNonNull(markColdStartRunning);
            ActionListener runAfter = ActionListener.runAfter(wrap, markColdStartRunning::close);
            this.threadPool.executor(AnomalyDetectorPlugin.AD_THREAD_POOL_NAME).execute(() -> {
                getEntityColdStartData(str3, str2, this.shingleSize, new ThreadedActionListener(logger, this.threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, runAfter, false));
            });
            this.lastColdStartTime.put(str, Instant.now());
        }
    }

    private void trainModelFromDataSegments(List<double[][]> list, String str, ModelState<EntityModel> modelState) {
        if (list == null || list.size() == 0 || list.get(0) == null || list.get(0).length == 0) {
            throw new IllegalArgumentException("Data points must not be empty.");
        }
        RandomCutForest build = RandomCutForest.builder().dimensions(list.get(0)[0].length).sampleSize(this.rcfSampleSize).numberOfTrees(this.numberOfTrees).lambda(this.rcfTimeDecay).outputAfter(this.numMinSamples).parallelExecutionEnabled(false).build();
        ArrayList<double[]> arrayList = new ArrayList();
        int i = 0;
        Iterator<double[][]> it = list.iterator();
        while (it.hasNext()) {
            double[] trainRCFModel = trainRCFModel(it.next(), str, build);
            arrayList.add(trainRCFModel);
            i += trainRCFModel.length;
        }
        EntityModel model = modelState.getModel();
        if (model == null) {
            model = new EntityModel(str, new ArrayDeque(), null, null);
        }
        model.setRcf(build);
        double[] dArr = new double[i];
        int i2 = 0;
        for (double[] dArr2 : arrayList) {
            System.arraycopy(dArr2, 0, dArr, i2, dArr2.length);
            i2 += dArr2.length;
        }
        HybridThresholdingModel hybridThresholdingModel = new HybridThresholdingModel(this.thresholdMinPvalue, this.thresholdMaxRankError, this.thresholdMaxScore, this.thresholdNumLogNormalQuantiles, this.thresholdDownsamples, this.thresholdMaxSamples);
        hybridThresholdingModel.train(dArr);
        model.setThreshold(hybridThresholdingModel);
        modelState.setLastUsedTime(this.clock.instant());
        this.checkpointDao.write(modelState, str, true);
    }

    private double[] trainRCFModel(double[][] dArr, String str, RandomCutForest randomCutForest) {
        if (dArr.length == 0 || dArr[0].length == 0) {
            throw new IllegalArgumentException("Data points must not be empty.");
        }
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = randomCutForest.getAnomalyScore(dArr[i]);
            randomCutForest.update(dArr[i]);
        }
        return DoubleStream.of(dArr2).filter(d -> {
            return d > 0.0d;
        }).toArray();
    }

    private void getEntityColdStartData(String str, String str2, int i, ActionListener<Optional<List<double[][]>>> actionListener) {
        CheckedConsumer checkedConsumer = optional -> {
            if (!optional.isPresent()) {
                this.nodeStateManager.setLastColdStartException(str, new EndRunException(str, "AnomalyDetector is not available.", true));
                return;
            }
            ArrayList arrayList = new ArrayList();
            AnomalyDetector anomalyDetector = (AnomalyDetector) optional.get();
            CheckedConsumer checkedConsumer2 = entry -> {
                Optional optional = (Optional) entry.getKey();
                Optional optional2 = (Optional) entry.getValue();
                if (!optional.isPresent() || !optional2.isPresent()) {
                    actionListener.onResponse(Optional.empty());
                    return;
                }
                List<Map.Entry<Long, Long>> trainSampleRanges = getTrainSampleRanges(anomalyDetector, ((Long) optional.get()).longValue(), ((Long) optional2.get()).longValue(), this.maxSampleStride, this.maxTrainSamples);
                CheckedConsumer checkedConsumer3 = list -> {
                    ArrayList arrayList2 = new ArrayList(this.maxTrainSamples);
                    for (int i2 = 0; i2 < list.size(); i2++) {
                        Optional optional3 = (Optional) list.get(i2);
                        if (optional3.isPresent()) {
                            arrayList2.add((double[]) optional3.get());
                        } else if (!arrayList2.isEmpty()) {
                            double[][] dArr = (double[][]) arrayList2.toArray(new double[0][0]);
                            arrayList.add(this.featureManager.batchShingle(this.featureManager.transpose(this.interpolator.interpolate(this.featureManager.transpose(dArr), (this.maxSampleStride * (dArr.length - 1)) + 1)), i));
                            arrayList2.clear();
                        }
                    }
                    if (!arrayList2.isEmpty()) {
                        double[][] dArr2 = (double[][]) arrayList2.toArray(new double[0][0]);
                        arrayList.add(this.featureManager.batchShingle(this.featureManager.transpose(this.interpolator.interpolate(this.featureManager.transpose(dArr2), (this.maxSampleStride * (dArr2.length - 1)) + 1)), i));
                    }
                    if (arrayList.isEmpty()) {
                        actionListener.onResponse(Optional.empty());
                    } else {
                        actionListener.onResponse(Optional.of(arrayList));
                    }
                };
                Objects.requireNonNull(actionListener);
                this.searchFeatureDao.getColdStartSamplesForPeriods(anomalyDetector, trainSampleRanges, str2, false, new ThreadedActionListener(logger, this.threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, ActionListener.wrap(checkedConsumer3, actionListener::onFailure), false));
            };
            Objects.requireNonNull(actionListener);
            this.searchFeatureDao.getEntityMinMaxDataTime(anomalyDetector, str2, new ThreadedActionListener(logger, this.threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, ActionListener.wrap(checkedConsumer2, actionListener::onFailure), false));
        };
        Objects.requireNonNull(actionListener);
        this.nodeStateManager.getAnomalyDetector(str, new ThreadedActionListener(logger, this.threadPool, AnomalyDetectorPlugin.AD_THREAD_POOL_NAME, ActionListener.wrap(checkedConsumer, actionListener::onFailure), false));
    }

    private List<Map.Entry<Long, Long>> getTrainSampleRanges(AnomalyDetector anomalyDetector, long j, long j2, int i, int i2) {
        long millis = ((IntervalTimeConfiguration) anomalyDetector.getDetectionInterval()).toDuration().toMillis();
        return (List) Stream.iterate(Long.valueOf(j2), l -> {
            return Long.valueOf(l.longValue() - (i * millis));
        }).limit(Math.min((int) Math.floor(((int) Math.floor((j2 - j) / millis)) / i), i2)).map(l2 -> {
            return new AbstractMap.SimpleImmutableEntry(Long.valueOf(l2.longValue() - millis), l2);
        }).collect(Collectors.toList());
    }

    public void trainModel(Queue<double[]> queue, String str, String str2, String str3, ModelState<EntityModel> modelState) {
        if (queue.size() < this.numMinSamples) {
            coldStart(str, str2, str3, modelState);
        } else {
            trainModelFromDataSegments(Collections.singletonList(this.featureManager.batchShingle((double[][]) queue.toArray(new double[0][0]), this.shingleSize)), str, modelState);
        }
    }

    private boolean hasEnoughSample(List<double[][]> list, ModelState<EntityModel> modelState) {
        int i = 0;
        Iterator<double[][]> it = list.iterator();
        while (it.hasNext()) {
            i += it.next().length;
        }
        EntityModel model = modelState.getModel();
        if (model != null) {
            i += model.getSamples().size();
        }
        return i >= this.numMinSamples;
    }

    private void combineTrainSamples(List<double[][]> list, String str, ModelState<EntityModel> modelState) {
        EntityModel model = modelState.getModel();
        if (model == null) {
            model = new EntityModel(str, new ArrayDeque(), null, null);
        }
        for (double[][] dArr : list) {
            for (double[] dArr2 : dArr) {
                model.addSample(dArr2);
            }
        }
        this.checkpointDao.write(modelState, str, true);
    }
}
