/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.modality.nlp.generate.SearchConfig;
import ai.djl.modality.nlp.generate.SeqBatcher;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class SeqBatchScheduler {
    private static final Logger logger = LoggerFactory.getLogger(SeqBatchScheduler.class);
    Predictor<NDList, CausalLMOutput> predictor;
    SeqBatcher seqBatcher;
    NDManager manager;
    SearchConfig config;
    Map<Long, NDArray> results;

    public SeqBatchScheduler(Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig config) {
        this.predictor = lmBlock;
        this.config = config;
        this.results = new ConcurrentHashMap<Long, NDArray>();
    }

    public abstract SeqBatcher initForward(NDArray var1, NDArray var2) throws TranslateException;

    public boolean incrementForward(int count) throws TranslateException {
        int i = 0;
        while (i++ < count) {
            if (this.seqBatcher == null || this.seqBatcher.getData() == null) {
                logger.info("seqBatcher not set or is empty. Please call addBatch. Current inference times is " + i);
                return true;
            }
            this.inferenceCall();
            if (!this.seqBatcher.sequenceComplete()) continue;
            this.results.putAll(this.seqBatcher.collectAndTrim());
        }
        return false;
    }

    protected abstract NDArray inferenceCall() throws TranslateException;

    public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateException {
        SeqBatcher seqBatcherNew = this.initForward(inputIds, batchUids);
        if (this.seqBatcher == null) {
            this.seqBatcher = seqBatcherNew;
        } else {
            this.seqBatcher.addBatch(seqBatcherNew);
        }
    }

    public Map<Long, NDArray> collectResults() {
        Map<Long, NDArray> output = this.results;
        this.results = new ConcurrentHashMap<Long, NDArray>();
        return output;
    }

    static NDArray computeOffSets(NDArray inputIds, SearchConfig config) {
        int numBatch = Math.toIntExact(inputIds.getShape().get(0));
        int initSeqSize = Math.toIntExact(inputIds.getShape().get(1));
        long[] offSetsArray = new long[numBatch];
        for (int i = 0; i < numBatch; ++i) {
            int idx;
            long[] aSequence = inputIds.get("{},:", i).toLongArray();
            for (idx = 0; idx < initSeqSize && aSequence[idx] == config.getPadTokenId(); ++idx) {
            }
            offSetsArray[i] = idx;
        }
        NDManager manager = inputIds.getManager();
        return manager.create(offSetsArray).reshape(-1L, 1L);
    }

    static NDArray computeAttentionMask(NDArray inputIds, SearchConfig config) {
        int numBatch = Math.toIntExact(inputIds.getShape().get(0));
        int initSeqSize = Math.toIntExact(inputIds.getShape().get(1));
        NDManager manager = inputIds.getManager();
        NDArray attentionMask = manager.ones(new Shape(1L, inputIds.getShape().getLastDimension()), DataType.INT64).reshape(1L, -1L).repeat(0, numBatch);
        for (int i = 0; i < numBatch; ++i) {
            int idx;
            long[] aSequence = inputIds.get("{},:", i).toLongArray();
            for (idx = 0; idx < initSeqSize && aSequence[idx] == config.getPadTokenId(); ++idx) {
            }
            attentionMask.set(new NDIndex("{},{}:{}", i, 0, idx), (Number)0);
        }
        return attentionMask;
    }

    static NDArray computePositionIds(NDArray inputIds, NDArray offSets, long pastSeqLength, int repeat) {
        NDManager manager = inputIds.getManager();
        NDArray positionIds = manager.arange(pastSeqLength, pastSeqLength + inputIds.getShape().getLastDimension(), 1.0f, DataType.INT64).expandDims(0).repeat(0, inputIds.getShape().get(0));
        NDArray positionIdsShifted = positionIds.subi(offSets.reshape(-1L, 1L).repeat(0, repeat));
        positionIds = positionIdsShifted.maximum(positionIdsShifted.zerosLike());
        return positionIds;
    }
}

