package com.amazon.opendistroforelasticsearch.ad.ml;

import com.amazon.opendistroforelasticsearch.ad.constant.CommonName;
import com.amazon.opendistroforelasticsearch.ad.indices.ADIndex;
import com.amazon.opendistroforelasticsearch.ad.indices.AnomalyDetectionIndices;
import com.amazon.opendistroforelasticsearch.ad.util.BulkUtil;
import com.amazon.opendistroforelasticsearch.ad.util.ClientUtil;
import com.amazon.randomcutforest.RandomCutForest;
import com.amazon.randomcutforest.serialize.RandomCutForestSerDe;
import com.google.common.util.concurrent.RateLimiter;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.security.AccessController;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.time.temporal.TemporalAmount;
import java.util.AbstractMap;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.ConcurrentModificationException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiConsumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.delete.DeleteRequest;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.reindex.BulkByScrollResponse;
import org.elasticsearch.index.reindex.DeleteByQueryAction;
import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.index.reindex.ScrollableHitSource;

/* loaded from: input_file:com/amazon/opendistroforelasticsearch/ad/ml/CheckpointDao.class */
public class CheckpointDao {
    private static final Logger logger = LogManager.getLogger(CheckpointDao.class);
    static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of";
    static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of";
    static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of";
    static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted";
    static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted.  Has nothing to do:";
    static final String NOT_ABLE_TO_DELETE_LOG_MSG = "Cannot delete all checkpoints of detector";
    public static final String ENTITY_SAMPLE = "sp";
    public static final String ENTITY_RCF = "rcf";
    public static final String ENTITY_THRESHOLD = "th";
    public static final String FIELD_MODEL = "model";
    public static final String TIMESTAMP = "timestamp";
    public static final String DETECTOR_ID = "detectorId";
    private final Client client;
    private final ClientUtil clientUtil;
    private final String indexName;
    private Gson gson;
    private RandomCutForestSerDe rcfSerde;
    private final Class<? extends ThresholdingModel> thresholdingModelClass;
    private final Duration checkpointInterval;
    private final Clock clock;
    private final AnomalyDetectionIndices indexUtil;
    private final RateLimiter bulkRateLimiter;
    private final int maxBulkRequestSize;
    private final JsonParser parser = new JsonParser();
    private ConcurrentLinkedQueue<DocWriteRequest<?>> requests = new ConcurrentLinkedQueue<>();
    private final ReentrantLock lock = new ReentrantLock();

    public CheckpointDao(Client client, ClientUtil clientUtil, String str, Gson gson, RandomCutForestSerDe randomCutForestSerDe, Class<? extends ThresholdingModel> cls, Clock clock, Duration duration, AnomalyDetectionIndices anomalyDetectionIndices, int i, double d) {
        this.client = client;
        this.clientUtil = clientUtil;
        this.indexName = str;
        this.gson = gson;
        this.rcfSerde = randomCutForestSerDe;
        this.thresholdingModelClass = cls;
        this.clock = clock;
        this.checkpointInterval = duration;
        this.indexUtil = anomalyDetectionIndices;
        this.maxBulkRequestSize = i;
        this.bulkRateLimiter = RateLimiter.create(d);
    }

    @Deprecated
    public void putModelCheckpoint(String str, String str2) {
        HashMap hashMap = new HashMap();
        hashMap.put("model", str2);
        hashMap.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC));
        if (this.indexUtil.doesCheckpointIndexExist()) {
            saveModelCheckpointSync(hashMap, str);
        } else {
            onCheckpointNotExist(hashMap, str, false, null);
        }
    }

    private void saveModelCheckpointSync(Map<String, Object> map, String str) {
        ClientUtil clientUtil = this.clientUtil;
        IndexRequest source = new IndexRequest(this.indexName).id(str).source(map);
        Logger logger2 = logger;
        Client client = this.client;
        Objects.requireNonNull(client);
        clientUtil.timedRequest(source, logger2, client::index);
    }

    public void putModelCheckpoint(String str, String str2, ActionListener<Void> actionListener) {
        HashMap hashMap = new HashMap();
        hashMap.put("model", str2);
        hashMap.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC));
        if (this.indexUtil.doesCheckpointIndexExist()) {
            saveModelCheckpointAsync(hashMap, str, actionListener);
        } else {
            onCheckpointNotExist(hashMap, str, true, actionListener);
        }
    }

    private void onCheckpointNotExist(Map<String, Object> map, String str, boolean z, ActionListener<Void> actionListener) {
        this.indexUtil.initCheckpointIndex(ActionListener.wrap(createIndexResponse -> {
            if (!createIndexResponse.isAcknowledged()) {
                throw new RuntimeException("Creating checkpoint with mappings call not acknowledged.");
            }
            if (z) {
                saveModelCheckpointAsync(map, str, actionListener);
            } else {
                saveModelCheckpointSync(map, str);
            }
        }, exc -> {
            if (!(ExceptionsHelper.unwrapCause(exc) instanceof ResourceAlreadyExistsException)) {
                logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", this.indexName), exc);
            } else if (z) {
                saveModelCheckpointAsync(map, str, actionListener);
            } else {
                saveModelCheckpointSync(map, str);
            }
        }));
    }

    private void saveModelCheckpointAsync(Map<String, Object> map, String str, ActionListener<Void> actionListener) {
        ClientUtil clientUtil = this.clientUtil;
        IndexRequest source = new IndexRequest(this.indexName).id(str).source(map);
        Client client = this.client;
        Objects.requireNonNull(client);
        BiConsumer biConsumer = client::index;
        CheckedConsumer checkedConsumer = indexResponse -> {
            actionListener.onResponse((Object) null);
        };
        Objects.requireNonNull(actionListener);
        clientUtil.asyncRequest(source, biConsumer, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    public void flush() {
        DocWriteRequest<?> poll;
        boolean isHeldByCurrentThread;
        try {
            if (!this.lock.tryLock()) {
                if (isHeldByCurrentThread) {
                    return;
                } else {
                    return;
                }
            }
            if (this.requests.size() > 0 && this.bulkRateLimiter.tryAcquire()) {
                BulkRequest bulkRequest = new BulkRequest();
                for (int i = 0; i < this.maxBulkRequestSize && (poll = this.requests.poll()) != null; i++) {
                    bulkRequest.add(poll);
                }
                if (this.indexUtil.doesCheckpointIndexExist()) {
                    flush(bulkRequest);
                } else {
                    this.indexUtil.initCheckpointIndex(ActionListener.wrap(createIndexResponse -> {
                        if (!createIndexResponse.isAcknowledged()) {
                            throw new RuntimeException("Creating checkpoint with mappings call not acknowledged.");
                        }
                        flush(bulkRequest);
                    }, exc -> {
                        if (ExceptionsHelper.unwrapCause(exc) instanceof ResourceAlreadyExistsException) {
                            flush(bulkRequest);
                        } else {
                            logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", this.indexName), exc);
                        }
                    }));
                }
            }
            if (this.lock.isHeldByCurrentThread()) {
                this.lock.unlock();
            }
        } finally {
            if (this.lock.isHeldByCurrentThread()) {
                this.lock.unlock();
            }
        }
    }

    private void flush(BulkRequest bulkRequest) {
        this.clientUtil.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> {
            if (bulkResponse.hasFailures()) {
                this.requests.addAll(BulkUtil.getIndexRequestToRetry(bulkRequest, bulkResponse));
            } else if (this.requests.size() >= this.maxBulkRequestSize / 2) {
                flush();
            }
        }, exc -> {
            logger.error("Failed bulking checkpoints", exc);
            Iterator it = bulkRequest.requests().iterator();
            while (it.hasNext()) {
                this.requests.add((DocWriteRequest) it.next());
            }
        }));
    }

    public void write(ModelState<EntityModel> modelState, String str) {
        write(modelState, str, false);
    }

    public void write(ModelState<EntityModel> modelState, String str, boolean z) {
        Instant lastCheckpointTime = modelState.getLastCheckpointTime();
        if (((lastCheckpointTime == Instant.MIN || lastCheckpointTime.plus((TemporalAmount) this.checkpointInterval).isAfter(this.clock.instant())) && !z) || modelState.getModel() == null) {
            return;
        }
        try {
            String checkpoint = toCheckpoint(modelState.getModel());
            HashMap hashMap = new HashMap();
            hashMap.put(DETECTOR_ID, modelState.getDetectorId());
            hashMap.put("model", checkpoint);
            hashMap.put(TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC));
            hashMap.put(CommonName.SCHEMA_VERSION_FIELD, Integer.valueOf(this.indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)));
            this.requests.add(new IndexRequest(this.indexName).id(str).source(hashMap));
            modelState.setLastCheckpointTime(this.clock.instant());
            if (this.requests.size() >= this.maxBulkRequestSize) {
                flush();
            }
        } catch (ConcurrentModificationException e) {
            logger.info(new ParameterizedMessage("Concurrent modification while serializing models for [{}]", str), e);
        }
    }

    @Deprecated
    public Optional<String> getModelCheckpoint(String str) {
        ClientUtil clientUtil = this.clientUtil;
        GetRequest getRequest = new GetRequest(this.indexName, str);
        Logger logger2 = logger;
        Client client = this.client;
        Objects.requireNonNull(client);
        return clientUtil.timedRequest(getRequest, logger2, client::get).filter((v0) -> {
            return v0.isExists();
        }).map((v0) -> {
            return v0.getSource();
        }).map(map -> {
            return (String) map.get("model");
        });
    }

    String toCheckpoint(EntityModel entityModel) {
        return (String) AccessController.doPrivileged(() -> {
            JsonObject jsonObject = new JsonObject();
            jsonObject.add(ENTITY_SAMPLE, this.gson.toJsonTree(entityModel.getSamples()));
            if (entityModel.getRcf() != null) {
                jsonObject.addProperty(ENTITY_RCF, this.rcfSerde.toJson(entityModel.getRcf()));
            }
            if (entityModel.getThreshold() != null) {
                jsonObject.addProperty(ENTITY_THRESHOLD, this.gson.toJson(entityModel.getThreshold()));
            }
            return this.gson.toJson(jsonObject);
        });
    }

    @Deprecated
    public void deleteModelCheckpoint(String str) {
        ClientUtil clientUtil = this.clientUtil;
        DeleteRequest deleteRequest = new DeleteRequest(this.indexName, str);
        Logger logger2 = logger;
        Client client = this.client;
        Objects.requireNonNull(client);
        clientUtil.timedRequest(deleteRequest, logger2, client::delete);
    }

    public void deleteModelCheckpoint(String str, ActionListener<Void> actionListener) {
        ClientUtil clientUtil = this.clientUtil;
        DeleteRequest deleteRequest = new DeleteRequest(this.indexName, str);
        Client client = this.client;
        Objects.requireNonNull(client);
        BiConsumer biConsumer = client::delete;
        CheckedConsumer checkedConsumer = deleteResponse -> {
            actionListener.onResponse((Object) null);
        };
        Objects.requireNonNull(actionListener);
        clientUtil.asyncRequest(deleteRequest, biConsumer, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    public void deleteModelCheckpointByDetectorId(String str) {
        DeleteByQueryRequest requestsPerSecond = new DeleteByQueryRequest(new String[]{CommonName.CHECKPOINT_INDEX_NAME}).setQuery(new MatchQueryBuilder(DETECTOR_ID, str)).setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN).setAbortOnVersionConflict(false).setRequestsPerSecond(500.0f);
        logger.info("Delete checkpoints of detector {}", str);
        this.client.execute(DeleteByQueryAction.INSTANCE, requestsPerSecond, ActionListener.wrap(bulkByScrollResponse -> {
            if (bulkByScrollResponse.isTimedOut() || !bulkByScrollResponse.getBulkFailures().isEmpty() || !bulkByScrollResponse.getSearchFailures().isEmpty()) {
                logFailure(bulkByScrollResponse, str);
            }
            logger.info("{} checkpoints docs get deleted", Long.valueOf(bulkByScrollResponse.getDeleted()));
        }, exc -> {
            if (exc instanceof IndexNotFoundException) {
                logger.info("Checkpoint index has been deleted.  Has nothing to do: {}", str);
            } else {
                logger.error(NOT_ABLE_TO_DELETE_LOG_MSG, exc);
            }
        }));
    }

    private void logFailure(BulkByScrollResponse bulkByScrollResponse, String str) {
        if (bulkByScrollResponse.isTimedOut()) {
            logger.warn("Timeout while deleting checkpoints of {}", str);
            return;
        }
        if (bulkByScrollResponse.getBulkFailures().isEmpty()) {
            logger.warn("Search failure while deleting checkpoints of {}", str);
            Iterator it = bulkByScrollResponse.getSearchFailures().iterator();
            while (it.hasNext()) {
                logger.warn((ScrollableHitSource.SearchFailure) it.next());
            }
            return;
        }
        logger.warn("Bulk failure while deleting checkpoints of {}", str);
        Iterator it2 = bulkByScrollResponse.getBulkFailures().iterator();
        while (it2.hasNext()) {
            logger.warn((BulkItemResponse.Failure) it2.next());
        }
    }

    private Map.Entry<EntityModel, Instant> fromEntityModelCheckpoint(Map<String, Object> map, String str) {
        try {
            return (Map.Entry) AccessController.doPrivileged(() -> {
                JsonObject asJsonObject = this.parser.parse((String) map.get("model")).getAsJsonObject();
                ArrayDeque arrayDeque = new ArrayDeque(Arrays.asList((double[][]) this.gson.fromJson(asJsonObject.getAsJsonArray(ENTITY_SAMPLE), new double[0][0].getClass())));
                RandomCutForest randomCutForest = null;
                if (asJsonObject.has(ENTITY_RCF)) {
                    randomCutForest = this.rcfSerde.fromJson(asJsonObject.getAsJsonPrimitive(ENTITY_RCF).getAsString());
                }
                ThresholdingModel thresholdingModel = null;
                if (asJsonObject.has(ENTITY_THRESHOLD)) {
                    thresholdingModel = (ThresholdingModel) this.gson.fromJson(asJsonObject.getAsJsonPrimitive(ENTITY_THRESHOLD).getAsString(), this.thresholdingModelClass);
                }
                return new AbstractMap.SimpleImmutableEntry(new EntityModel(str, arrayDeque, randomCutForest, thresholdingModel), Instant.parse((String) map.get(TIMESTAMP)));
            });
        } catch (RuntimeException e) {
            logger.warn("Exception while deserializing checkpoint", e);
            throw e;
        }
    }

    public void restoreModelCheckpoint(String str, ActionListener<Optional<Map.Entry<EntityModel, Instant>>> actionListener) {
        ClientUtil clientUtil = this.clientUtil;
        GetRequest getRequest = new GetRequest(this.indexName, str);
        Client client = this.client;
        Objects.requireNonNull(client);
        BiConsumer biConsumer = client::get;
        CheckedConsumer checkedConsumer = getResponse -> {
            Optional<Map<String, Object>> processRawCheckpoint = processRawCheckpoint(getResponse);
            if (processRawCheckpoint.isPresent()) {
                actionListener.onResponse(Optional.of(fromEntityModelCheckpoint(processRawCheckpoint.get(), str)));
            } else {
                actionListener.onResponse(Optional.empty());
            }
        };
        Objects.requireNonNull(actionListener);
        clientUtil.asyncRequest(getRequest, biConsumer, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    public void getModelCheckpoint(String str, ActionListener<Optional<String>> actionListener) {
        ClientUtil clientUtil = this.clientUtil;
        GetRequest getRequest = new GetRequest(this.indexName, str);
        Client client = this.client;
        Objects.requireNonNull(client);
        BiConsumer biConsumer = client::get;
        CheckedConsumer checkedConsumer = getResponse -> {
            actionListener.onResponse(processModelCheckpoint(getResponse));
        };
        Objects.requireNonNull(actionListener);
        clientUtil.asyncRequest(getRequest, biConsumer, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private Optional<String> processModelCheckpoint(GetResponse getResponse) {
        return Optional.ofNullable(getResponse).filter((v0) -> {
            return v0.isExists();
        }).map((v0) -> {
            return v0.getSource();
        }).map(map -> {
            return (String) map.get("model");
        });
    }

    private Optional<Map<String, Object>> processRawCheckpoint(GetResponse getResponse) {
        return Optional.ofNullable(getResponse).filter((v0) -> {
            return v0.isExists();
        }).map((v0) -> {
            return v0.getSource();
        });
    }
}
