/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.schema;

import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource;
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.SortField;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrException;
import org.apache.solr.schema.FloatPointField;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.QParser;
import org.apache.solr.uninverting.UninvertingReader;
import org.apache.solr.util.vector.ByteDenseVectorParser;
import org.apache.solr.util.vector.DenseVectorParser;
import org.apache.solr.util.vector.FloatDenseVectorParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DenseVectorField
extends FloatPointField {
    private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
    public static final String HNSW_ALGORITHM = "hnsw";
    public static final String DEFAULT_KNN_ALGORITHM = "hnsw";
    static final String KNN_VECTOR_DIMENSION = "vectorDimension";
    static final String KNN_ALGORITHM = "knnAlgorithm";
    static final String HNSW_MAX_CONNECTIONS = "hnswMaxConnections";
    static final String HNSW_BEAM_WIDTH = "hnswBeamWidth";
    static final String VECTOR_ENCODING = "vectorEncoding";
    static final VectorEncoding DEFAULT_VECTOR_ENCODING = VectorEncoding.FLOAT32;
    static final String KNN_SIMILARITY_FUNCTION = "similarityFunction";
    static final VectorSimilarityFunction DEFAULT_SIMILARITY = VectorSimilarityFunction.EUCLIDEAN;
    private int dimension;
    private VectorSimilarityFunction similarityFunction;
    private String knnAlgorithm;
    private int hnswMaxConn;
    private int hnswBeamWidth;
    private VectorEncoding vectorEncoding;

    public DenseVectorField() {
    }

    public DenseVectorField(int dimension) {
        this(dimension, DEFAULT_SIMILARITY, DEFAULT_VECTOR_ENCODING);
    }

    public DenseVectorField(int dimension, VectorEncoding vectorEncoding) {
        this(dimension, DEFAULT_SIMILARITY, vectorEncoding);
    }

    public DenseVectorField(int dimension, VectorSimilarityFunction similarityFunction, VectorEncoding vectorEncoding) {
        this.dimension = dimension;
        this.similarityFunction = similarityFunction;
        this.vectorEncoding = vectorEncoding;
    }

    @Override
    public void init(IndexSchema schema, Map<String, String> args) {
        this.dimension = Optional.ofNullable(args.get(KNN_VECTOR_DIMENSION)).map(Integer::parseInt).orElseThrow(() -> new SolrException(SolrException.ErrorCode.SERVER_ERROR, "the vector dimension is a mandatory parameter"));
        args.remove(KNN_VECTOR_DIMENSION);
        this.similarityFunction = Optional.ofNullable(args.get(KNN_SIMILARITY_FUNCTION)).map(value -> VectorSimilarityFunction.valueOf((String)value.toUpperCase(Locale.ROOT))).orElse(DEFAULT_SIMILARITY);
        args.remove(KNN_SIMILARITY_FUNCTION);
        this.knnAlgorithm = args.getOrDefault(KNN_ALGORITHM, "hnsw");
        args.remove(KNN_ALGORITHM);
        this.vectorEncoding = Optional.ofNullable(args.get(VECTOR_ENCODING)).map(value -> VectorEncoding.valueOf((String)value.toUpperCase(Locale.ROOT))).orElse(DEFAULT_VECTOR_ENCODING);
        args.remove(VECTOR_ENCODING);
        this.hnswMaxConn = Optional.ofNullable(args.get(HNSW_MAX_CONNECTIONS)).map(Integer::parseInt).orElse(16);
        args.remove(HNSW_MAX_CONNECTIONS);
        this.hnswBeamWidth = Optional.ofNullable(args.get(HNSW_BEAM_WIDTH)).map(Integer::parseInt).orElse(100);
        args.remove(HNSW_BEAM_WIDTH);
        this.properties &= 0xFFFFFDFF;
        this.properties &= 0xFFF7FFFF;
        super.init(schema, args);
    }

    public int getDimension() {
        return this.dimension;
    }

    public VectorSimilarityFunction getSimilarityFunction() {
        return this.similarityFunction;
    }

    public String getKnnAlgorithm() {
        return this.knnAlgorithm;
    }

    public Integer getHnswMaxConn() {
        return this.hnswMaxConn;
    }

    public Integer getHnswBeamWidth() {
        return this.hnswBeamWidth;
    }

    public VectorEncoding getVectorEncoding() {
        return this.vectorEncoding;
    }

    @Override
    protected boolean enableDocValuesByDefault() {
        return false;
    }

    @Override
    public void checkSchemaField(SchemaField field) throws SolrException {
        super.checkSchemaField(field);
        if (field.multiValued()) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, this.getClass().getSimpleName() + " fields can not be multiValued: " + field.getName());
        }
        if (field.hasDocValues()) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, this.getClass().getSimpleName() + " fields can not have docValues: " + field.getName());
        }
        switch (this.vectorEncoding) {
            case FLOAT32: {
                if (this.dimension <= 1024 || !log.isWarnEnabled()) break;
                log.warn("The vector dimension {} specified for field {} exceeds the current Lucene default max dimension of {}. It's un-tested territory, extra caution and benchmarks are recommended for production systems.", new Object[]{this.dimension, field.getName(), 1024});
                break;
            }
            case BYTE: {
                if (this.dimension <= 1024 || !log.isWarnEnabled()) break;
                log.warn("The vector dimension {} specified for field {} exceeds the current Lucene default max dimension of {}. It's un-tested territory, extra caution and benchmarks are recommended for production systems.", new Object[]{this.dimension, field.getName(), 1024});
            }
        }
    }

    @Override
    public List<IndexableField> createFields(SchemaField field, Object value) {
        try {
            ArrayList<IndexableField> fields = new ArrayList<IndexableField>();
            DenseVectorParser vectorBuilder = this.getVectorBuilder(value, DenseVectorParser.BuilderPhase.INDEX);
            if (field.indexed()) {
                fields.add(this.createField(field, vectorBuilder));
            }
            if (field.stored()) {
                switch (this.vectorEncoding) {
                    case FLOAT32: {
                        fields.ensureCapacity(vectorBuilder.getFloatVector().length + 1);
                        for (float vectorElement : vectorBuilder.getFloatVector()) {
                            fields.add((IndexableField)this.getStoredField(field, Float.valueOf(vectorElement)));
                        }
                        break;
                    }
                    case BYTE: {
                        fields.add((IndexableField)new StoredField(field.getName(), vectorBuilder.getByteVector()));
                    }
                }
            }
            return fields;
        }
        catch (RuntimeException e) {
            throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Error while creating field '" + String.valueOf(field) + "' from value '" + String.valueOf(value) + "'", (Throwable)e);
        }
    }

    @Override
    public IndexableField createField(SchemaField field, Object vectorValue) {
        FieldType denseVectorFieldType = this.getDenseVectorFieldType();
        if (vectorValue == null) {
            return null;
        }
        DenseVectorParser vectorBuilder = (DenseVectorParser)vectorValue;
        switch (this.vectorEncoding) {
            case BYTE: {
                return new KnnByteVectorField(field.getName(), vectorBuilder.getByteVector(), denseVectorFieldType);
            }
            case FLOAT32: {
                return new KnnFloatVectorField(field.getName(), vectorBuilder.getFloatVector(), denseVectorFieldType);
            }
        }
        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unexpected state. Vector Encoding: " + String.valueOf(this.vectorEncoding));
    }

    private FieldType getDenseVectorFieldType() {
        FieldType vectorFieldType = new FieldType(){

            public int vectorDimension() {
                return DenseVectorField.this.dimension;
            }

            public VectorEncoding vectorEncoding() {
                return DenseVectorField.this.vectorEncoding;
            }

            public VectorSimilarityFunction vectorSimilarityFunction() {
                return DenseVectorField.this.similarityFunction;
            }
        };
        return vectorFieldType;
    }

    @Override
    public Object toObject(IndexableField f) {
        if (this.vectorEncoding.equals((Object)VectorEncoding.BYTE)) {
            BytesRef bytesRef = f.binaryValue();
            if (bytesRef != null) {
                ArrayList<Integer> ret = new ArrayList<Integer>(this.dimension);
                for (byte b : bytesRef.bytes) {
                    ret.add(Integer.valueOf(b));
                }
                return ret;
            }
            throw new AssertionError((Object)("Unexpected state. Field: '" + String.valueOf(f) + "'"));
        }
        return super.toObject(f);
    }

    public DenseVectorParser getVectorBuilder(Object inputValue, DenseVectorParser.BuilderPhase phase) {
        switch (this.vectorEncoding) {
            case FLOAT32: {
                return new FloatDenseVectorParser(this.dimension, inputValue, phase);
            }
            case BYTE: {
                return new ByteDenseVectorParser(this.dimension, inputValue, phase);
            }
        }
        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unexpected state. Vector Encoding: " + String.valueOf(this.vectorEncoding));
    }

    @Override
    public UninvertingReader.Type getUninversionType(SchemaField sf) {
        return null;
    }

    @Override
    public ValueSource getValueSource(SchemaField field, QParser parser) {
        switch (this.vectorEncoding) {
            case FLOAT32: {
                return new FloatKnnVectorFieldSource(field.getName());
            }
            case BYTE: {
                return new ByteKnnVectorFieldSource(field.getName());
            }
        }
        throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Vector encoding not supported for function queries.");
    }

    public Query getKnnVectorQuery(String fieldName, String vectorToSearch, int topK, Query filterQuery) {
        DenseVectorParser vectorBuilder = this.getVectorBuilder(vectorToSearch, DenseVectorParser.BuilderPhase.QUERY);
        switch (this.vectorEncoding) {
            case FLOAT32: {
                return new KnnFloatVectorQuery(fieldName, vectorBuilder.getFloatVector(), topK, filterQuery);
            }
            case BYTE: {
                return new KnnByteVectorQuery(fieldName, vectorBuilder.getByteVector(), topK, filterQuery);
            }
        }
        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, "Unexpected state. Vector Encoding: " + String.valueOf(this.vectorEncoding));
    }

    @Override
    public Query getFieldQuery(QParser parser, SchemaField field, String externalVal) {
        throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Field Queries are not supported for Dense Vector fields. Please use the {!knn} query parser to run K nearest neighbors search queries.");
    }

    @Override
    public Query getRangeQuery(QParser parser, SchemaField field, String part1, String part2, boolean minInclusive, boolean maxInclusive) {
        throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Range Queries are not supported for Dense Vector fields. Please use the {!knn} query parser to run K nearest neighbors search queries.");
    }

    @Override
    public SortField getSortField(SchemaField field, boolean top) {
        throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, "Cannot sort on a Dense Vector field");
    }
}

