package org.jeecg.ai.handler;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.jeecg.ai.assistant.AiChatAssistant;
import org.jeecg.ai.assistant.AiStreamChatAssistant;
import org.jeecg.ai.factory.AiModelFactory;
import org.jeecg.ai.prop.AiChatProperties;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/jeecg/ai/handler/LLMHandler.class */
public class LLMHandler {
    private static final Logger log = LoggerFactory.getLogger(LLMHandler.class);
    private AiChatProperties aiChatProperties;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jeecg/ai/handler/LLMHandler$CollateMsgResp.class */
    public static class CollateMsgResp {
        public final String systemMessage;
        public final String prompt;
        public final ChatMemory chatMemory;

        public CollateMsgResp(String str, String str2, ChatMemory chatMemory) {
            this.systemMessage = str;
            this.prompt = str2;
            this.chatMemory = chatMemory;
        }

        public String toString() {
            return "{systemMessage='" + this.systemMessage + "', prompt='" + this.prompt + "', chatMemory=" + (this.chatMemory != null ? this.chatMemory.messages() : "null") + '}';
        }
    }

    public LLMHandler(AiChatProperties aiChatProperties) {
        this.aiChatProperties = aiChatProperties;
    }

    public LLMHandler() {
    }

    private AIParams ensureParams(AIParams aIParams) {
        if (null == aIParams || StringUtils.isEmpty(aIParams.getApiKey())) {
            aIParams = getDefaultModel(aIParams);
        }
        if (null == aIParams) {
            throw new IllegalArgumentException("大语言模型参数为空");
        }
        return aIParams;
    }

    private AIParams getDefaultModel(AIParams aIParams) {
        if (null == this.aiChatProperties) {
            log.warn("未配置默认大预言模型");
            return null;
        }
        if (aIParams == null) {
            aIParams = new AIParams();
        }
        aIParams.setProvider(this.aiChatProperties.getProvider());
        aIParams.setModelName(this.aiChatProperties.getModel());
        aIParams.setBaseUrl(this.aiChatProperties.getApiHost());
        aIParams.setApiKey(this.aiChatProperties.getApiKey());
        aIParams.setSecretKey(this.aiChatProperties.getCredential().getSecretKey());
        return aIParams;
    }

    public String completions(String str) {
        return completions(Collections.singletonList(UserMessage.from(str)), null);
    }

    public String completions(List<ChatMessage> list, AIParams aIParams) {
        AIParams ensureParams = ensureParams(aIParams);
        ChatLanguageModel createChatModel = AiModelFactory.createChatModel(ensureParams.toModelOptions());
        AiServices builder = AiServices.builder(AiChatAssistant.class);
        builder.chatLanguageModel(createChatModel);
        if (null != ensureParams.getQueryRouter()) {
            builder.retrievalAugmentor(DefaultRetrievalAugmentor.builder().queryRouter(ensureParams.getQueryRouter()).build());
        }
        CollateMsgResp collateMessage = collateMessage(list, ensureParams);
        if (null != collateMessage.chatMemory) {
            builder.chatMemory(collateMessage.chatMemory);
        }
        AiChatAssistant aiChatAssistant = (AiChatAssistant) builder.build();
        log.info("[LLMHandler] send message to AI server. message: {}", collateMessage);
        String chat = StringUtils.isNotEmpty(collateMessage.systemMessage) ? aiChatAssistant.chat(collateMessage.systemMessage, collateMessage.prompt) : aiChatAssistant.chat(collateMessage.prompt);
        log.info("[LLMHandler] Received the AI's response . message: {}", chat);
        return chat;
    }

    public TokenStream chat(List<ChatMessage> list, AIParams aIParams) {
        AIParams ensureParams = ensureParams(aIParams);
        if (null == ensureParams) {
            throw new IllegalArgumentException("大语言模型参数为空");
        }
        StreamingChatLanguageModel createStreamingChatModel = AiModelFactory.createStreamingChatModel(ensureParams.toModelOptions());
        AiServices builder = AiServices.builder(AiStreamChatAssistant.class);
        builder.streamingChatLanguageModel(createStreamingChatModel);
        if (null != ensureParams.getQueryRouter()) {
            builder.retrievalAugmentor(DefaultRetrievalAugmentor.builder().queryRouter(ensureParams.getQueryRouter()).build());
        }
        CollateMsgResp collateMessage = collateMessage(list, ensureParams);
        if (null != collateMessage.chatMemory) {
            builder.chatMemory(collateMessage.chatMemory);
        }
        AiStreamChatAssistant aiStreamChatAssistant = (AiStreamChatAssistant) builder.build();
        log.info("[LLMHandler] send message to AI server. message: {}", collateMessage);
        return (null == collateMessage.systemMessage || collateMessage.systemMessage.isEmpty()) ? aiStreamChatAssistant.chat(collateMessage.prompt) : aiStreamChatAssistant.chat(collateMessage.systemMessage, collateMessage.prompt);
    }

    @NotNull
    private CollateMsgResp collateMessage(List<ChatMessage> list, AIParams aIParams) {
        if (null == aIParams) {
            aIParams = new AIParams();
        }
        String str = "";
        String str2 = "";
        ChatMemory chatMemory = null;
        if (null != list && !list.isEmpty()) {
            ArrayList<ChatMessage> arrayList = new ArrayList(list);
            str = (String) arrayList.stream().filter(chatMessage -> {
                return ChatMessageType.SYSTEM.equals(chatMessage.type());
            }).map(chatMessage2 -> {
                return ((SystemMessage) chatMessage2).text();
            }).collect(Collectors.joining("\n"));
            int size = arrayList.size() - 1;
            while (true) {
                if (size < 0) {
                    break;
                }
                UserMessage userMessage = (ChatMessage) arrayList.get(size);
                if (ChatMessageType.USER.equals(userMessage.type())) {
                    str2 = userMessage.singleText();
                    arrayList.remove(size);
                    break;
                }
                size--;
            }
            if (StringUtils.isNotEmpty(aIParams.getKnowledgeTxt())) {
                str2 = String.format("%s\n\n用以下信息回答问题:\n%s\n\n", str2, aIParams.getKnowledgeTxt());
            }
            str2 = str2.replaceAll("\\{\\{(.*?)}}", "$1");
            chatMemory = MessageWindowChatMemory.builder().maxMessages(Integer.valueOf(null != aIParams.getMaxMsgNumber() ? aIParams.getMaxMsgNumber().intValue() + 2 : 6)).build();
            for (ChatMessage chatMessage3 : arrayList) {
                if (null != chatMessage3 && !chatMessage3.type().equals(ChatMessageType.SYSTEM)) {
                    chatMemory.add(chatMessage3);
                }
            }
        }
        return new CollateMsgResp(str, str2, chatMemory);
    }
}
