/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.transform.nlpmodel.embedding.remote.qianfan;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;

public class QianfanModel
extends AbstractModel {
    private final CloseableHttpClient client;
    private final String apiKey;
    private final String secretKey;
    private final String model;
    private final String apiPath;
    private final String oauthPath;
    private final String oauthSuffixPath = "?grant_type=client_credentials&client_id=%s&client_secret=%s";
    private String accessToken;

    public QianfanModel(String apiKey, String secretKey, String model, String apiPath, String oauthPath, Integer vectorizedNumber) throws IOException {
        super(vectorizedNumber);
        this.apiKey = apiKey;
        this.secretKey = secretKey;
        this.model = model;
        this.apiPath = apiPath;
        this.oauthPath = oauthPath;
        this.client = HttpClients.createDefault();
        this.accessToken = this.getAccessToken();
    }

    public QianfanModel(String apiKey, String secretKey, String model, String apiPath, Integer vectorizedNumber, String oauthPath, String accessToken) throws IOException {
        super(vectorizedNumber);
        this.apiKey = apiKey;
        this.secretKey = secretKey;
        this.model = model;
        this.apiPath = apiPath;
        this.oauthPath = oauthPath;
        this.client = HttpClients.createDefault();
        this.accessToken = accessToken;
    }

    private String getAccessToken() throws IOException {
        HttpGet get = new HttpGet(String.format(this.oauthPath + "?grant_type=client_credentials&client_id=%s&client_secret=%s", this.apiKey, this.secretKey));
        CloseableHttpResponse response = this.client.execute(get);
        String responseStr = EntityUtils.toString(response.getEntity());
        if (response.getStatusLine().getStatusCode() != 200) {
            throw new IOException("Failed to Oauth for qianfan, response: " + responseStr);
        }
        JsonNode result = OBJECT_MAPPER.readTree(responseStr);
        return result.get("access_token").asText();
    }

    @Override
    public List<List<Float>> vector(Object[] fields) throws IOException {
        return this.vectorGeneration(fields);
    }

    @Override
    public Integer dimension() throws IOException {
        return this.vectorGeneration(new Object[]{"dimension example"}).get(0).size();
    }

    private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
        String formattedApiPath = String.format((this.apiPath.endsWith("/") ? this.apiPath : this.apiPath + "/") + "%s?access_token=%s", this.model, this.accessToken);
        HttpPost post = new HttpPost(formattedApiPath);
        post.setHeader("Content-Type", "application/json");
        post.setConfig(RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
        post.setEntity(new StringEntity(OBJECT_MAPPER.writeValueAsString((Object)this.createJsonNodeFromData(fields)), "UTF-8"));
        CloseableHttpResponse response = this.client.execute(post);
        String responseStr = EntityUtils.toString(response.getEntity());
        if (response.getStatusLine().getStatusCode() != 200) {
            throw new IOException("Failed to get vector from qianfan, response: " + responseStr);
        }
        JsonNode result = OBJECT_MAPPER.readTree(responseStr);
        JsonNode errorCode = result.get("error_code");
        if (errorCode != null) {
            if (errorCode.asInt() == 110) {
                this.accessToken = this.getAccessToken();
            }
            throw new IOException("Failed to get vector from qianfan, response: " + result.get("error_msg"));
        }
        ArrayList<List<Float>> embeddings = new ArrayList<List<Float>>();
        JsonNode data = result.get("data");
        if (data.isArray()) {
            for (JsonNode node : data) {
                List embedding = (List)OBJECT_MAPPER.readValue(node.get("embedding").traverse(), (TypeReference)new TypeReference<List<Float>>(){});
                embeddings.add(embedding);
            }
        }
        return embeddings;
    }

    @VisibleForTesting
    public ObjectNode createJsonNodeFromData(Object[] data) {
        ArrayNode arrayNode = (ArrayNode)OBJECT_MAPPER.valueToTree(Arrays.asList(data));
        return (ObjectNode)OBJECT_MAPPER.createObjectNode().set("input", (JsonNode)arrayNode);
    }

    @Override
    public void close() throws IOException {
        if (this.client != null) {
            this.client.close();
        }
    }
}

