/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
import org.apache.seatunnel.transform.nlpmodel.embedding.remote.AbstractModel;

public class DoubaoModel
extends AbstractModel {
    private final CloseableHttpClient client;
    private final String apiKey;
    private final String model;
    private final String apiPath;

    public DoubaoModel(String apiKey, String model, String apiPath, Integer vectorizedNumber) {
        super(vectorizedNumber);
        this.apiKey = apiKey;
        this.model = model;
        this.apiPath = apiPath;
        this.client = HttpClients.createDefault();
    }

    @Override
    protected List<List<Float>> vector(Object[] fields) throws IOException {
        return this.vectorGeneration(fields);
    }

    @Override
    public Integer dimension() throws IOException {
        return this.vectorGeneration(new Object[]{"dimension example"}).size();
    }

    private List<List<Float>> vectorGeneration(Object[] fields) throws IOException {
        HttpPost post = new HttpPost(this.apiPath);
        post.setHeader("Authorization", "Bearer " + this.apiKey);
        post.setHeader("Content-Type", "application/json");
        post.setConfig(RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
        post.setEntity(new StringEntity(OBJECT_MAPPER.writeValueAsString((Object)this.createJsonNodeFromData(fields)), "UTF-8"));
        CloseableHttpResponse response = this.client.execute(post);
        String responseStr = EntityUtils.toString(response.getEntity());
        if (response.getStatusLine().getStatusCode() != 200) {
            throw new IOException("Failed to get vector from doubao, response: " + responseStr);
        }
        JsonNode data = OBJECT_MAPPER.readTree(responseStr).get("data");
        ArrayList<List<Float>> embeddings = new ArrayList<List<Float>>();
        if (data.isArray()) {
            for (JsonNode node : data) {
                JsonNode embeddingNode = node.get("embedding");
                List embedding = (List)OBJECT_MAPPER.readValue(embeddingNode.traverse(), (TypeReference)new TypeReference<List<Float>>(){});
                embeddings.add(embedding);
            }
        }
        return embeddings;
    }

    @VisibleForTesting
    public ObjectNode createJsonNodeFromData(Object[] fields) {
        ArrayNode arrayNode = (ArrayNode)OBJECT_MAPPER.valueToTree(Arrays.asList(fields));
        return (ObjectNode)OBJECT_MAPPER.createObjectNode().put("model", this.model).set("input", (JsonNode)arrayNode);
    }

    @Override
    public void close() throws IOException {
        if (this.client != null) {
            this.client.close();
        }
    }
}

