/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.processor;

import java.util.ArrayList;
import java.util.List;

public class OutputTransformations {
    private static final String MEAN_POOLING_SUFFIX = ".meanPooling()";
    private static final String MAX_POOLING_SUFFIX = ".maxPooling()";

    public static boolean hasTransformation(String outputFieldName) {
        return outputFieldName != null && (outputFieldName.endsWith(MEAN_POOLING_SUFFIX) || outputFieldName.endsWith(MAX_POOLING_SUFFIX));
    }

    public static String getBaseFieldName(String outputFieldName) {
        if (outputFieldName != null) {
            if (outputFieldName.endsWith(MEAN_POOLING_SUFFIX)) {
                return outputFieldName.substring(0, outputFieldName.length() - MEAN_POOLING_SUFFIX.length());
            }
            if (outputFieldName.endsWith(MAX_POOLING_SUFFIX)) {
                return outputFieldName.substring(0, outputFieldName.length() - MAX_POOLING_SUFFIX.length());
            }
        }
        return outputFieldName;
    }

    public static Object applyMeanPooling(Object value) {
        if (!(value instanceof List)) {
            throw new IllegalArgumentException("Mean pooling requires a list input");
        }
        List outerList = (List)value;
        if (outerList.isEmpty() || !(outerList.getFirst() instanceof List)) {
            throw new IllegalArgumentException("Mean pooling requires nested array structure");
        }
        List firstVector = (List)outerList.getFirst();
        int dimensions = firstVector.size();
        float[] meanVector = new float[dimensions];
        for (Object vectorObj : outerList) {
            if (!(vectorObj instanceof List)) {
                throw new IllegalArgumentException("All elements must be vectors (lists)");
            }
            List vector = (List)vectorObj;
            if (vector.size() != dimensions) {
                throw new IllegalArgumentException("All vectors must have the same dimension");
            }
            int i = 0;
            while (i < dimensions) {
                Object element = vector.get(i);
                float val = element instanceof Number ? ((Number)element).floatValue() : 0.0f;
                int n = i++;
                meanVector[n] = meanVector[n] + val;
            }
        }
        ArrayList<Float> result = new ArrayList<Float>();
        for (int i = 0; i < dimensions; ++i) {
            result.add(Float.valueOf(meanVector[i] / (float)outerList.size()));
        }
        return result;
    }

    public static Object applyMaxPooling(Object value) {
        if (!(value instanceof List)) {
            throw new IllegalArgumentException("Max pooling requires a list input");
        }
        List outerList = (List)value;
        if (outerList.isEmpty() || !(outerList.getFirst() instanceof List)) {
            throw new IllegalArgumentException("Max pooling requires nested array structure");
        }
        List firstVector = (List)outerList.getFirst();
        int dimensions = firstVector.size();
        float[] maxVector = new float[dimensions];
        List firstVectorList = (List)outerList.getFirst();
        for (int i = 0; i < dimensions; ++i) {
            Object element = firstVectorList.get(i);
            maxVector[i] = element instanceof Number ? ((Number)element).floatValue() : Float.NEGATIVE_INFINITY;
        }
        for (Object vectorObj : outerList) {
            if (!(vectorObj instanceof List)) {
                throw new IllegalArgumentException("All elements must be vectors (lists)");
            }
            List vector = (List)vectorObj;
            if (vector.size() != dimensions) {
                throw new IllegalArgumentException("All vectors must have the same dimension");
            }
            for (int i = 0; i < dimensions; ++i) {
                Object element = vector.get(i);
                float val = element instanceof Number ? ((Number)element).floatValue() : Float.NEGATIVE_INFINITY;
                maxVector[i] = Math.max(maxVector[i], val);
            }
        }
        ArrayList<Float> result = new ArrayList<Float>();
        for (int i = 0; i < dimensions; ++i) {
            result.add(Float.valueOf(maxVector[i]));
        }
        return result;
    }

    public static Object applyTransformation(String outputFieldName, Object value) {
        if (outputFieldName != null) {
            if (outputFieldName.endsWith(MEAN_POOLING_SUFFIX)) {
                return OutputTransformations.applyMeanPooling(value);
            }
            if (outputFieldName.endsWith(MAX_POOLING_SUFFIX)) {
                return OutputTransformations.applyMaxPooling(value);
            }
        }
        return value;
    }
}

