/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.action;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;

public class StartTrainedModelDeploymentAction
extends ActionType<CreateTrainedModelAssignmentAction.Response> {
    public static final StartTrainedModelDeploymentAction INSTANCE = new StartTrainedModelDeploymentAction();
    public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/start";
    public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(30L, TimeUnit.SECONDS);
    private static final ByteSizeValue MEMORY_OVERHEAD = ByteSizeValue.ofMb((long)240L);

    public StartTrainedModelDeploymentAction() {
        super(NAME, CreateTrainedModelAssignmentAction.Response::new);
    }

    public static long estimateMemoryUsageBytes(long totalDefinitionLength) {
        return MEMORY_OVERHEAD.getBytes() + 2L * totalDefinitionLength;
    }

    public static interface TaskMatcher {
        public static boolean match(Task task, String expectedId) {
            if (task instanceof TaskMatcher) {
                if (Strings.isAllOrWildcard((String)expectedId)) {
                    return true;
                }
                String expectedDescription = MlTasks.trainedModelAssignmentTaskDescription(expectedId);
                return expectedDescription.equals(task.getDescription());
            }
            return false;
        }
    }

    public static class TaskParams
    implements MlTaskParams,
    Writeable,
    ToXContentObject {
        public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
        private static final ParseField MODEL_BYTES = new ParseField("model_bytes", new String[0]);
        public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", new String[0]);
        public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", new String[0]);
        private static final ParseField LEGACY_MODEL_THREADS = new ParseField("model_threads", new String[0]);
        public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads", new String[0]);
        public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity", new String[0]);
        public static final ParseField CACHE_SIZE = new ParseField("cache_size", new String[0]);
        public static final ParseField PRIORITY = new ParseField("priority", new String[0]);
        private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser("trained_model_deployment_params", true, a -> new TaskParams((String)a[0], (Long)a[1], (Integer)a[2], (Integer)a[3], (Integer)a[4], (ByteSizeValue)a[5], (Integer)a[6], (Integer)a[7], a[8] == null ? null : Priority.fromString((String)a[8])));
        private final String modelId;
        private final ByteSizeValue cacheSize;
        private final long modelBytes;
        private final int threadsPerAllocation;
        private final int numberOfAllocations;
        private final int queueCapacity;
        private final Priority priority;

        public static boolean mayAssignToNode(DiscoveryNode node) {
            return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE) && node.getVersion().onOrAfter(VERSION_INTRODUCED);
        }

        public static TaskParams fromXContent(XContentParser parser) {
            return (TaskParams)PARSER.apply(parser, null);
        }

        private TaskParams(String modelId, long modelBytes, Integer numberOfAllocations, Integer threadsPerAllocation, int queueCapacity, ByteSizeValue cacheSizeValue, Integer legacyModelThreads, Integer legacyInferenceThreads, Priority priority) {
            this(modelId, modelBytes, numberOfAllocations == null ? legacyModelThreads : numberOfAllocations, threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation, queueCapacity, cacheSizeValue, priority == null ? Priority.NORMAL : priority);
        }

        public TaskParams(String modelId, long modelBytes, int numberOfAllocations, int threadsPerAllocation, int queueCapacity, @Nullable ByteSizeValue cacheSize, Priority priority) {
            this.modelId = Objects.requireNonNull(modelId);
            this.modelBytes = modelBytes;
            this.threadsPerAllocation = threadsPerAllocation;
            this.numberOfAllocations = numberOfAllocations;
            this.queueCapacity = queueCapacity;
            this.cacheSize = cacheSize;
            this.priority = Objects.requireNonNull(priority);
        }

        public TaskParams(StreamInput in) throws IOException {
            this.modelId = in.readString();
            this.modelBytes = in.readLong();
            this.threadsPerAllocation = in.readVInt();
            this.numberOfAllocations = in.readVInt();
            this.queueCapacity = in.readVInt();
            this.cacheSize = in.getTransportVersion().onOrAfter(TransportVersion.V_8_4_0) ? (ByteSizeValue)in.readOptionalWriteable(ByteSizeValue::readFrom) : null;
            this.priority = in.getTransportVersion().onOrAfter(TransportVersion.V_8_6_0) ? (Priority)in.readEnum(Priority.class) : Priority.NORMAL;
        }

        public String getModelId() {
            return this.modelId;
        }

        public long estimateMemoryUsageBytes() {
            if (this.cacheSize != null && this.cacheSize.getBytes() > this.modelBytes) {
                return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(this.modelBytes) + (this.cacheSize.getBytes() - this.modelBytes);
            }
            return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(this.modelBytes);
        }

        public Version getMinimalSupportedVersion() {
            return VERSION_INTRODUCED;
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeString(this.modelId);
            out.writeLong(this.modelBytes);
            out.writeVInt(this.threadsPerAllocation);
            out.writeVInt(this.numberOfAllocations);
            out.writeVInt(this.queueCapacity);
            if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_4_0)) {
                out.writeOptionalWriteable((Writeable)this.cacheSize);
            }
            if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_6_0)) {
                out.writeEnum((Enum)this.priority);
            }
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), this.modelId);
            builder.field(MODEL_BYTES.getPreferredName(), this.modelBytes);
            builder.field(THREADS_PER_ALLOCATION.getPreferredName(), this.threadsPerAllocation);
            builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), this.numberOfAllocations);
            builder.field(QUEUE_CAPACITY.getPreferredName(), this.queueCapacity);
            if (this.cacheSize != null) {
                builder.field(CACHE_SIZE.getPreferredName(), this.cacheSize.getStringRep());
            }
            builder.field(PRIORITY.getPreferredName(), (Enum)this.priority);
            builder.endObject();
            return builder;
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.modelId, this.modelBytes, this.threadsPerAllocation, this.numberOfAllocations, this.queueCapacity, this.cacheSize, this.priority});
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            TaskParams other = (TaskParams)o;
            return Objects.equals(this.modelId, other.modelId) && this.modelBytes == other.modelBytes && this.threadsPerAllocation == other.threadsPerAllocation && this.numberOfAllocations == other.numberOfAllocations && Objects.equals(this.cacheSize, other.cacheSize) && this.queueCapacity == other.queueCapacity && this.priority == other.priority;
        }

        @Override
        public String getMlId() {
            return this.modelId;
        }

        public long getModelBytes() {
            return this.modelBytes;
        }

        public int getThreadsPerAllocation() {
            return this.threadsPerAllocation;
        }

        public int getNumberOfAllocations() {
            return this.numberOfAllocations;
        }

        public int getQueueCapacity() {
            return this.queueCapacity;
        }

        public Optional<ByteSizeValue> getCacheSize() {
            return Optional.ofNullable(this.cacheSize);
        }

        public long getCacheSizeBytes() {
            return Optional.ofNullable(this.cacheSize).map(ByteSizeValue::getBytes).orElse(this.modelBytes);
        }

        public Priority getPriority() {
            return this.priority;
        }

        public String toString() {
            return Strings.toString((ToXContent)this);
        }

        static {
            PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
            PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
            PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ByteSizeValue.parseBytesSizeValue((String)p.text(), (String)CACHE_SIZE.getPreferredName()), CACHE_SIZE, ObjectParser.ValueType.VALUE);
            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
            PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PRIORITY);
        }
    }

    public static class Request
    extends MasterNodeRequest<Request>
    implements ToXContentObject {
        private static final AllocationStatus.State[] VALID_WAIT_STATES = new AllocationStatus.State[]{AllocationStatus.State.STARTED, AllocationStatus.State.STARTING, AllocationStatus.State.FULLY_ALLOCATED};
        private static final int MAX_THREADS_PER_ALLOCATION = 32;
        private static final int MAX_QUEUE_CAPACITY = 1000000;
        public static final ParseField MODEL_ID = new ParseField("model_id", new String[0]);
        public static final ParseField TIMEOUT = new ParseField("timeout", new String[0]);
        public static final ParseField WAIT_FOR = new ParseField("wait_for", new String[0]);
        public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", new String[]{"inference_threads"});
        public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", new String[]{"model_threads"});
        public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
        public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE;
        public static final ParseField PRIORITY = TaskParams.PRIORITY;
        public static final ObjectParser<Request, Void> PARSER = new ObjectParser("cluster:admin/xpack/ml/trained_models/deployment/start", Request::new);
        private String modelId;
        private TimeValue timeout = DEFAULT_TIMEOUT;
        private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
        private ByteSizeValue cacheSize;
        private int numberOfAllocations = 1;
        private int threadsPerAllocation = 1;
        private int queueCapacity = 1024;
        private Priority priority = Priority.NORMAL;

        public static Request parseRequest(String modelId, XContentParser parser) {
            Request request = (Request)((Object)PARSER.apply(parser, null));
            if (request.getModelId() == null) {
                request.setModelId(modelId);
            } else if (!Strings.isNullOrEmpty((String)modelId) && !modelId.equals(request.getModelId())) {
                throw ExceptionsHelper.badRequestException(Messages.getMessage("Inconsistent {0}; ''{1}'' specified in the body differs from ''{2}'' specified as a URL argument", MODEL_ID, request.getModelId(), modelId), new Object[0]);
            }
            return request;
        }

        private Request() {
        }

        public Request(String modelId) {
            this.setModelId(modelId);
        }

        public Request(StreamInput in) throws IOException {
            super(in);
            this.modelId = in.readString();
            this.timeout = in.readTimeValue();
            this.waitForState = (AllocationStatus.State)in.readEnum(AllocationStatus.State.class);
            this.numberOfAllocations = in.readVInt();
            this.threadsPerAllocation = in.readVInt();
            this.queueCapacity = in.readVInt();
            if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_4_0)) {
                this.cacheSize = (ByteSizeValue)in.readOptionalWriteable(ByteSizeValue::readFrom);
            }
            this.priority = in.getTransportVersion().onOrAfter(TransportVersion.V_8_6_0) ? (Priority)in.readEnum(Priority.class) : Priority.NORMAL;
        }

        public final void setModelId(String modelId) {
            this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
        }

        public String getModelId() {
            return this.modelId;
        }

        public void setTimeout(TimeValue timeout) {
            this.timeout = ExceptionsHelper.requireNonNull(timeout, TIMEOUT);
        }

        public TimeValue getTimeout() {
            return this.timeout;
        }

        public AllocationStatus.State getWaitForState() {
            return this.waitForState;
        }

        public Request setWaitForState(AllocationStatus.State waitForState) {
            this.waitForState = ExceptionsHelper.requireNonNull(waitForState, WAIT_FOR);
            return this;
        }

        public int getNumberOfAllocations() {
            return this.numberOfAllocations;
        }

        public void setNumberOfAllocations(int numberOfAllocations) {
            this.numberOfAllocations = numberOfAllocations;
        }

        public int getThreadsPerAllocation() {
            return this.threadsPerAllocation;
        }

        public void setThreadsPerAllocation(int threadsPerAllocation) {
            this.threadsPerAllocation = threadsPerAllocation;
        }

        public int getQueueCapacity() {
            return this.queueCapacity;
        }

        public void setQueueCapacity(int queueCapacity) {
            this.queueCapacity = queueCapacity;
        }

        public ByteSizeValue getCacheSize() {
            return this.cacheSize;
        }

        public void setCacheSize(ByteSizeValue cacheSize) {
            this.cacheSize = cacheSize;
        }

        public Priority getPriority() {
            return this.priority;
        }

        public void setPriority(String priority) {
            this.priority = Priority.fromString(priority);
        }

        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            out.writeString(this.modelId);
            out.writeTimeValue(this.timeout);
            out.writeEnum((Enum)this.waitForState);
            out.writeVInt(this.numberOfAllocations);
            out.writeVInt(this.threadsPerAllocation);
            out.writeVInt(this.queueCapacity);
            if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_4_0)) {
                out.writeOptionalWriteable((Writeable)this.cacheSize);
            }
            if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_6_0)) {
                out.writeEnum((Enum)this.priority);
            }
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field(MODEL_ID.getPreferredName(), this.modelId);
            builder.field(TIMEOUT.getPreferredName(), this.timeout.getStringRep());
            builder.field(WAIT_FOR.getPreferredName(), (Enum)this.waitForState);
            builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), this.numberOfAllocations);
            builder.field(THREADS_PER_ALLOCATION.getPreferredName(), this.threadsPerAllocation);
            builder.field(QUEUE_CAPACITY.getPreferredName(), this.queueCapacity);
            if (this.cacheSize != null) {
                builder.field(CACHE_SIZE.getPreferredName(), (ToXContent)this.cacheSize);
            }
            builder.field(PRIORITY.getPreferredName(), (Enum)this.priority);
            builder.endObject();
            return builder;
        }

        public ActionRequestValidationException validate() {
            ActionRequestValidationException validationException = new ActionRequestValidationException();
            if (!this.waitForState.isAnyOf(VALID_WAIT_STATES)) {
                validationException.addValidationError("invalid [wait_for] state [" + this.waitForState + "]; must be one of [" + Strings.arrayToCommaDelimitedString((Object[])VALID_WAIT_STATES));
            }
            if (this.numberOfAllocations < 1) {
                validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer");
            }
            if (this.threadsPerAllocation < 1) {
                validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer");
            }
            if (this.threadsPerAllocation > 32 || !Request.isPowerOf2(this.threadsPerAllocation)) {
                validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a power of 2 less than or equal to 32");
            }
            if (this.queueCapacity < 1) {
                validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be a positive integer");
            }
            if (this.queueCapacity > 1000000) {
                validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be less than 1000000");
            }
            if (this.timeout.nanos() < 1L) {
                validationException.addValidationError("[" + TIMEOUT + "] must be positive");
            }
            if (this.priority == Priority.LOW) {
                if (this.numberOfAllocations > 1) {
                    validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be 1 when [" + PRIORITY + "] is low");
                }
                if (this.threadsPerAllocation > 1) {
                    validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be 1 when [" + PRIORITY + "] is low");
                }
            }
            return validationException.validationErrors().isEmpty() ? null : validationException;
        }

        private static boolean isPowerOf2(int value) {
            return Integer.bitCount(value) == 1;
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.modelId, this.timeout, this.waitForState, this.numberOfAllocations, this.threadsPerAllocation, this.queueCapacity, this.cacheSize, this.priority});
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || obj.getClass() != ((Object)((Object)this)).getClass()) {
                return false;
            }
            Request other = (Request)((Object)obj);
            return Objects.equals(this.modelId, other.modelId) && Objects.equals(this.timeout, other.timeout) && Objects.equals((Object)this.waitForState, (Object)other.waitForState) && Objects.equals(this.cacheSize, other.cacheSize) && this.numberOfAllocations == other.numberOfAllocations && this.threadsPerAllocation == other.threadsPerAllocation && this.queueCapacity == other.queueCapacity && this.priority == other.priority;
        }

        public String toString() {
            return Strings.toString((ToXContent)this);
        }

        static {
            PARSER.declareString(Request::setModelId, MODEL_ID);
            PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue((String)val, (String)TIMEOUT.getPreferredName())), TIMEOUT);
            PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
            PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
            PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
            PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
            PARSER.declareField(Request::setCacheSize, (p, c) -> ByteSizeValue.parseBytesSizeValue((String)p.text(), (String)CACHE_SIZE.getPreferredName()), CACHE_SIZE, ObjectParser.ValueType.VALUE);
            PARSER.declareString(Request::setPriority, PRIORITY);
        }
    }
}

