package com.amazon.opendistroforelasticsearch.ad;

import com.amazon.opendistroforelasticsearch.ad.common.exception.LimitExceededException;
import com.amazon.opendistroforelasticsearch.ad.model.AnomalyDetector;
import com.amazon.opendistroforelasticsearch.ad.settings.AnomalyDetectorSettings;
import com.amazon.randomcutforest.RandomCutForest;
import java.util.EnumMap;
import java.util.Locale;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.monitor.jvm.JvmService;

/* loaded from: input_file:com/amazon/opendistroforelasticsearch/ad/MemoryTracker.class */
public class MemoryTracker {
    private static final Logger LOG = LogManager.getLogger(MemoryTracker.class);
    private long totalMemoryBytes = 0;
    private final Map<Origin, Long> totalMemoryBytesByOrigin = new EnumMap(Origin.class);
    private long reservedMemoryBytes = 0;
    private final Map<Origin, Long> reservedMemoryBytesByOrigin = new EnumMap(Origin.class);
    private long heapSize;
    private long heapLimitBytes;
    private long desiredModelSize;
    private int thresholdModelBytes;
    private int sampleSize;

    /* loaded from: input_file:com/amazon/opendistroforelasticsearch/ad/MemoryTracker$Origin.class */
    public enum Origin {
        SINGLE_ENTITY_DETECTOR,
        MULTI_ENTITY_DETECTOR,
        HISTORICAL_SINGLE_ENTITY_DETECTOR
    }

    public MemoryTracker(JvmService jvmService, double d, double d2, ClusterService clusterService, int i) {
        this.heapSize = jvmService.info().getMem().getHeapMax().getBytes();
        this.heapLimitBytes = (long) (this.heapSize * d);
        this.desiredModelSize = (long) (this.heapSize * d2);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, d3 -> {
            this.heapLimitBytes = (long) (this.heapSize * d3.doubleValue());
        });
        this.thresholdModelBytes = 180000;
        this.sampleSize = i;
    }

    public synchronized boolean isHostingAllowed(String str, RandomCutForest randomCutForest) {
        return canAllocateReserved(str, estimateModelSize(randomCutForest));
    }

    public synchronized boolean canAllocateReserved(String str, long j) {
        if (this.reservedMemoryBytes + j <= this.heapLimitBytes) {
            return true;
        }
        throw new LimitExceededException(str, String.format(Locale.ROOT, "Exceeded memory limit. New size is %d bytes and max limit is %d bytes", Long.valueOf(this.reservedMemoryBytes + j), Long.valueOf(this.heapLimitBytes)));
    }

    public synchronized boolean canAllocate(long j) {
        return this.totalMemoryBytes + j <= this.heapLimitBytes;
    }

    public synchronized void consumeMemory(long j, boolean z, Origin origin) {
        this.totalMemoryBytes += j;
        adjustOriginMemoryConsumption(j, origin, this.totalMemoryBytesByOrigin);
        if (z) {
            this.reservedMemoryBytes += j;
            adjustOriginMemoryConsumption(j, origin, this.reservedMemoryBytesByOrigin);
        }
    }

    private void adjustOriginMemoryConsumption(long j, Origin origin, Map<Origin, Long> map) {
        map.put(origin, Long.valueOf(map.getOrDefault(origin, 0L).longValue() + j));
    }

    public synchronized void releaseMemory(long j, boolean z, Origin origin) {
        this.totalMemoryBytes -= j;
        adjustOriginMemoryRelease(j, origin, this.totalMemoryBytesByOrigin);
        if (z) {
            this.reservedMemoryBytes -= j;
            adjustOriginMemoryRelease(j, origin, this.reservedMemoryBytesByOrigin);
        }
    }

    private void adjustOriginMemoryRelease(long j, Origin origin, Map<Origin, Long> map) {
        Long l = map.get(origin);
        if (l != null) {
            map.put(origin, Long.valueOf(l.longValue() - j));
        }
    }

    public long estimateModelSize(RandomCutForest randomCutForest) {
        return estimateModelSize(randomCutForest.getDimensions(), randomCutForest.getNumberOfTrees(), randomCutForest.getSampleSize());
    }

    public long estimateModelSize(AnomalyDetector anomalyDetector, int i) {
        return estimateModelSize(anomalyDetector.getEnabledFeatureIds().size() * anomalyDetector.getShingleSize().intValue(), i, this.sampleSize);
    }

    public long estimateModelSize(int i, int i2, int i3) {
        long j = i2 * i3;
        return (j * ((40 * i) + 132)) + (j * 36) + this.thresholdModelBytes;
    }

    public synchronized long memoryToShed() {
        return this.totalMemoryBytes - this.heapLimitBytes;
    }

    public long getHeapLimit() {
        return this.heapLimitBytes;
    }

    public long getDesiredModelSize() {
        return this.desiredModelSize;
    }

    public long getTotalMemoryBytes() {
        return this.totalMemoryBytes;
    }

    public synchronized boolean syncMemoryState(Origin origin, long j, long j2) {
        long longValue = this.totalMemoryBytesByOrigin.getOrDefault(origin, 0L).longValue();
        long longValue2 = this.reservedMemoryBytesByOrigin.getOrDefault(origin, 0L).longValue();
        if (j == longValue && j2 == longValue2) {
            return false;
        }
        LOG.info(String.format(Locale.ROOT, "Memory states do not match.  Recorded: total bytes %d, reserved bytes %d.Actual: total bytes %d, reserved bytes: %d", Long.valueOf(longValue), Long.valueOf(longValue2), Long.valueOf(j), Long.valueOf(j2)));
        this.reservedMemoryBytesByOrigin.put(origin, Long.valueOf(j2));
        this.reservedMemoryBytes += j2 - longValue2;
        this.totalMemoryBytesByOrigin.put(origin, Long.valueOf(j));
        this.totalMemoryBytes += j - longValue;
        return true;
    }

    public int getThresholdModelBytes() {
        return this.thresholdModelBytes;
    }
}
