package org.jeecg.ai.handler;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.output.ServiceOutputParser;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.jeecg.ai.assistant.AiStreamChatAssistant;
import org.jeecg.ai.factory.AiModelFactory;
import org.jeecg.ai.prop.AiChatProperties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/jeecg/ai/handler/LLMHandler.class */
public class LLMHandler {

    @Generated
    private static final Logger log = LoggerFactory.getLogger(LLMHandler.class);
    private AiChatProperties aiChatProperties;
    private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jeecg/ai/handler/LLMHandler$CollateMsgResp.class */
    public static class CollateMsgResp {
        public final ChatMemory chatMemory;
        public final AugmentationResult augmentationResult;

        public CollateMsgResp(ChatMemory chatMemory, AugmentationResult augmentationResult) {
            this.chatMemory = chatMemory;
            this.augmentationResult = augmentationResult;
        }

        public String toString() {
            return "{messages=" + (this.chatMemory != null ? this.chatMemory.messages() : "null") + "}";
        }
    }

    public LLMHandler(AiChatProperties aiChatProperties) {
        this.aiChatProperties = aiChatProperties;
    }

    public LLMHandler() {
    }

    private AIParams ensureParams(AIParams aIParams) {
        if (null == aIParams || StringUtils.isEmpty(aIParams.getBaseUrl())) {
            aIParams = getDefaultModel(aIParams);
        }
        if (null == aIParams) {
            throw new IllegalArgumentException("大语言模型参数为空");
        }
        return aIParams;
    }

    private AIParams getDefaultModel(AIParams aIParams) {
        if (null == this.aiChatProperties) {
            log.warn("未配置默认大预言模型");
            return null;
        }
        if (aIParams == null) {
            aIParams = new AIParams();
        }
        aIParams.setProvider(this.aiChatProperties.getProvider());
        aIParams.setModelName(this.aiChatProperties.getModel());
        aIParams.setBaseUrl(this.aiChatProperties.getApiHost());
        aIParams.setApiKey(this.aiChatProperties.getApiKey());
        aIParams.setSecretKey(this.aiChatProperties.getCredential().getSecretKey());
        aIParams.setTimeout(Integer.valueOf(this.aiChatProperties.getTimeout()));
        return aIParams;
    }

    public String completions(String str) {
        return completions(Collections.singletonList(UserMessage.from(str)), null);
    }

    public String completions(List<ChatMessage> list, AIParams aIParams) {
        AIParams ensureParams = ensureParams(aIParams);
        ChatLanguageModel createChatModel = AiModelFactory.createChatModel(ensureParams.toModelOptions());
        CollateMsgResp collateMessage = collateMessage(list, ensureParams);
        log.info("[LLMHandler] send message to AI server. message: {}", collateMessage);
        String str = (String) this.serviceOutputParser.parse(createChatModel.generate(collateMessage.chatMemory.messages()), String.class);
        log.info("[LLMHandler] Received the AI's response . message: {}", str);
        return str;
    }

    public TokenStream chat(List<ChatMessage> list, AIParams aIParams) {
        AIParams ensureParams = ensureParams(aIParams);
        if (null == ensureParams) {
            throw new IllegalArgumentException("大语言模型参数为空");
        }
        StreamingChatLanguageModel createStreamingChatModel = AiModelFactory.createStreamingChatModel(ensureParams.toModelOptions());
        CollateMsgResp collateMessage = collateMessage(list, ensureParams);
        AiServiceContext aiServiceContext = new AiServiceContext(AiStreamChatAssistant.class);
        aiServiceContext.streamingChatModel = createStreamingChatModel;
        log.info("[LLMHandler] send message to AI server. message: {}", collateMessage);
        return new AiServiceTokenStream(collateMessage.chatMemory.messages(), (List) null, (Map) null, collateMessage.augmentationResult != null ? collateMessage.augmentationResult.contents() : null, aiServiceContext, "default");
    }

    private CollateMsgResp collateMessage(List<ChatMessage> list, AIParams aIParams) {
        DefaultRetrievalAugmentor build;
        if (null == aIParams) {
            aIParams = new AIParams();
        }
        ArrayList arrayList = new ArrayList(list);
        if (!((ChatMessage) arrayList.get(arrayList.size() - 1)).type().equals(ChatMessageType.USER)) {
            throw new IllegalArgumentException("最后一条消息必须是用户消息");
        }
        UserMessage userMessage = (UserMessage) arrayList.remove(arrayList.size() - 1);
        int intValue = null != aIParams.getMaxMsgNumber() ? aIParams.getMaxMsgNumber().intValue() + 2 : 6;
        AtomicReference atomicReference = new AtomicReference();
        arrayList.removeIf(chatMessage -> {
            if (!ChatMessageType.SYSTEM.equals(chatMessage.type())) {
                return false;
            }
            if (atomicReference.get() == null) {
                atomicReference.set((SystemMessage) chatMessage);
                return true;
            }
            atomicReference.set(SystemMessage.from(((SystemMessage) atomicReference.get()).text() + "\n" + ((SystemMessage) chatMessage).text()));
            return true;
        });
        MessageWindowChatMemory build2 = MessageWindowChatMemory.builder().maxMessages(Integer.valueOf(intValue)).build();
        if (null != atomicReference.get()) {
            build2.add((ChatMessage) atomicReference.get());
        }
        Objects.requireNonNull(build2);
        arrayList.forEach(build2::add);
        AugmentationResult augmentationResult = null;
        if (null != aIParams.getQueryRouter() && (build = DefaultRetrievalAugmentor.builder().queryRouter(aIParams.getQueryRouter()).build()) != null) {
            StringBuilder sb = new StringBuilder();
            ArrayList arrayList2 = new ArrayList(userMessage.contents());
            for (int size = arrayList2.size() - 1; size >= 0; size--) {
                if (arrayList2.get(size) instanceof TextContent) {
                    sb.append(((TextContent) arrayList2.remove(size)).text());
                    sb.append("\n");
                }
            }
            UserMessage from = UserMessage.from(sb.toString());
            augmentationResult = build.augment(new AugmentationRequest(from, Metadata.from(from, "default", build2.messages())));
            arrayList2.add(TextContent.from(augmentationResult.chatMessage().singleText()));
            userMessage = UserMessage.from(arrayList2);
        }
        build2.add(userMessage);
        return new CollateMsgResp(build2, augmentationResult);
    }
}
