# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
- BatchNormalization ∘ Conv         -> Conv
- BatchNormalization ∘ ConvTranspose -> ConvTranspose
- BatchNormalization ∘ Gemm         -> Gemm

Approach:
    Given an inbound operation output: Y = W * X + B
    And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps)

    The fusion updates the inbound weights as follows:
        - W_fused = W * (gamma / std)
        - B_fused = (B - μ) * (gamma / std) + β
"""

from abc import ABC
from typing import ClassVar, Mapping

import numpy as np

from onnxscript import ir
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet


def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray:
    # Build shape: 1s everywhere except -1 at the target axis
    broadcast_shape = [1 if axis != i else -1 for i in range(rank)]
    return np.reshape(x, broadcast_shape)


class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
    """Interface for BatchNormalization nodes fusion."""

    def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
        """Return the axis along which BatchNorm scale should be broadcasted."""
        raise NotImplementedError()

    def _scale_weights(
        self,
        weights: np.ndarray,
        scale_factor: np.ndarray,
        attributes: Mapping[str, ir.Attr],
    ) -> np.ndarray:
        axis = self.get_filters_axis(attributes)
        return weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)

    def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
        batchnorm_node = batchnorm_out.producer()
        # Get BatchNorm parameters
        gamma, beta, input_mean, input_var = [
            inp.const_value.numpy() for inp in batchnorm_node.inputs[1:]
        ]

        # 1e-5 is the default value for epsilon according to
        # https://onnx.ai/onnx/operators/onnx__BatchNormalization.html#attributes
        default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5)
        eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float()

        # Compute the scale_factor to update the inbound weights and bias
        scale_factor = gamma / np.sqrt(input_var + eps)

        # Update inbound weights
        inbound_node = inbound_out.producer()
        weights = inbound_node.inputs[1].const_value.numpy()

        fused_weights = ir.tensor(
            self._scale_weights(weights, scale_factor, inbound_node.attributes)
        )

        # Update bias
        if len(inbound_node.inputs) > 2:
            original_bias = inbound_node.inputs[2].const_value.numpy()
            bias_name = inbound_node.inputs[2].name
        else:
            original_bias = np.zeros_like(input_mean)
            # Use inbound input 1 (should be weight) to derive a name for the bias
            # to avoid name collision on initializer creation when there are multiple patterns
            # sharing the same parent nodes.
            bias_name = inbound_node.inputs[1].name + "_bias"
        fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta)

        return op.op(
            self.op_type,
            inputs=[
                x,
                op.initializer(fused_weights, name=inbound_node.inputs[1].name),
                op.initializer(fused_bias, name=bias_name),
            ],
            attributes=inbound_node.attributes,
        )

    def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> MatchResult:
        del context  # Unused
        check_result = MatchResult()

        inbound_node = inbound_out.producer()
        batchnorm_node = batchnorm_out.producer()

        # Check that inbound weights + (inbound bias) + batchnorm params are initializers
        # and that they are not graph inputs
        initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]]
        if len(inbound_node.inputs) > 2:
            initializers.append(inbound_node.inputs[2])

        for initializer in initializers:
            if not initializer.is_initializer() or initializer.const_value is None:
                return check_result.fail(f"{initializer.name} is not a constant initializer.")
            if initializer.is_graph_input():
                return check_result.fail(f"{initializer.name} is a graph input.")

        # Check that the inbound node's weight and bias initializers are not shared
        # with other nodes outside this matched pattern. When the fusion creates new
        # initializers with the same name as the original shared weights, it overwrites
        # the original initializer in the graph, leaving other nodes that reference the
        # original value with an invalid (unregistered) input.
        matched_nodes = {inbound_node, batchnorm_node}
        inbound_initializers = [inbound_node.inputs[1]]
        if len(inbound_node.inputs) > 2:
            inbound_initializers.append(inbound_node.inputs[2])
        for init_value in inbound_initializers:
            for user, _ in init_value.uses():
                if user not in matched_nodes:
                    return check_result.fail(
                        f"Initializer '{init_value.name}' is used by another node "
                        f"'{user.name}' outside the matched pattern."
                    )

        return check_result


class FuseBatchNormIntoConv(_FuseBatchNormBase):
    """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""

    op_type: ClassVar = "Conv"

    def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
        return 0

    def pattern(self, op, x):
        return op.BatchNormalization(
            op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
            _allow_other_inputs=True,
            _outputs=["batchnorm_out"],
        )


class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
    """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""

    op_type: ClassVar = "ConvTranspose"

    def _scale_weights(
        self,
        weights: np.ndarray,
        scale_factor: np.ndarray,
        attributes: Mapping[str, ir.Attr],
    ) -> np.ndarray:
        # ConvTranspose weight: (in_channels, out_channels/group, *kernel)
        # Reshape weights: [in_channels, out_channels/group, *kernel] → [group, in_channels/group, out_channels/group, *kernel]
        in_channels = weights.shape[0]
        out_channels_per_group = weights.shape[1]
        kernel_shape = weights.shape[2:]
        group = attributes.get("group", ir.AttrInt64("group", 1)).as_int()
        w = weights.reshape(group, in_channels // group, out_channels_per_group, *kernel_shape)

        # Per group scale_factor (out_channels,) -> (group, out_channels/group) -> (group, 1, out_channels/group, 1, ..., 1)
        s = scale_factor.reshape((group, out_channels_per_group) + (1,) * len(kernel_shape))
        # insert in_channels/group axis -> (group, 1, out_channels/group, *ones)
        s = s[:, None, ...]

        return (w * s).reshape(weights.shape)

    def pattern(self, op, x):
        return op.BatchNormalization(
            op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
            _allow_other_inputs=True,
            _outputs=["batchnorm_out"],
        )

    def check(self, context, x, inbound_out, batchnorm_out):
        check_result = super().check(context, x, inbound_out, batchnorm_out)
        if not check_result:
            return check_result

        inbound_node = inbound_out.producer()

        in_channels = inbound_node.inputs[1].const_value.numpy().shape[0]
        group = inbound_node.attributes.get("group", ir.AttrInt64("group", 1)).as_int()

        # Check that in_channels is divisible by group as ONNX checker allows it
        # But this is invalid case
        if in_channels % group != 0:
            return check_result.fail(
                f"ConvTranspose in_channels ({in_channels}) is not divisible by group ({group})."
            )

        return check_result


class FuseBatchNormIntoGemm(_FuseBatchNormBase):
    """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""

    op_type: ClassVar = "Gemm"

    def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
        return (
            0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1
        )

    def pattern(self, op, x):
        return op.BatchNormalization(
            op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
            _allow_other_inputs=True,
            _outputs=["batchnorm_out"],
        )


fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule()
fuse_batchnorm_into_conv_transpose_rule = FuseBatchNormIntoConvTranspose().rule()
fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule()


rules = RewriteRuleSet(
    [
        fuse_batchnorm_into_conv_rule,
        fuse_batchnorm_into_conv_transpose_rule,
        fuse_batchnorm_into_gemm_rule,
    ]
)
