/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.stats;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.GregorianCalendar;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUnknownAs;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.datasketches.kll.KllFloatsSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.quantilescommon.QuantileSearchCriteria;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveConfPlannerContext;
import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
import org.apache.hadoop.hive.ql.optimizer.calcite.SearchTransformer;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveIn;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.stats.HiveRelMdDistinctRowCount;
import org.apache.hadoop.hive.ql.plan.ColStatistics;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FilterSelectivityEstimator
extends RexVisitorImpl<Double> {
    protected static final Logger LOG = LoggerFactory.getLogger(FilterSelectivityEstimator.class);
    private final RelNode childRel;
    private final double childCardinality;
    private final RelMetadataQuery mq;
    private final RexBuilder rexBuilder;

    public FilterSelectivityEstimator(RelNode childRel, RelMetadataQuery mq) {
        super(true);
        this.mq = mq;
        this.childRel = childRel;
        this.childCardinality = mq.getRowCount(childRel);
        this.rexBuilder = childRel.getCluster().getRexBuilder();
    }

    public Double estimateSelectivity(RexNode predicate) {
        return (Double)predicate.accept((RexVisitor)this);
    }

    public Double visitInputRef(RexInputRef inputRef) {
        if (inputRef.getType().getSqlTypeName() == SqlTypeName.BOOLEAN) {
            return 0.5;
        }
        return null;
    }

    public Double visitCall(RexCall call) {
        Double selectivity;
        if (!this.deep) {
            return 1.0;
        }
        if (this.isPartitionPredicate((RexNode)call, this.childRel)) {
            return 1.0;
        }
        switch (call.getKind()) {
            case AND: {
                selectivity = this.computeConjunctionSelectivity(call);
                break;
            }
            case SEARCH: {
                return (Double)new SearchTransformer(this.rexBuilder, call, RexUnknownAs.FALSE).transform().accept((RexVisitor)this);
            }
            case OR: {
                selectivity = this.computeDisjunctionSelectivity(call);
                break;
            }
            case NOT: {
                Double opSelectivity = (Double)((RexNode)call.getOperands().get(0)).accept((RexVisitor)this);
                assert (opSelectivity >= 0.0 && opSelectivity <= 1.0);
                return 1.0 - opSelectivity;
            }
            case NOT_EQUALS: {
                selectivity = this.computeNotEqualitySelectivity(call);
                break;
            }
            case IS_NOT_NULL: {
                if (this.childRel instanceof HiveTableScan) {
                    double noOfNulls = this.getMaxNulls(call, (HiveTableScan)this.childRel);
                    double totalNoOfTuples = this.mq.getRowCount(this.childRel);
                    if (totalNoOfTuples >= noOfNulls) {
                        selectivity = (totalNoOfTuples - noOfNulls) / Math.max(totalNoOfTuples, 1.0);
                        break;
                    }
                    HiveConfPlannerContext ctx = (HiveConfPlannerContext)this.childRel.getCluster().getPlanner().getContext().unwrap(HiveConfPlannerContext.class);
                    String msg = "Invalid statistics: Number of null values > number of tuples. Consider recomputing statistics for table: " + ((RelOptHiveTable)this.childRel.getTable()).getHiveTableMD().getFullyQualifiedName();
                    if (ctx.isExplainPlan()) {
                        SessionState.getConsole().printError("WARNING: " + msg);
                    }
                    LOG.warn(msg);
                    selectivity = 0.3333333333333333;
                    break;
                }
                selectivity = this.computeNotEqualitySelectivity(call);
                break;
            }
            case LESS_THAN_OR_EQUAL: 
            case GREATER_THAN_OR_EQUAL: 
            case LESS_THAN: 
            case GREATER_THAN: {
                selectivity = this.computeRangePredicateSelectivity(call, call.getKind());
                break;
            }
            case BETWEEN: {
                selectivity = this.computeBetweenPredicateSelectivity(call);
                break;
            }
            default: {
                if (HiveIn.INSTANCE.equals((Object)call.op)) {
                    selectivity = this.computeFunctionSelectivity(call);
                    if (selectivity == null) break;
                    if ((selectivity = Double.valueOf(selectivity * (double)(call.operands.size() - 1))) <= 0.0) {
                        selectivity = 0.1;
                        break;
                    }
                    if (!(selectivity >= 1.0)) break;
                    selectivity = 1.0;
                    break;
                }
                selectivity = this.computeFunctionSelectivity(call);
            }
        }
        return selectivity;
    }

    private double computeRangePredicateSelectivity(RexCall call, SqlKind op) {
        int inputRefIndex;
        HiveTableScan t;
        List<ColStatistics> colStats;
        boolean isLiteralLeft = ((RexNode)call.getOperands().get(0)).getKind().equals((Object)SqlKind.LITERAL);
        boolean isLiteralRight = ((RexNode)call.getOperands().get(1)).getKind().equals((Object)SqlKind.LITERAL);
        boolean isInputRefLeft = ((RexNode)call.getOperands().get(0)).getKind().equals((Object)SqlKind.INPUT_REF);
        boolean isInputRefRight = ((RexNode)call.getOperands().get(1)).getKind().equals((Object)SqlKind.INPUT_REF);
        if (this.childRel instanceof HiveTableScan && isLiteralLeft != isLiteralRight && isInputRefLeft != isInputRefRight && !(colStats = (t = (HiveTableScan)this.childRel).getColStat(Collections.singletonList(inputRefIndex = ((RexInputRef)call.getOperands().get(isInputRefLeft ? 0 : 1)).getIndex()))).isEmpty() && FilterSelectivityEstimator.isHistogramAvailable(colStats.get(0))) {
            boolean closedBound;
            KllFloatsSketch kll = KllFloatsSketch.heapify((Memory)Memory.wrap((byte[])colStats.get(0).getHistogram()));
            Comparable boundValueObject = ((RexLiteral)call.getOperands().get(isLiteralLeft ? 0 : 1)).getValue();
            SqlTypeName typeName = ((RexNode)call.getOperands().get(isInputRefLeft ? 0 : 1)).getType().getSqlTypeName();
            float value = this.extractLiteral(typeName, boundValueObject);
            boolean bl = closedBound = op.equals((Object)SqlKind.LESS_THAN_OR_EQUAL) || op.equals((Object)SqlKind.GREATER_THAN_OR_EQUAL);
            double selectivity = op.equals((Object)SqlKind.LESS_THAN_OR_EQUAL) || op.equals((Object)SqlKind.LESS_THAN) ? (closedBound ? FilterSelectivityEstimator.lessThanOrEqualSelectivity(kll, value) : FilterSelectivityEstimator.lessThanSelectivity(kll, value)) : (closedBound ? FilterSelectivityEstimator.greaterThanOrEqualSelectivity(kll, value) : FilterSelectivityEstimator.greaterThanSelectivity(kll, value));
            return (double)kll.getN() * selectivity / t.getTable().getRowCount();
        }
        return 0.3333333333333333;
    }

    private Double computeBetweenPredicateSelectivity(RexCall call) {
        int inputRefIndex;
        HiveTableScan t;
        List<ColStatistics> colStats;
        boolean hasLiteralBool = ((RexNode)call.getOperands().get(0)).getKind().equals((Object)SqlKind.LITERAL);
        boolean hasInputRef = ((RexNode)call.getOperands().get(1)).getKind().equals((Object)SqlKind.INPUT_REF);
        boolean hasLiteralLeft = ((RexNode)call.getOperands().get(2)).getKind().equals((Object)SqlKind.LITERAL);
        boolean hasLiteralRight = ((RexNode)call.getOperands().get(3)).getKind().equals((Object)SqlKind.LITERAL);
        if (this.childRel instanceof HiveTableScan && hasLiteralBool && hasInputRef && hasLiteralLeft && hasLiteralRight && !(colStats = (t = (HiveTableScan)this.childRel).getColStat(Collections.singletonList(inputRefIndex = ((RexInputRef)call.getOperands().get(1)).getIndex()))).isEmpty() && FilterSelectivityEstimator.isHistogramAvailable(colStats.get(0))) {
            KllFloatsSketch kll = KllFloatsSketch.heapify((Memory)Memory.wrap((byte[])colStats.get(0).getHistogram()));
            SqlTypeName typeName = ((RexNode)call.getOperands().get(1)).getType().getSqlTypeName();
            Comparable inverseBoolValueObject = ((RexLiteral)call.getOperands().get(0)).getValue();
            boolean inverseBool = Boolean.parseBoolean(inverseBoolValueObject.toString());
            Comparable leftBoundValueObject = ((RexLiteral)call.getOperands().get(2)).getValue();
            float leftValue = this.extractLiteral(typeName, leftBoundValueObject);
            Comparable rightBoundValueObject = ((RexLiteral)call.getOperands().get(3)).getValue();
            float rightValue = this.extractLiteral(typeName, rightBoundValueObject);
            if (inverseBool) {
                if (rightValue == leftValue) {
                    return this.computeNotEqualitySelectivity(call);
                }
                if (rightValue < leftValue) {
                    return 1.0;
                }
                return 1.0 - (double)kll.getN() * FilterSelectivityEstimator.betweenSelectivity(kll, leftValue, rightValue) / t.getTable().getRowCount();
            }
            if (Double.compare(leftValue, rightValue) != 0) {
                return (double)kll.getN() * FilterSelectivityEstimator.betweenSelectivity(kll, leftValue, rightValue) / t.getTable().getRowCount();
            }
        }
        return this.computeFunctionSelectivity(call);
    }

    private float extractLiteral(SqlTypeName typeName, Object boundValueObject) {
        String boundValueString = boundValueObject.toString();
        return switch (typeName) {
            case SqlTypeName.TINYINT -> Byte.parseByte(boundValueString);
            case SqlTypeName.SMALLINT -> Short.parseShort(boundValueString);
            case SqlTypeName.INTEGER -> Integer.parseInt(boundValueString);
            case SqlTypeName.BIGINT -> Long.parseLong(boundValueString);
            case SqlTypeName.FLOAT -> Float.parseFloat(boundValueString);
            case SqlTypeName.DOUBLE -> (float)Double.parseDouble(boundValueString);
            case SqlTypeName.DECIMAL -> new BigDecimal(boundValueString).floatValue();
            case SqlTypeName.DATE, SqlTypeName.TIMESTAMP -> ((GregorianCalendar)boundValueObject).toInstant().getEpochSecond();
            default -> throw new IllegalStateException("Unsupported type for comparator selectivity evaluation using histogram: " + String.valueOf(typeName));
        };
    }

    private Double computeNotEqualitySelectivity(RexCall call) {
        Double tmpNDV = this.getMaxNDV(call);
        if (tmpNDV == null) {
            return null;
        }
        if (tmpNDV > 1.0) {
            return (tmpNDV - 1.0) / tmpNDV;
        }
        return 1.0;
    }

    private Double computeFunctionSelectivity(RexCall call) {
        Double tmpNDV = this.getMaxNDV(call);
        if (tmpNDV == null) {
            return null;
        }
        return 1.0 / tmpNDV;
    }

    private Double computeDisjunctionSelectivity(RexCall call) {
        double selectivity = 1.0;
        for (RexNode dje : call.getOperands()) {
            Double tmpCardinality;
            Double tmpSelectivity = (Double)dje.accept((RexVisitor)this);
            if (tmpSelectivity == null) {
                tmpSelectivity = 0.99;
            }
            tmpSelectivity = (tmpCardinality = Double.valueOf(this.childCardinality * tmpSelectivity)) > 1.0 && tmpCardinality < this.childCardinality ? Double.valueOf(1.0 - tmpCardinality / this.childCardinality) : Double.valueOf(1.0);
            selectivity *= tmpSelectivity.doubleValue();
        }
        if (selectivity < 0.0) {
            selectivity = 0.0;
        }
        return 1.0 - selectivity;
    }

    private Double computeConjunctionSelectivity(RexCall call) {
        double selectivity = 1.0;
        for (RexNode cje : call.getOperands()) {
            Double tmpSelectivity = (Double)cje.accept((RexVisitor)this);
            if (tmpSelectivity == null) continue;
            selectivity *= tmpSelectivity.doubleValue();
        }
        return selectivity;
    }

    private long getMaxNulls(RexCall call, HiveTableScan t) {
        long tmpNoNulls = 0L;
        long maxNoNulls = 0L;
        Set<Integer> iRefSet = HiveCalciteUtil.getInputRefs((RexNode)call);
        List<ColStatistics> colStats = t.getColStat(new ArrayList<Integer>(iRefSet));
        for (ColStatistics cs : colStats) {
            tmpNoNulls = cs.getNumNulls();
            if (tmpNoNulls <= maxNoNulls) continue;
            maxNoNulls = tmpNoNulls;
        }
        return maxNoNulls;
    }

    private Double getMaxNDV(RexCall call) {
        double maxNDV = 1.0;
        for (RexNode op : call.getOperands()) {
            Double tmpNDV;
            if (op instanceof RexInputRef) {
                tmpNDV = HiveRelMdDistinctRowCount.getDistinctRowCount(this.childRel, this.mq, ((RexInputRef)op).getIndex());
                if (tmpNDV == null) {
                    return null;
                }
                if (!(tmpNDV > maxNDV)) continue;
                maxNDV = tmpNDV;
                continue;
            }
            RelOptUtil.InputReferencedVisitor irv = new RelOptUtil.InputReferencedVisitor();
            irv.apply(op);
            for (Integer childProjIndx : irv.inputPosReferenced) {
                tmpNDV = HiveRelMdDistinctRowCount.getDistinctRowCount(this.childRel, this.mq, childProjIndx);
                if (tmpNDV == null) {
                    return null;
                }
                if (!(tmpNDV > maxNDV)) continue;
                maxNDV = tmpNDV;
            }
        }
        return maxNDV;
    }

    private boolean isPartitionPredicate(RexNode expr, RelNode r) {
        if (r instanceof Project) {
            expr = RelOptUtil.pushFilterPastProject((RexNode)expr, (Project)((Project)r));
            return this.isPartitionPredicate(expr, ((Project)r).getInput());
        }
        if (r instanceof Filter) {
            return this.isPartitionPredicate(expr, ((Filter)r).getInput());
        }
        if (r instanceof HiveTableScan) {
            RelOptHiveTable table = (RelOptHiveTable)((HiveTableScan)r).getTable();
            ImmutableBitSet cols = RelOptUtil.InputFinder.bits((RexNode)expr);
            return table.containsPartitionColumnsOnly(cols);
        }
        return false;
    }

    public Double visitLiteral(RexLiteral literal) {
        if (literal.isAlwaysFalse() || RexUtil.isNull((RexNode)literal)) {
            return 0.0;
        }
        if (literal.isAlwaysTrue()) {
            return 1.0;
        }
        assert (false);
        return null;
    }

    private static double rangedSelectivity(KllFloatsSketch kll, float val1, float val2) {
        float[] splitPoints = new float[]{val1, val2};
        double[] boundaries = kll.getCDF(splitPoints, QuantileSearchCriteria.EXCLUSIVE);
        return boundaries[1] - boundaries[0];
    }

    public static double greaterThanSelectivity(KllFloatsSketch kll, float value) {
        float max = kll.getMaxItem();
        if (value > max) {
            return 0.0;
        }
        float nextValue = Math.nextUp(value);
        if (Double.compare(value, max) == 0 || Double.compare(nextValue, max) == 0) {
            return 0.0;
        }
        return FilterSelectivityEstimator.rangedSelectivity(kll, nextValue, Math.nextUp(max));
    }

    public static double greaterThanOrEqualSelectivity(KllFloatsSketch kll, float value) {
        if (value > kll.getMaxItem()) {
            return 0.0;
        }
        return FilterSelectivityEstimator.rangedSelectivity(kll, value, Math.nextUp(kll.getMaxItem()));
    }

    public static double lessThanOrEqualSelectivity(KllFloatsSketch kll, float value) {
        if (value < kll.getMinItem()) {
            return 0.0;
        }
        return kll.getCDF(new float[]{Math.nextUp(value)}, QuantileSearchCriteria.EXCLUSIVE)[0];
    }

    public static double lessThanSelectivity(KllFloatsSketch kll, float value) {
        float min = kll.getMinItem();
        if (value < min) {
            return 0.0;
        }
        if (Double.compare(value, min) == 0 || Double.compare(Math.nextUp(value), min) == 0) {
            return 0.0;
        }
        return kll.getCDF(new float[]{value}, QuantileSearchCriteria.EXCLUSIVE)[0];
    }

    public static double betweenSelectivity(KllFloatsSketch kll, float leftValue, float rightValue) {
        if (rightValue < leftValue) {
            return 0.0;
        }
        if (Double.compare(leftValue, rightValue) == 0) {
            throw new IllegalArgumentException("Selectivity for BETWEEN leftValue AND rightValue when the two values coincide is not supported, found: leftValue = " + leftValue + " and rightValue = " + rightValue);
        }
        return FilterSelectivityEstimator.rangedSelectivity(kll, Math.nextDown(leftValue), Math.nextUp(rightValue));
    }

    public static boolean isHistogramAvailable(ColStatistics colStats) {
        return colStats != null && colStats.getHistogram() != null && colStats.getHistogram().length > 0;
    }
}

