/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.optrule;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.CompositeList;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Maps;

public class CorrReduceFunctionRule
extends RelOptRule {
    public static final CorrReduceFunctionRule INSTANCE = new CorrReduceFunctionRule(CorrReduceFunctionRule.operand(Aggregate.class, (RelOptRuleOperandChildren)CorrReduceFunctionRule.any()), RelFactories.LOGICAL_BUILDER, "CorrReduceFunctionRule");
    private static final RelDataType DOUBLE_TYPE = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createSqlType(SqlTypeName.DOUBLE);

    public CorrReduceFunctionRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory, String discription) {
        super(operand, relBuilderFactory, discription);
    }

    public boolean matches(RelOptRuleCall call) {
        return this.containsCorrCall(((Aggregate)call.rels[0]).getAggCallList());
    }

    public void onMatch(RelOptRuleCall ruleCall) {
        Aggregate oldAggRel = (Aggregate)ruleCall.rels[0];
        this.reduceAggs(ruleCall, oldAggRel);
    }

    private void reduceAggs(RelOptRuleCall call, Aggregate oldAggRel) {
        List oldCalls = oldAggRel.getAggCallList();
        int groupCount = oldAggRel.getGroupCount();
        int indicatorCount = oldAggRel.getIndicatorCount();
        ArrayList projList = Lists.newArrayList();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        for (int i = 0; i < groupCount + indicatorCount; ++i) {
            projList.add(rexBuilder.makeInputRef(this.getFieldType((RelNode)oldAggRel, i), i));
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push(oldAggRel.getInput());
        ArrayList<RexNode> inputExprs = new ArrayList<RexNode>((Collection<RexNode>)relBuilder.fields());
        ArrayList newCalls = Lists.newArrayList();
        HashMap aggCallMapping = Maps.newHashMap();
        for (AggregateCall oldCall : oldCalls) {
            projList.add(this.reduceAgg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
        }
        int extraArgCount = inputExprs.size() - relBuilder.peek().getRowType().getFieldCount();
        if (extraArgCount > 0) {
            relBuilder.project(inputExprs, (Iterable)CompositeList.of((List)relBuilder.peek().getRowType().getFieldNames(), Collections.nCopies(extraArgCount, null)));
        }
        ImmutableBitSet immutableBitSet = oldAggRel.getGroupSet();
        oldAggRel.getClass();
        relBuilder.aggregate(relBuilder.groupKey(immutableBitSet, false, oldAggRel.getGroupSets()), (List)newCalls);
        relBuilder.project((Iterable)projList, (Iterable)oldAggRel.getRowType().getFieldNames());
        call.transformTo(relBuilder.build());
    }

    private boolean containsCorrCall(List<AggregateCall> aggCallList) {
        for (AggregateCall call : aggCallList) {
            if (!"CORR".equals(call.getAggregation().getName())) continue;
            return true;
        }
        return false;
    }

    private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        if ("CORR".equals(oldCall.getAggregation().getName())) {
            return this.reduceCORR(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
        }
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        List oldArgTypes = SqlTypeUtil.projectTypes((RelDataType)oldAggRel.getInput().getRowType(), (List)oldCall.getArgList());
        int n = oldAggRel.getGroupCount();
        oldAggRel.getClass();
        return rexBuilder.addAggCall(oldCall, n, false, newCalls, aggCallMapping, oldArgTypes);
    }

    private RexNode reduceCORR(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        int oldNGroups = oldAggRel.getGroupCount();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        Function<RexNode, RexNode> castToDouble = rexNode -> rexBuilder.makeCast(DOUBLE_TYPE, rexNode);
        if (oldCall.getArgList() == null || oldCall.getArgList().size() < 2) {
            throw new IllegalArgumentException("CORR must have 2 argument parameters");
        }
        int iInputX = (Integer)oldCall.getArgList().get(0);
        int iInputY = (Integer)oldCall.getArgList().get(1);
        RelDataType argInputXType = this.getFieldType(oldAggRel.getInput(), iInputX);
        RexNode argX = inputExprs.get(iInputX);
        RexNode argY = inputExprs.get(iInputY);
        AggregateCall sumXCall = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM, (boolean)oldCall.isDistinct(), (boolean)false, (boolean)false, (List)ImmutableIntList.of((int[])new int[]{iInputX}), (int)oldCall.filterArg, null, (RelCollation)RelCollations.EMPTY, (int)oldAggRel.getGroupCount(), (RelNode)oldAggRel.getInput(), null, null);
        oldAggRel.getClass();
        RexNode sumX = castToDouble.apply(rexBuilder.addAggCall(sumXCall, oldNGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)inputExprs.get(iInputX).getType())));
        AggregateCall sumYCall = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM, (boolean)oldCall.isDistinct(), (boolean)false, (boolean)false, (List)ImmutableIntList.of((int[])new int[]{iInputY}), (int)oldCall.filterArg, null, (RelCollation)RelCollations.EMPTY, (int)oldAggRel.getGroupCount(), (RelNode)oldAggRel.getInput(), null, null);
        oldAggRel.getClass();
        RexNode sumY = castToDouble.apply(rexBuilder.addAggCall(sumYCall, oldNGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)inputExprs.get(iInputY).getType())));
        RexNode sumNodeXSquared = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{sumX, sumX});
        RexNode sumNodeXY = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{sumX, sumY});
        RexNode sumNodeYSquared = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{sumY, sumY});
        RexNode sumArgXSquared = castToDouble.apply(this.buildMultiplyRexNode(oldAggRel, oldCall, inputExprs, newCalls, aggCallMapping, argX, argX));
        RexNode sumArgYSquared = castToDouble.apply(this.buildMultiplyRexNode(oldAggRel, oldCall, inputExprs, newCalls, aggCallMapping, argY, argY));
        RexNode sumArgXY = castToDouble.apply(this.buildMultiplyRexNode(oldAggRel, oldCall, inputExprs, newCalls, aggCallMapping, argX, argY));
        AggregateCall countAggCall = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)oldCall.isDistinct(), (boolean)false, (boolean)false, (List)Lists.newArrayList(), (int)oldCall.filterArg, null, (RelCollation)RelCollations.EMPTY, (int)oldAggRel.getGroupCount(), (RelNode)oldAggRel.getInput(), null, null);
        oldAggRel.getClass();
        RexNode countArg = rexBuilder.addAggCall(countAggCall, oldNGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)argInputXType));
        RexNode covNode = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{countArg, sumArgXY}), sumNodeXY});
        RexNode varianceXNode = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{sumArgXSquared, countArg}), sumNodeXSquared});
        RexNode varianceYNode = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{sumArgYSquared, countArg}), sumNodeYSquared});
        RexLiteral half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
        RexNode divisor = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.POWER, new RexNode[]{varianceXNode, half}), rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.POWER, new RexNode[]{varianceYNode, half})});
        RexNode corr = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, new RexNode[]{castToDouble.apply(covNode), castToDouble.apply(divisor)});
        LinkedList<Object> caseWhenArgs = new LinkedList<Object>();
        caseWhenArgs.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{countArg, rexBuilder.makeZeroLiteral(divisor.getType())}));
        caseWhenArgs.add(rexBuilder.makeNullLiteral(DOUBLE_TYPE));
        caseWhenArgs.add(corr);
        RexNode corrWithCountChecking = rexBuilder.makeCall(DOUBLE_TYPE, (SqlOperator)SqlStdOperatorTable.CASE, caseWhenArgs);
        return rexBuilder.makeCast(oldCall.getType(), corrWithCountChecking);
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        RelDataTypeField inputField = (RelDataTypeField)relNode.getRowType().getFieldList().get(i);
        return inputField.getType();
    }

    private static <T> int lookupOrAdd(List<T> list, T element) {
        int ordinal = list.indexOf(element);
        if (ordinal == -1) {
            ordinal = list.size();
            list.add(element);
        }
        return ordinal;
    }

    private RexNode buildMultiplyRexNode(Aggregate oldAggRel, AggregateCall oldCall, List<RexNode> inputExprs, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, RexNode rexNodeX, RexNode rexNodeY) {
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        RexNode rexNodeXY = rexBuilder.makeCall(this.resultType(), (SqlOperator)SqlStdOperatorTable.MULTIPLY, (List)Lists.newArrayList((Object[])new RexNode[]{rexNodeX, rexNodeY}));
        int argXYSquaredOrdinal = CorrReduceFunctionRule.lookupOrAdd(inputExprs, rexNodeXY);
        int nGroups = oldAggRel.getGroupCount();
        Aggregate.AggCallBinding bindingXY = new Aggregate.AggCallBinding(oldAggRel.getCluster().getTypeFactory(), SqlStdOperatorTable.SUM, (List)ImmutableList.of((Object)inputExprs.get(argXYSquaredOrdinal).getType()), oldAggRel.getGroupCount(), oldCall.filterArg >= 0);
        AggregateCall sumArgXYSquaredAggCall = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM, (boolean)oldCall.isDistinct(), (List)ImmutableIntList.of((int[])new int[]{argXYSquaredOrdinal}), (int)oldCall.filterArg, (RelDataType)SqlStdOperatorTable.SUM.inferReturnType((SqlOperatorBinding)bindingXY), null);
        oldAggRel.getClass();
        return rexBuilder.addAggCall(sumArgXYSquaredAggCall, nGroups, false, newCalls, aggCallMapping, (List)ImmutableList.of((Object)inputExprs.get(argXYSquaredOrdinal).getType()));
    }

    private RelDataType resultType() {
        return DOUBLE_TYPE;
    }
}

