/*
 * Decompiled with CFR 0.152.
 */
package org.jeecg.ai.handler;

import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.exception.ToolExecutionException;
import dev.langchain4j.mcp.McpToolProvider;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import javax.imageio.ImageIO;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.jeecg.ai.enums.QwenImageModelEnum;
import org.jeecg.ai.factory.AiModelFactory;
import org.jeecg.ai.factory.AiModelOptions;
import org.jeecg.ai.handler.AIParams;
import org.jeecg.ai.prop.AiChatProperties;
import org.jeecg.ai.stream.InternalTokenStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;

public class LLMHandler {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(LLMHandler.class);
    private AiChatProperties aiChatProperties;
    private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();

    public LLMHandler(AiChatProperties aiChatProperties) {
        this.aiChatProperties = aiChatProperties;
    }

    public LLMHandler() {
    }

    private AIParams ensureParams(AIParams params) {
        if (null == params || StringUtils.isEmpty((String)params.getBaseUrl())) {
            params = this.getDefaultModel(params);
        }
        if (null == params) {
            throw new IllegalArgumentException("\u5927\u8bed\u8a00\u6a21\u578b\u53c2\u6570\u4e3a\u7a7a");
        }
        return params;
    }

    private AIParams getDefaultModel(AIParams params) {
        if (null == this.aiChatProperties) {
            log.warn("\u672a\u914d\u7f6e\u9ed8\u8ba4\u5927\u9884\u8a00\u6a21\u578b");
            return null;
        }
        if (params == null) {
            params = new AIParams();
        }
        params.setProvider(this.aiChatProperties.getProvider());
        params.setModelName(this.aiChatProperties.getModel());
        params.setBaseUrl(this.aiChatProperties.getApiHost());
        params.setApiKey(this.aiChatProperties.getApiKey());
        params.setSecretKey(this.aiChatProperties.getCredential().getSecretKey());
        params.setTimeout(this.aiChatProperties.getTimeout());
        return params;
    }

    public String completions(String message) {
        return this.completions(Collections.singletonList(UserMessage.from((String)message)), null);
    }

    public String completions(List<ChatMessage> messages, AIParams params) {
        params = this.ensureParams(params);
        AiModelOptions modelOp = params.toModelOptions();
        ChatModel chatModel = AiModelFactory.createChatModel(modelOp);
        CollateMsgResp chatMessage = this.collateMessage(messages, params);
        ArrayList<ToolSpecification> toolSpecifications = new ArrayList<ToolSpecification>();
        HashMap<String, ToolExecutor> toolExecutors = new HashMap<String, ToolExecutor>();
        if (null != params.getTools() && !params.getTools().isEmpty()) {
            toolSpecifications = new ArrayList<ToolSpecification>(params.getTools().keySet());
            params.getTools().forEach((tool, executor) -> toolExecutors.put(tool.name(), (ToolExecutor)executor));
        }
        this.fillMcpTools(params, chatMessage, toolSpecifications, toolExecutors);
        String resp = "";
        log.info("[LLMHandler] send message to AI server. message: {}", (Object)chatMessage);
        block4: while (true) {
            ChatRequest.Builder requestBuilder = ChatRequest.builder().messages(chatMessage.chatMemory.messages());
            if (this.isSupportTools(chatModel.defaultRequestParameters())) {
                requestBuilder = requestBuilder.toolSpecifications(toolSpecifications);
            }
            ChatResponse response = null;
            try {
                response = chatModel.chat(requestBuilder.build());
            }
            catch (ToolExecutionException e) {
                log.error("\u5de5\u5177\u8c03\u7528\u5931\u8d25\uff1a{}", (Object)e.getMessage(), (Object)e);
                break;
            }
            AiMessage aiMessage = response.aiMessage();
            chatMessage.chatMemory.add((ChatMessage)aiMessage);
            if (aiMessage.toolExecutionRequests() == null || aiMessage.toolExecutionRequests().isEmpty()) {
                resp = (String)this.serviceOutputParser.parse(response, String.class);
                break;
            }
            Iterator iterator = aiMessage.toolExecutionRequests().iterator();
            while (true) {
                if (!iterator.hasNext()) continue block4;
                ToolExecutionRequest toolExecReq = (ToolExecutionRequest)iterator.next();
                ToolExecutor executor2 = (ToolExecutor)toolExecutors.get(toolExecReq.name());
                if (executor2 == null) {
                    throw new IllegalStateException("\u672a\u627e\u5230\u5de5\u5177\u6267\u884c\u5668: " + toolExecReq.name());
                }
                log.info("[LLMHandler] Executing tool: {} ", (Object)toolExecReq.name());
                try {
                    String result = executor2.execute(toolExecReq, chatMessage.chatMemory.id());
                    ToolExecutionResultMessage resultMsg = ToolExecutionResultMessage.from((ToolExecutionRequest)toolExecReq, (String)result);
                    chatMessage.chatMemory.add((ChatMessage)resultMsg);
                }
                catch (ToolExecutionException e) {
                    log.info("\u63d2\u4ef6\u8fd0\u884c\u5931\u8d25\uff0c\u539f\u56e0\uff1a{}", (Object)e.getMessage(), (Object)e);
                }
            }
            break;
        }
        log.info("[LLMHandler] Received the AI's response . message: {}", (Object)resp);
        return resp;
    }

    private boolean isSupportTools(ChatRequestParameters parameters) {
        String modelName = parameters.modelName();
        boolean isMultimodalModel = modelName.contains("-vl-") || modelName.contains("-audio-") || modelName.contains("-omni-");
        return !isMultimodalModel;
    }

    public TokenStream chat(List<ChatMessage> messages, AIParams params) {
        if (null == (params = this.ensureParams(params))) {
            throw new IllegalArgumentException("\u5927\u8bed\u8a00\u6a21\u578b\u53c2\u6570\u4e3a\u7a7a");
        }
        AiModelOptions modelOp = params.toModelOptions();
        StreamingChatModel streamingChatModel = AiModelFactory.createStreamingChatModel(modelOp);
        ArrayList<ToolSpecification> toolSpecifications = new ArrayList<ToolSpecification>();
        HashMap<String, ToolExecutor> toolExecutors = new HashMap<String, ToolExecutor>();
        if (null != params.getTools() && !params.getTools().isEmpty()) {
            toolSpecifications = new ArrayList<ToolSpecification>(params.getTools().keySet());
            params.getTools().forEach((tool, executor) -> toolExecutors.put(tool.name(), (ToolExecutor)executor));
        }
        CollateMsgResp chatMessage = this.collateMessage(messages, params);
        this.fillMcpTools(params, chatMessage, toolSpecifications, toolExecutors);
        return new InternalTokenStream(streamingChatModel, toolSpecifications, toolExecutors, chatMessage.chatMemory, chatMessage.augmentationResult != null ? chatMessage.augmentationResult.contents() : null);
    }

    private void fillMcpTools(AIParams params, CollateMsgResp chatMessage, List<ToolSpecification> toolSpecifications, Map<String, ToolExecutor> toolExecutors) {
        if (params.getMcpToolProviders() == null || params.getMcpToolProviders().isEmpty()) {
            return;
        }
        for (McpToolProvider toolProvider : params.getMcpToolProviders()) {
            ToolProviderRequest request;
            ToolProviderResult result = toolProvider.provideTools(request = new ToolProviderRequest(chatMessage.chatMemory.id(), chatMessage.userMessage));
            if (result == null || result.tools() == null) continue;
            for (Map.Entry entry : result.tools().entrySet()) {
                toolSpecifications.add((ToolSpecification)entry.getKey());
                toolExecutors.put(((ToolSpecification)entry.getKey()).name(), (ToolExecutor)entry.getValue());
            }
        }
    }

    private CollateMsgResp collateMessage(List<ChatMessage> messages, AIParams params) {
        DefaultRetrievalAugmentor retrievalAugmentor;
        if (null == params) {
            params = new AIParams();
        }
        ArrayList<ChatMessage> messagesCopy = new ArrayList<ChatMessage>(messages);
        UserMessage userMessage = null;
        ChatMessage lastMessage = (ChatMessage)messagesCopy.get(messagesCopy.size() - 1);
        if (!lastMessage.type().equals((Object)ChatMessageType.USER)) {
            throw new IllegalArgumentException("\u6700\u540e\u4e00\u6761\u6d88\u606f\u5fc5\u987b\u662f\u7528\u6237\u6d88\u606f");
        }
        userMessage = (UserMessage)messagesCopy.remove(messagesCopy.size() - 1);
        int maxMsgNumber = 6;
        if (null != params.getMaxMsgNumber()) {
            maxMsgNumber = params.getMaxMsgNumber() + 2;
        }
        AtomicReference systemMessageAto = new AtomicReference();
        messagesCopy.removeIf(tempMsg -> {
            if (ChatMessageType.SYSTEM.equals((Object)tempMsg.type())) {
                if (systemMessageAto.get() == null) {
                    systemMessageAto.set((SystemMessage)tempMsg);
                } else {
                    SystemMessage systemMessage = (SystemMessage)systemMessageAto.get();
                    String text = systemMessage.text() + "\n" + ((SystemMessage)tempMsg).text();
                    systemMessageAto.set(SystemMessage.from((String)text));
                }
                return true;
            }
            return false;
        });
        MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder().maxMessages(Integer.valueOf(maxMsgNumber)).build();
        if (null != systemMessageAto.get()) {
            chatMemory.add((ChatMessage)systemMessageAto.get());
        }
        messagesCopy.forEach(arg_0 -> ((ChatMemory)chatMemory).add(arg_0));
        AugmentationResult augmentationResult = null;
        if (null != params.getQueryRouter() && (retrievalAugmentor = DefaultRetrievalAugmentor.builder().queryRouter(params.getQueryRouter()).build()) != null) {
            StringBuilder userQuestion = new StringBuilder();
            ArrayList<TextContent> contents = new ArrayList<TextContent>(userMessage.contents());
            for (int i = contents.size() - 1; i >= 0; --i) {
                if (!(contents.get(i) instanceof TextContent)) continue;
                userQuestion.append(((TextContent)contents.remove(i)).text());
                userQuestion.append("\n");
            }
            UserMessage textUserMessage = UserMessage.from((String)userQuestion.toString());
            Metadata metadata = Metadata.from((ChatMessage)textUserMessage, (Object)"default", (List)chatMemory.messages());
            AugmentationRequest augmentationRequest = new AugmentationRequest((ChatMessage)textUserMessage, metadata);
            augmentationResult = retrievalAugmentor.augment(augmentationRequest);
            textUserMessage = (UserMessage)augmentationResult.chatMessage();
            contents.add(TextContent.from((String)textUserMessage.singleText()));
            userMessage = UserMessage.from(contents);
        }
        chatMemory.add((ChatMessage)userMessage);
        return new CollateMsgResp((ChatMemory)chatMemory, augmentationResult, userMessage);
    }

    public List<Map<String, Object>> imageGenerate(String prompt, AIParams params) {
        if (null == (params = this.ensureParams(params))) {
            throw new IllegalArgumentException("\u5927\u8bed\u8a00\u6a21\u578b\u53c2\u6570\u4e3a\u7a7a");
        }
        AiModelOptions options = params.toModelOptions();
        ImageModel imageModel = AiModelFactory.createImageModel(options);
        ArrayList<Map<String, Object>> result = new ArrayList<Map<String, Object>>();
        Integer imageCount = params.imageCount;
        int count = imageCount == null || imageCount < 1 ? 1 : imageCount;
        try {
            for (int i = 0; i < count; ++i) {
                Response resp = imageModel.generate(prompt);
                Image image = (Image)resp.content();
                HashMap<String, Object> item = new HashMap<String, Object>();
                if (StringUtils.isNotEmpty((String)image.base64Data())) {
                    item.put("type", "base64");
                    Object base64 = image.base64Data();
                    if (!((String)base64).startsWith("data:")) {
                        base64 = "data:image/png;base64," + (String)base64;
                    }
                    item.put("value", base64);
                } else if (image.url() != null) {
                    item.put("type", "http");
                    item.put("value", image.url().toString());
                }
                result.add(item);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e.getMessage(), e);
        }
        return result;
    }

    public List<Map<String, Object>> imageEdit(String prompt, List<String> originalImages, AIParams params) {
        if (!"QWEN".equalsIgnoreCase(params.getProvider())) {
            log.info("\u9664\u4e07\u8c61\u6a21\u578b\u5176\u4ed6\u6a21\u578b\u6682\u4e0d\u652f\u6301\u56fe\u751f\u56fe\u6a21\u5f0f\uff0c\u4f7f\u7528\u6587\u751f\u56fe\u6a21\u5f0f\uff0c\u5f53\u524d\u6a21\u578b\uff1a" + params.modelName);
            return this.imageGenerate(prompt, params);
        }
        if (null == (params = this.ensureParams(params))) {
            throw new IllegalArgumentException("\u5927\u8bed\u8a00\u6a21\u578b\u53c2\u6570\u4e3a\u7a7a");
        }
        if (CollectionUtils.isEmpty(originalImages)) {
            throw new IllegalArgumentException("\u539f\u59cb\u56fe\u7247\u4e0d\u80fd\u4e3a\u7a7a");
        }
        AiModelOptions options = params.toModelOptions();
        Integer imageCount = params.imageCount;
        if (QwenImageModelEnum.WANX_2_1_IMAGE_EDIT.getModelName().equals(options.getModelName()) || QwenImageModelEnum.WAN_2_5_I2I_PREVIEW.getModelName().equals(options.getModelName())) {
            return this.imageEditQwen(prompt, originalImages, options, imageCount);
        }
        return this.imageEditDefault(prompt, originalImages.get(0), options, imageCount);
    }

    private List<Map<String, Object>> imageEditDefault(String prompt, String originalImageBase64, AiModelOptions options, Integer imageCount) {
        if (originalImageBase64.contains("base64,")) {
            originalImageBase64 = originalImageBase64.split("base64,")[1];
        }
        ImageModel imageModel = AiModelFactory.createImageModel(options);
        ArrayList<Map<String, Object>> result = new ArrayList<Map<String, Object>>();
        Image inputImage = Image.builder().base64Data(originalImageBase64).build();
        int count = imageCount == null || imageCount < 1 ? 1 : imageCount;
        try {
            for (int i = 0; i < count; ++i) {
                Response response = imageModel.edit(inputImage, prompt);
                Image image = (Image)response.content();
                HashMap<String, Object> item = new HashMap<String, Object>();
                if (StringUtils.isNotEmpty((String)image.base64Data())) {
                    item.put("type", "base64");
                    Object base64 = image.base64Data();
                    if (!((String)base64).startsWith("data:")) {
                        base64 = "data:image/png;base64," + (String)base64;
                    }
                    item.put("value", base64);
                } else if (image.url() != null) {
                    item.put("type", "http");
                    item.put("value", image.url().toString());
                }
                result.add(item);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e.getMessage(), e);
        }
        return result;
    }

    private List<Map<String, Object>> imageEditQwen(String prompt, List<String> originalImages, AiModelOptions options, Integer imageCount) {
        originalImages = this.checkAndResizeImage(originalImages);
        int count = imageCount == null || imageCount < 1 ? 1 : imageCount;
        ArrayList<Map<String, Object>> result = new ArrayList<Map<String, Object>>();
        try {
            ImageSynthesis imageSynthesis;
            ImageSynthesisResult resultResponse;
            ImageSynthesisParam param = ((ImageSynthesisParam.ImageSynthesisParamBuilder)((ImageSynthesisParam.ImageSynthesisParamBuilder)ImageSynthesisParam.builder().apiKey(options.getApiKey())).model(StringUtils.isNotEmpty((String)options.getModelName()) ? options.getModelName() : QwenImageModelEnum.WAN_2_5_I2I_PREVIEW.getModelName())).prompt(prompt).function("description_edit").n(Integer.valueOf(count)).build();
            if (StringUtils.isNotEmpty((String)options.getImageSize())) {
                param.setSize(options.getImageSize());
            } else {
                param.setSize("1024*1024");
            }
            if (QwenImageModelEnum.WAN_2_5_I2I_PREVIEW.getModelName().equals(options.getModelName())) {
                param.setImages(originalImages);
            }
            if (QwenImageModelEnum.WANX_2_1_IMAGE_EDIT.getModelName().equals(options.getModelName())) {
                param.setBaseImageUrl(originalImages.get(0));
            }
            if ((resultResponse = (imageSynthesis = new ImageSynthesis("text2image", options.getBaseUrl())).call(param)).getOutput() != null && resultResponse.getOutput().getResults() != null) {
                for (Map item : resultResponse.getOutput().getResults()) {
                    HashMap<String, Object> map = new HashMap<String, Object>();
                    if (item.containsKey("url")) {
                        map.put("type", "http");
                        map.put("value", item.get("url"));
                    } else if (item.containsKey("b64_json")) {
                        map.put("type", "base64");
                        Object b64 = (String)item.get("b64_json");
                        if (!((String)b64).startsWith("data:")) {
                            b64 = "data:image/png;base64," + (String)b64;
                        }
                        map.put("value", b64);
                    }
                    result.add(map);
                }
            } else {
                log.error(resultResponse.getOutput().getMessage());
            }
            return result;
        }
        catch (Exception e) {
            throw new RuntimeException("Qwen image edit failed: " + e.getMessage(), e);
        }
    }

    private List<String> checkAndResizeImage(List<String> base64ImageList) {
        ArrayList<String> result = new ArrayList<String>();
        for (String base64Image : base64ImageList) {
            try {
                String base64Data = base64Image;
                if (base64Image.contains("base64,")) {
                    String[] parts = base64Image.split("base64,");
                    base64Data = parts[1];
                }
                base64Data = base64Data.replaceAll("[\\s\r\n]", "");
                byte[] imageBytes = Base64.getDecoder().decode(base64Data);
                ByteArrayInputStream bis = new ByteArrayInputStream(imageBytes);
                BufferedImage image = ImageIO.read(bis);
                if (image == null) {
                    log.warn("ImageIO read failed, use original image");
                    result.add("data:image/png;base64," + base64Image);
                    continue;
                }
                int width = image.getWidth();
                int height = image.getHeight();
                int minHeight = 512;
                int maxHeight = 4096;
                if (height >= minHeight && height <= maxHeight) {
                    result.add("data:image/png;base64," + base64Image);
                    continue;
                }
                int newHeight = height;
                int newWidth = width;
                if (height < minHeight) {
                    newHeight = minHeight;
                    newWidth = (int)((double)width * ((double)minHeight / (double)height));
                } else if (height > maxHeight) {
                    newHeight = maxHeight;
                    newWidth = (int)((double)width * ((double)maxHeight / (double)height));
                }
                log.info("Resize image from {}x{} to {}x{}", new Object[]{width, height, newWidth, newHeight});
                BufferedImage outputImage = new BufferedImage(newWidth, newHeight, 1);
                Graphics2D g2d = outputImage.createGraphics();
                g2d.drawImage(image.getScaledInstance(newWidth, newHeight, 4), 0, 0, null);
                g2d.dispose();
                ByteArrayOutputStream bos = new ByteArrayOutputStream();
                ImageIO.write((RenderedImage)outputImage, "png", bos);
                byte[] newBytes = bos.toByteArray();
                String newBase64 = Base64.getEncoder().encodeToString(newBytes);
                result.add("data:image/png;base64," + newBase64);
            }
            catch (Exception e) {
                log.error("Check and resize image failed: {}", (Object)e.getMessage());
                result.add("data:image/png;base64," + base64Image);
            }
        }
        return result;
    }

    private static class CollateMsgResp {
        public final ChatMemory chatMemory;
        public final AugmentationResult augmentationResult;
        public final UserMessage userMessage;

        public CollateMsgResp(ChatMemory chatMemory, AugmentationResult augmentationResult, UserMessage userMessage) {
            this.chatMemory = chatMemory;
            this.augmentationResult = augmentationResult;
            this.userMessage = userMessage;
        }

        public String toString() {
            return "{messages=" + (this.chatMemory != null ? this.chatMemory.messages() : "null") + "}";
        }
    }
}

