/*
 * Decompiled with CFR 0.152.
 */
package biz.k11i.xgboost;

import biz.k11i.xgboost.config.PredictorConfiguration;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.spark.SparkModelParam;
import biz.k11i.xgboost.util.FVec;
import biz.k11i.xgboost.util.ModelReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;

public class Predictor
implements Serializable {
    private ModelParam mparam;
    private SparkModelParam sparkModelParam;
    private String name_obj;
    private String name_gbm;
    private ObjFunction obj;
    private GradBooster gbm;

    public Predictor(InputStream in) throws IOException {
        this(in, null);
    }

    public Predictor(InputStream in, PredictorConfiguration configuration) throws IOException {
        if (configuration == null) {
            configuration = PredictorConfiguration.DEFAULT;
        }
        ModelReader reader = new ModelReader(in);
        this.readParam(reader);
        this.initObjFunction(configuration);
        this.initObjGbm();
        this.gbm.loadModel(reader, this.mparam.saved_with_pbuffer != 0);
    }

    void readParam(ModelReader reader) throws IOException {
        int num_feature;
        float base_score;
        byte[] first4Bytes = reader.readByteArray(4);
        byte[] next4Bytes = reader.readByteArray(4);
        if (first4Bytes[0] == 98 && first4Bytes[1] == 105 && first4Bytes[2] == 110 && first4Bytes[3] == 102) {
            base_score = reader.asFloat(next4Bytes);
            num_feature = reader.readUnsignedInt();
        } else if (first4Bytes[0] == 0 && first4Bytes[1] == 5 && first4Bytes[2] == 95) {
            String modelType = null;
            if (first4Bytes[3] == 99 && next4Bytes[0] == 108 && next4Bytes[1] == 115 && next4Bytes[2] == 95) {
                modelType = "_cls_";
            } else if (first4Bytes[3] == 114 && next4Bytes[0] == 101 && next4Bytes[1] == 103 && next4Bytes[2] == 95) {
                modelType = "_reg_";
            }
            if (modelType != null) {
                int len = (next4Bytes[3] << 8) + reader.readByteAsInt();
                String featuresCol = reader.readUTF(len);
                this.sparkModelParam = new SparkModelParam(modelType, featuresCol, reader);
                base_score = reader.readFloat();
                num_feature = reader.readUnsignedInt();
            } else {
                base_score = reader.asFloat(first4Bytes);
                num_feature = reader.asUnsignedInt(next4Bytes);
            }
        } else {
            base_score = reader.asFloat(first4Bytes);
            num_feature = reader.asUnsignedInt(next4Bytes);
        }
        this.mparam = new ModelParam(base_score, num_feature, reader);
        this.name_obj = reader.readString();
        this.name_gbm = reader.readString();
    }

    void initObjFunction(PredictorConfiguration configuration) {
        this.obj = configuration.getObjFunction();
        if (this.obj == null) {
            this.obj = ObjFunction.fromName(this.name_obj);
        }
    }

    void initObjGbm() {
        this.obj = ObjFunction.fromName(this.name_obj);
        this.gbm = GradBooster.Factory.createGradBooster(this.name_gbm);
        this.gbm.setNumClass(this.mparam.num_class);
    }

    public double[] predict(FVec feat) {
        return this.predict(feat, false);
    }

    public double[] predict(FVec feat, boolean output_margin) {
        return this.predict(feat, output_margin, 0);
    }

    public double[] predict(FVec feat, boolean output_margin, int ntree_limit) {
        double[] preds = this.predictRaw(feat, ntree_limit);
        if (!output_margin) {
            return this.obj.predTransform(preds);
        }
        return preds;
    }

    double[] predictRaw(FVec feat, int ntree_limit) {
        double[] preds = this.gbm.predict(feat, ntree_limit);
        int i = 0;
        while (i < preds.length) {
            int n = i++;
            preds[n] = preds[n] + (double)this.mparam.base_score;
        }
        return preds;
    }

    public double predictSingle(FVec feat) {
        return this.predictSingle(feat, false);
    }

    public double predictSingle(FVec feat, boolean output_margin) {
        return this.predictSingle(feat, output_margin, 0);
    }

    public double predictSingle(FVec feat, boolean output_margin, int ntree_limit) {
        double pred = this.predictSingleRaw(feat, ntree_limit);
        if (!output_margin) {
            return this.obj.predTransform(pred);
        }
        return pred;
    }

    double predictSingleRaw(FVec feat, int ntree_limit) {
        return this.gbm.predictSingle(feat, ntree_limit) + (double)this.mparam.base_score;
    }

    public int[] predictLeaf(FVec feat) {
        return this.predictLeaf(feat, 0);
    }

    public int[] predictLeaf(FVec feat, int ntree_limit) {
        return this.gbm.predictLeaf(feat, ntree_limit);
    }

    public SparkModelParam getSparkModelParam() {
        return this.sparkModelParam;
    }

    public int getNumClass() {
        return this.mparam.num_class;
    }

    static class ModelParam
    implements Serializable {
        final float base_score;
        final int num_feature;
        final int num_class;
        final int saved_with_pbuffer;
        final int[] reserved;

        ModelParam(float base_score, int num_feature, ModelReader reader) throws IOException {
            this.base_score = base_score;
            this.num_feature = num_feature;
            this.num_class = reader.readInt();
            this.saved_with_pbuffer = reader.readInt();
            this.reserved = reader.readIntArray(30);
        }
    }
}

