package com.amazon.opendistroforelasticsearch.ad.task;

import com.amazon.opendistroforelasticsearch.ad.MemoryTracker;
import com.amazon.opendistroforelasticsearch.ad.common.exception.DuplicateTaskException;
import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.constant.CommonErrorMessages;
import com.amazon.opendistroforelasticsearch.ad.ml.ThresholdingModel;
import com.amazon.opendistroforelasticsearch.ad.model.ADTask;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings;
import com.amazon.randomcutforest.RandomCutForest;
import java.util.Deque;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.set.Sets;

/* loaded from: input_file:com/amazon/opendistroforelasticsearch/ad/task/ADTaskCacheManager.class */
public class ADTaskCacheManager {
    private final Map<String, ADBatchTaskCache> taskCaches;
    private volatile Integer maxAdBatchTaskPerNode;
    private final MemoryTracker memoryTracker;
    private Set<String> detectors;
    private final Logger logger = LogManager.getLogger(ADTaskCacheManager.class);
    private final int numberSize = 8;

    public ADTaskCacheManager(Settings settings, ClusterService clusterService, MemoryTracker memoryTracker) {
        this.maxAdBatchTaskPerNode = (Integer) AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, num -> {
            this.maxAdBatchTaskPerNode = num;
        });
        this.taskCaches = new ConcurrentHashMap();
        this.memoryTracker = memoryTracker;
        this.detectors = Sets.newConcurrentHashSet();
    }

    public synchronized void add(ADTask aDTask) {
        String taskId = aDTask.getTaskId();
        if (contains(taskId)) {
            throw new DuplicateTaskException(CommonErrorMessages.DETECTOR_IS_RUNNING);
        }
        if (containsTaskOfDetector(aDTask.getDetectorId())) {
            throw new DuplicateTaskException(CommonErrorMessages.DETECTOR_IS_RUNNING);
        }
        checkRunningTaskLimit();
        long calculateADTaskCacheSize = calculateADTaskCacheSize(aDTask);
        if (!this.memoryTracker.canAllocateReserved(aDTask.getDetectorId(), calculateADTaskCacheSize)) {
            throw new LimitExceededException("No enough memory to run detector");
        }
        this.memoryTracker.consumeMemory(calculateADTaskCacheSize, true, MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR);
        ADBatchTaskCache aDBatchTaskCache = new ADBatchTaskCache(aDTask);
        aDBatchTaskCache.getCacheMemorySize().set(calculateADTaskCacheSize);
        this.taskCaches.put(taskId, aDBatchTaskCache);
    }

    public synchronized void add(String str) {
        if (this.detectors.contains(str)) {
            this.logger.debug("detector is already in running detector cache, detectorId: " + str);
            throw new DuplicateTaskException(CommonErrorMessages.DETECTOR_IS_RUNNING);
        }
        this.logger.debug("add detector in running detector cache, detectorId: " + str);
        this.detectors.add(str);
    }

    public void checkRunningTaskLimit() {
        if (size() >= this.maxAdBatchTaskPerNode.intValue()) {
            throw new LimitExceededException("Can't run more than " + this.maxAdBatchTaskPerNode + " historical detectors per data node");
        }
    }

    public RandomCutForest getRcfModel(String str) {
        return getBatchTaskCache(str).getRcfModel();
    }

    public ThresholdingModel getThresholdModel(String str) {
        return getBatchTaskCache(str).getThresholdModel();
    }

    public double[] getThresholdModelTrainingData(String str) {
        return getBatchTaskCache(str).getThresholdModelTrainingData();
    }

    public int getThresholdModelTrainingDataSize(String str) {
        return getBatchTaskCache(str).getThresholdModelTrainingDataSize().get();
    }

    public int addThresholdModelTrainingData(String str, double... dArr) {
        ADBatchTaskCache batchTaskCache = getBatchTaskCache(str);
        double[] thresholdModelTrainingData = batchTaskCache.getThresholdModelTrainingData();
        AtomicInteger thresholdModelTrainingDataSize = batchTaskCache.getThresholdModelTrainingDataSize();
        int min = Math.min(dArr.length, AnomalyDetectorSettings.THRESHOLD_MODEL_TRAINING_SIZE - thresholdModelTrainingDataSize.get());
        System.arraycopy(dArr, 0, thresholdModelTrainingData, thresholdModelTrainingDataSize.get(), min);
        return thresholdModelTrainingDataSize.addAndGet(min);
    }

    public boolean isThresholdModelTrained(String str) {
        return getBatchTaskCache(str).isThresholdModelTrained();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setThresholdModelTrained(String str, boolean z) {
        ADBatchTaskCache batchTaskCache = getBatchTaskCache(str);
        batchTaskCache.setThresholdModelTrained(z);
        if (z) {
            long trainingDataMemorySize = trainingDataMemorySize(batchTaskCache.getThresholdModelTrainingDataSize().get());
            batchTaskCache.clearTrainingData();
            batchTaskCache.getCacheMemorySize().getAndAdd(-trainingDataMemorySize);
            this.memoryTracker.releaseMemory(trainingDataMemorySize, true, MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR);
        }
    }

    public Deque<Map.Entry<Long, Optional<double[]>>> getShingle(String str) {
        return getBatchTaskCache(str).getShingle();
    }

    public boolean contains(String str) {
        return this.taskCaches.containsKey(str);
    }

    public boolean containsTaskOfDetector(String str) {
        return this.taskCaches.values().stream().filter(aDBatchTaskCache -> {
            return Objects.equals(str, aDBatchTaskCache.getDetectorId());
        }).findAny().isPresent();
    }

    public List<String> getTasksOfDetector(String str) {
        return (List) this.taskCaches.values().stream().filter(aDBatchTaskCache -> {
            return Objects.equals(str, aDBatchTaskCache.getDetectorId());
        }).map(aDBatchTaskCache2 -> {
            return aDBatchTaskCache2.getTaskId();
        }).collect(Collectors.toList());
    }

    private ADBatchTaskCache getBatchTaskCache(String str) {
        if (contains(str)) {
            return this.taskCaches.get(str);
        }
        throw new IllegalArgumentException("AD task not in cache");
    }

    private List<ADBatchTaskCache> getBatchTaskCacheByDetectorId(String str) {
        return (List) this.taskCaches.values().stream().filter(aDBatchTaskCache -> {
            return Objects.equals(str, aDBatchTaskCache.getDetectorId());
        }).collect(Collectors.toList());
    }

    private long calculateADTaskCacheSize(ADTask aDTask) {
        AnomalyDetector detector = aDTask.getDetector();
        return this.memoryTracker.estimateModelSize(detector, 100) + trainingDataMemorySize(AnomalyDetectorSettings.THRESHOLD_MODEL_TRAINING_SIZE) + shingleMemorySize(detector.getShingleSize().intValue(), detector.getEnabledFeatureIds().size());
    }

    public long getModelSize(String str) {
        ADBatchTaskCache batchTaskCache = getBatchTaskCache(str);
        return this.memoryTracker.estimateModelSize(batchTaskCache.getRcfModel().getDimensions(), batchTaskCache.getRcfModel().getNumberOfTrees(), AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE);
    }

    public void remove(String str) {
        if (contains(str)) {
            this.memoryTracker.releaseMemory(getBatchTaskCache(str).getCacheMemorySize().get(), true, MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR);
            this.taskCaches.remove(str);
        }
    }

    public void removeDetector(String str) {
        if (!this.detectors.contains(str)) {
            this.logger.debug("Detector is not in AD task coordinating node cache");
        } else {
            this.detectors.remove(str);
            this.logger.debug("Removed detector from AD task coordinating node cache, detectorId: " + str);
        }
    }

    public ADTaskCancellationState cancel(String str, String str2, String str3) {
        if (!contains(str)) {
            return ADTaskCancellationState.NOT_FOUND;
        }
        if (isCancelled(str)) {
            return ADTaskCancellationState.ALREADY_CANCELLED;
        }
        getBatchTaskCache(str).cancel(str2, str3);
        return ADTaskCancellationState.CANCELLED;
    }

    public ADTaskCancellationState cancelByDetectorId(String str, String str2, String str3) {
        List<ADBatchTaskCache> batchTaskCacheByDetectorId = getBatchTaskCacheByDetectorId(str);
        if (batchTaskCacheByDetectorId.isEmpty()) {
            return ADTaskCancellationState.NOT_FOUND;
        }
        ADTaskCancellationState aDTaskCancellationState = ADTaskCancellationState.ALREADY_CANCELLED;
        for (ADBatchTaskCache aDBatchTaskCache : batchTaskCacheByDetectorId) {
            if (!aDBatchTaskCache.isCancelled()) {
                aDTaskCancellationState = ADTaskCancellationState.CANCELLED;
                aDBatchTaskCache.cancel(str2, str3);
            }
        }
        return aDTaskCancellationState;
    }

    public boolean isCancelled(String str) {
        return getBatchTaskCache(str).isCancelled();
    }

    public String getCancelReason(String str) {
        return getBatchTaskCache(str).getCancelReason();
    }

    public String getCancelledBy(String str) {
        return getBatchTaskCache(str).getCancelledBy();
    }

    public int size() {
        return this.taskCaches.size();
    }

    public void clear() {
        this.taskCaches.clear();
        this.detectors.clear();
    }

    public long trainingDataMemorySize(int i) {
        return 8 * i;
    }

    public long shingleMemorySize(int i, int i2) {
        return (80 + (8 * i2)) * i;
    }
}
