/*
 * Decompiled with CFR 0.152.
 */
package org.jeecg.ai.handler;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.exception.ToolExecutionException;
import dev.langchain4j.mcp.McpToolProvider;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.jeecg.ai.factory.AiModelFactory;
import org.jeecg.ai.factory.AiModelOptions;
import org.jeecg.ai.handler.AIParams;
import org.jeecg.ai.prop.AiChatProperties;
import org.jeecg.ai.stream.InternalTokenStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LLMHandler {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(LLMHandler.class);
    private AiChatProperties aiChatProperties;
    private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();

    public LLMHandler(AiChatProperties aiChatProperties) {
        this.aiChatProperties = aiChatProperties;
    }

    public LLMHandler() {
    }

    private AIParams ensureParams(AIParams params) {
        if (null == params || StringUtils.isEmpty((String)params.getBaseUrl())) {
            params = this.getDefaultModel(params);
        }
        if (null == params) {
            throw new IllegalArgumentException("\u5927\u8bed\u8a00\u6a21\u578b\u53c2\u6570\u4e3a\u7a7a");
        }
        return params;
    }

    private AIParams getDefaultModel(AIParams params) {
        if (null == this.aiChatProperties) {
            log.warn("\u672a\u914d\u7f6e\u9ed8\u8ba4\u5927\u9884\u8a00\u6a21\u578b");
            return null;
        }
        if (params == null) {
            params = new AIParams();
        }
        params.setProvider(this.aiChatProperties.getProvider());
        params.setModelName(this.aiChatProperties.getModel());
        params.setBaseUrl(this.aiChatProperties.getApiHost());
        params.setApiKey(this.aiChatProperties.getApiKey());
        params.setSecretKey(this.aiChatProperties.getCredential().getSecretKey());
        params.setTimeout(this.aiChatProperties.getTimeout());
        return params;
    }

    public String completions(String message) {
        return this.completions(Collections.singletonList(UserMessage.from((String)message)), null);
    }

    public String completions(List<ChatMessage> messages, AIParams params) {
        params = this.ensureParams(params);
        AiModelOptions modelOp = params.toModelOptions();
        ChatModel chatModel = AiModelFactory.createChatModel(modelOp);
        CollateMsgResp chatMessage = this.collateMessage(messages, params);
        ArrayList<ToolSpecification> toolSpecifications = new ArrayList<ToolSpecification>();
        HashMap<String, ToolExecutor> toolExecutors = new HashMap<String, ToolExecutor>();
        if (null != params.getTools() && !params.getTools().isEmpty()) {
            toolSpecifications = new ArrayList<ToolSpecification>(params.getTools().keySet());
            params.getTools().forEach((tool, executor) -> toolExecutors.put(tool.name(), (ToolExecutor)executor));
        }
        this.fillMcpTools(params, chatMessage, toolSpecifications, toolExecutors);
        String resp = "";
        log.info("[LLMHandler] send message to AI server. message: {}", (Object)chatMessage);
        block4: while (true) {
            ChatRequest.Builder requestBuilder = ChatRequest.builder().messages(chatMessage.chatMemory.messages());
            if (this.isSupportTools(chatModel.defaultRequestParameters())) {
                requestBuilder = requestBuilder.toolSpecifications(toolSpecifications);
            }
            ChatResponse response = null;
            try {
                response = chatModel.chat(requestBuilder.build());
            }
            catch (ToolExecutionException e) {
                log.error("\u5de5\u5177\u8c03\u7528\u5931\u8d25\uff1a{}", (Object)e.getMessage(), (Object)e);
                break;
            }
            AiMessage aiMessage = response.aiMessage();
            chatMessage.chatMemory.add((ChatMessage)aiMessage);
            if (aiMessage.toolExecutionRequests() == null || aiMessage.toolExecutionRequests().isEmpty()) {
                resp = (String)this.serviceOutputParser.parse(response, String.class);
                break;
            }
            Iterator iterator = aiMessage.toolExecutionRequests().iterator();
            while (true) {
                if (!iterator.hasNext()) continue block4;
                ToolExecutionRequest toolExecReq = (ToolExecutionRequest)iterator.next();
                ToolExecutor executor2 = (ToolExecutor)toolExecutors.get(toolExecReq.name());
                if (executor2 == null) {
                    throw new IllegalStateException("\u672a\u627e\u5230\u5de5\u5177\u6267\u884c\u5668: " + toolExecReq.name());
                }
                log.info("[LLMHandler] Executing tool: {} ", (Object)toolExecReq.name());
                try {
                    String result = executor2.execute(toolExecReq, chatMessage.chatMemory.id());
                    ToolExecutionResultMessage resultMsg = ToolExecutionResultMessage.from((ToolExecutionRequest)toolExecReq, (String)result);
                    chatMessage.chatMemory.add((ChatMessage)resultMsg);
                }
                catch (ToolExecutionException e) {
                    log.info("\u63d2\u4ef6\u8fd0\u884c\u5931\u8d25\uff0c\u539f\u56e0\uff1a{}", (Object)e.getMessage(), (Object)e);
                }
            }
            break;
        }
        log.info("[LLMHandler] Received the AI's response . message: {}", (Object)resp);
        return resp;
    }

    private boolean isSupportTools(ChatRequestParameters parameters) {
        String modelName = parameters.modelName();
        boolean isMultimodalModel = modelName.contains("-vl-") || modelName.contains("-audio-") || modelName.contains("-omni-");
        return !isMultimodalModel;
    }

    public TokenStream chat(List<ChatMessage> messages, AIParams params) {
        if (null == (params = this.ensureParams(params))) {
            throw new IllegalArgumentException("\u5927\u8bed\u8a00\u6a21\u578b\u53c2\u6570\u4e3a\u7a7a");
        }
        AiModelOptions modelOp = params.toModelOptions();
        StreamingChatModel streamingChatModel = AiModelFactory.createStreamingChatModel(modelOp);
        ArrayList<ToolSpecification> toolSpecifications = new ArrayList<ToolSpecification>();
        HashMap<String, ToolExecutor> toolExecutors = new HashMap<String, ToolExecutor>();
        if (null != params.getTools() && !params.getTools().isEmpty()) {
            toolSpecifications = new ArrayList<ToolSpecification>(params.getTools().keySet());
            params.getTools().forEach((tool, executor) -> toolExecutors.put(tool.name(), (ToolExecutor)executor));
        }
        CollateMsgResp chatMessage = this.collateMessage(messages, params);
        this.fillMcpTools(params, chatMessage, toolSpecifications, toolExecutors);
        return new InternalTokenStream(streamingChatModel, toolSpecifications, toolExecutors, chatMessage.chatMemory, chatMessage.augmentationResult != null ? chatMessage.augmentationResult.contents() : null);
    }

    private void fillMcpTools(AIParams params, CollateMsgResp chatMessage, List<ToolSpecification> toolSpecifications, Map<String, ToolExecutor> toolExecutors) {
        if (params.getMcpToolProviders() == null || params.getMcpToolProviders().isEmpty()) {
            return;
        }
        for (McpToolProvider toolProvider : params.getMcpToolProviders()) {
            ToolProviderRequest request;
            ToolProviderResult result = toolProvider.provideTools(request = new ToolProviderRequest(chatMessage.chatMemory.id(), chatMessage.userMessage));
            if (result == null || result.tools() == null) continue;
            for (Map.Entry entry : result.tools().entrySet()) {
                toolSpecifications.add((ToolSpecification)entry.getKey());
                toolExecutors.put(((ToolSpecification)entry.getKey()).name(), (ToolExecutor)entry.getValue());
            }
        }
    }

    private CollateMsgResp collateMessage(List<ChatMessage> messages, AIParams params) {
        DefaultRetrievalAugmentor retrievalAugmentor;
        if (null == params) {
            params = new AIParams();
        }
        ArrayList<ChatMessage> messagesCopy = new ArrayList<ChatMessage>(messages);
        UserMessage userMessage = null;
        ChatMessage lastMessage = (ChatMessage)messagesCopy.get(messagesCopy.size() - 1);
        if (!lastMessage.type().equals((Object)ChatMessageType.USER)) {
            throw new IllegalArgumentException("\u6700\u540e\u4e00\u6761\u6d88\u606f\u5fc5\u987b\u662f\u7528\u6237\u6d88\u606f");
        }
        userMessage = (UserMessage)messagesCopy.remove(messagesCopy.size() - 1);
        int maxMsgNumber = 6;
        if (null != params.getMaxMsgNumber()) {
            maxMsgNumber = params.getMaxMsgNumber() + 2;
        }
        AtomicReference systemMessageAto = new AtomicReference();
        messagesCopy.removeIf(tempMsg -> {
            if (ChatMessageType.SYSTEM.equals((Object)tempMsg.type())) {
                if (systemMessageAto.get() == null) {
                    systemMessageAto.set((SystemMessage)tempMsg);
                } else {
                    SystemMessage systemMessage = (SystemMessage)systemMessageAto.get();
                    String text = systemMessage.text() + "\n" + ((SystemMessage)tempMsg).text();
                    systemMessageAto.set(SystemMessage.from((String)text));
                }
                return true;
            }
            return false;
        });
        MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder().maxMessages(Integer.valueOf(maxMsgNumber)).build();
        if (null != systemMessageAto.get()) {
            chatMemory.add((ChatMessage)systemMessageAto.get());
        }
        messagesCopy.forEach(arg_0 -> ((ChatMemory)chatMemory).add(arg_0));
        AugmentationResult augmentationResult = null;
        if (null != params.getQueryRouter() && (retrievalAugmentor = DefaultRetrievalAugmentor.builder().queryRouter(params.getQueryRouter()).build()) != null) {
            StringBuilder userQuestion = new StringBuilder();
            ArrayList<TextContent> contents = new ArrayList<TextContent>(userMessage.contents());
            for (int i = contents.size() - 1; i >= 0; --i) {
                if (!(contents.get(i) instanceof TextContent)) continue;
                userQuestion.append(((TextContent)contents.remove(i)).text());
                userQuestion.append("\n");
            }
            UserMessage textUserMessage = UserMessage.from((String)userQuestion.toString());
            Metadata metadata = Metadata.from((ChatMessage)textUserMessage, (Object)"default", (List)chatMemory.messages());
            AugmentationRequest augmentationRequest = new AugmentationRequest((ChatMessage)textUserMessage, metadata);
            augmentationResult = retrievalAugmentor.augment(augmentationRequest);
            textUserMessage = (UserMessage)augmentationResult.chatMessage();
            contents.add(TextContent.from((String)textUserMessage.singleText()));
            userMessage = UserMessage.from(contents);
        }
        chatMemory.add((ChatMessage)userMessage);
        return new CollateMsgResp((ChatMemory)chatMemory, augmentationResult, userMessage);
    }

    private static class CollateMsgResp {
        public final ChatMemory chatMemory;
        public final AugmentationResult augmentationResult;
        public final UserMessage userMessage;

        public CollateMsgResp(ChatMemory chatMemory, AugmentationResult augmentationResult, UserMessage userMessage) {
            this.chatMemory = chatMemory;
            this.augmentationResult = augmentationResult;
            this.userMessage = userMessage;
        }

        public String toString() {
            return "{messages=" + (this.chatMemory != null ? this.chatMemory.messages() : "null") + "}";
        }
    }
}

