# Copyright 2020 The TensorFlow Probability Authors. All Rights Reserved.
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
# THIS FILE IS AUTO-GENERATED BY `gen_linear_operators.py`.
# DO NOT MODIFY DIRECTLY.
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
# pylint: disable=g-import-not-at-top
# pylint: disable=g-direct-tensorflow-import
# pylint: disable=g-bad-import-order
# pylint: disable=unused-import
# pylint: disable=line-too-long
# pylint: disable=reimported
# pylint: disable=g-bool-id-comparison
# pylint: disable=g-statement-before-imports
# pylint: disable=bad-continuation
# pylint: disable=useless-import-alias
# pylint: disable=property-with-parameters
# pylint: disable=trailing-whitespace

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registrations for LinearOperator.cholesky."""

from __future__ import absolute_import
from __future__ import division
# [internal] enable type annotations
from __future__ import print_function

from tensorflow_probability.python.internal.backend.jax import linalg_impl as linalg_ops
from tensorflow_probability.python.internal.backend.jax import numpy_math as math_ops
from tensorflow_probability.python.internal.backend.jax.gen import linear_operator
from tensorflow_probability.python.internal.backend.jax.gen import linear_operator_algebra
from tensorflow_probability.python.internal.backend.jax.gen import linear_operator_block_diag
from tensorflow_probability.python.internal.backend.jax.gen import linear_operator_diag
from tensorflow_probability.python.internal.backend.jax.gen import linear_operator_identity
from tensorflow_probability.python.internal.backend.jax.gen import linear_operator_kronecker
from tensorflow_probability.python.internal.backend.jax.gen import linear_operator_lower_triangular


# By default, compute the Cholesky of the dense matrix, and return a
# LowerTriangular operator. Methods below specialize this registration.
@linear_operator_algebra.RegisterCholesky(linear_operator.LinearOperator)
def _cholesky_linear_operator(linop):
  return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
      linalg_ops.cholesky(linop.to_dense()),
      is_non_singular=True,
      is_self_adjoint=False,
      is_square=True)


@linear_operator_algebra.RegisterCholesky(
    linear_operator_diag.LinearOperatorDiag)
def _cholesky_diag(diag_operator):
  return linear_operator_diag.LinearOperatorDiag(
      math_ops.sqrt(diag_operator.diag),
      is_non_singular=True,
      is_self_adjoint=True,
      is_positive_definite=True,
      is_square=True)


@linear_operator_algebra.RegisterCholesky(
    linear_operator_identity.LinearOperatorIdentity)
def _cholesky_identity(identity_operator):
  return linear_operator_identity.LinearOperatorIdentity(
      num_rows=identity_operator._num_rows,  # pylint: disable=protected-access
      batch_shape=identity_operator.batch_shape,
      dtype=identity_operator.dtype,
      is_non_singular=True,
      is_self_adjoint=True,
      is_positive_definite=True,
      is_square=True)


@linear_operator_algebra.RegisterCholesky(
    linear_operator_identity.LinearOperatorScaledIdentity)
def _cholesky_scaled_identity(identity_operator):
  return linear_operator_identity.LinearOperatorScaledIdentity(
      num_rows=identity_operator._num_rows,  # pylint: disable=protected-access
      multiplier=math_ops.sqrt(identity_operator.multiplier),
      is_non_singular=True,
      is_self_adjoint=True,
      is_positive_definite=True,
      is_square=True)


@linear_operator_algebra.RegisterCholesky(
    linear_operator_block_diag.LinearOperatorBlockDiag)
def _cholesky_block_diag(block_diag_operator):
    # We take the cholesky of each block on the diagonal.
  return linear_operator_block_diag.LinearOperatorBlockDiag(
      operators=[
          operator.cholesky() for operator in block_diag_operator.operators],
      is_non_singular=True,
      is_self_adjoint=False,
      is_square=True)


@linear_operator_algebra.RegisterCholesky(
    linear_operator_kronecker.LinearOperatorKronecker)
def _cholesky_kronecker(kronecker_operator):
    # Cholesky decomposition of a Kronecker product is the Kronecker product
    # of cholesky decompositions.
  return linear_operator_kronecker.LinearOperatorKronecker(
      operators=[
          operator.cholesky() for operator in kronecker_operator.operators],
      is_non_singular=True,
      is_self_adjoint=False,
      is_square=True)

import numpy as np; onp = np
from tensorflow_probability.python.internal.backend.jax import linalg_impl as _linalg
from tensorflow_probability.python.internal.backend.jax import ops as _ops
from tensorflow_probability.python.internal.backend.jax.gen import tensor_shape

from tensorflow_probability.python.internal.backend.jax import private
distribution_util = private.LazyLoader(
    "distribution_util", globals(),
    "tensorflow_probability.substrates.numpy.internal.distribution_util")
tensorshape_util = private.LazyLoader(
    "tensorshape_util", globals(),
    "tensorflow_probability.substrates.numpy.internal.tensorshape_util")


