import numpy as np
import torch
import torch.nn as nn

try:
    import qeft_cuda
except:
    print('CUDA extension is not installed.')

def quantize(x, scale, zero, minq, maxq):
    q = torch.clamp(torch.round(x / scale) + zero, minq, maxq)
    return scale * (q - zero)

def quantize_efficient(x_round, scale, zero, minq, maxq):
    q = torch.clamp(x_round + zero, minq, maxq)
    return scale * (q - zero)

class Quantizer(nn.Module):
    def __init__(
            self,
            bits, perchannel=False, sym=False, 
            mse=False, norm=2.4, group_size=-1,
        ):
        super(Quantizer, self).__init__()
        self.register_buffer('scale', torch.zeros(1))
        self.register_buffer('zero', torch.zeros(1))
        self.register_buffer('out_ids', torch.zeros(1))
        
        self.bits = bits
        self.sym = sym
        self.mse = mse
        self.norm = norm
        self.perchannel = perchannel
        self.n_levels = 2 ** bits
        self.group_size = group_size
        
        if self.sym:
            self.minq, self.maxq = -((self.n_levels - 1) // 2 + 1), (self.n_levels - 1) // 2
        else:
            self.minq, self.maxq = 0, self.n_levels - 1
        
        self.num = 100
        self.eps = torch.tensor(1e-8)
        
    def lp_loss(self, pred, tgt, p=2.0):
        x = (pred - tgt).abs().pow(p)
        if not self.perchannel:
            return x.mean()
        else:
            y = torch.flatten(x, 1)
            return y.mean(1)
        
    def append_params(self):
        if not hasattr(self, 'scale_group'):
            self.register_buffer('scale_group', self.scale)
            self.register_buffer('zero_group', self.zero)
        else:
            self.scale_group = torch.cat((self.scale_group, self.scale), 1)
            self.zero_group = torch.cat((self.zero_group, self.zero), 1)
    
    def find_params(self, x, weight=False, num=100):
        self.num = num
        dev = x.device
        minq, maxq = self.minq, self.maxq
        
        shape = x.shape
        if self.perchannel: # row-wise
            if weight:
                x = x.flatten(1)
            else:
                if len(shape) == 4:
                    x = x.permute([1, 0, 2, 3])
                    x = x.flatten(1)
                if len(shape) == 3:
                    x = x.reshape((-1, shape[-1])).t()
                if len(shape) == 2:
                    x = x.t()
        else:
            x = x.flatten().unsqueeze(0)

        tmp = torch.zeros(x.shape[0], device=dev)
        xmin = torch.minimum(x.min(1)[0], tmp)
        xmax = torch.maximum(x.max(1)[0], tmp)
        
        if self.mse:
            if self.perchannel:
                new_shape = [-1] + [1] * (len(x.shape) -  1)
            
            best_score = torch.zeros_like(xmin) + (1e+10)
            best_min = xmin.clone()
            best_max = xmax.clone()
            
            if self.sym:
                xrange = torch.max(xmin.abs(), xmax)
                zero = torch.zeros_like(xmin)
                if self.perchannel:
                    zero = zero.reshape(new_shape)
                for i in range(1, self.num + 1):
                    tmp_max = xrange / self.num * i
                    scale = torch.max(tmp_max / -minq, self.eps)
                    if self.perchannel:
                        scale = scale.reshape(new_shape)
                    x_round = torch.round(x / scale)
                    x_q = quantize_efficient(x_round, scale, zero, minq, maxq)
                    score = self.lp_loss(x, x_q, 2.4)
                    best_max = torch.where(score < best_score, tmp_max, best_max)
                    best_score = torch.min(score, best_score)
                
                max_val = torch.max(best_max, torch.zeros_like(best_max))

                self.scale = torch.max(max_val / -minq, self.eps)
                self.zero = torch.zeros_like(self.scale)
            else:
                xrange = xmax - xmin
                tmp_min = torch.zeros_like(xmin)
                for i in range(1, self.num + 1):
                    tmp_max = xrange / self.num * i
                    scale = torch.max((tmp_max - tmp_min) / (maxq - minq), self.eps)
                    delta = scale.clone()
                    if self.perchannel:
                        scale = scale.reshape(new_shape)
                    x_round = torch.round(x / scale)
                    for zp in range(0, self.n_levels):
                        new_min = tmp_min - zp * delta
                        new_max = tmp_max - zp * delta
                        zero = torch.clamp(minq - torch.round(new_min / delta), minq, maxq)
                        if self.perchannel:
                            zero = zero.reshape(new_shape)
                        x_q = quantize_efficient(x_round, scale, zero, minq, maxq)
                        score = self.lp_loss(x, x_q, 2.4)
                        best_min = torch.where(score < best_score, new_min, best_min)
                        best_max = torch.where(score < best_score, new_max, best_max)
                        best_score = torch.min(best_score, score)
            
                min_val_neg = torch.min(best_min, torch.zeros_like(best_min))
                max_val_pos = torch.max(best_max, torch.zeros_like(best_max))

                self.scale = torch.max((max_val_pos - min_val_neg) / (maxq - minq), self.eps)
                self.zero = torch.clamp(minq - torch.round(min_val_neg / self.scale), minq, maxq)
        else:
            if self.sym:
                xmax = torch.maximum(torch.abs(xmin), xmax)
                tmp = xmin < 0
                if torch.any(tmp):
                    xmin[tmp] = -xmax[tmp]

            tmp = (xmin == 0) & (xmax == 0) 
            xmin[tmp] = -1
            xmax[tmp] = +1

            if self.sym:
                self.scale = xmax / -minq
                self.zero = torch.zeros_like(self.scale)
            else:
                self.scale = (xmax - xmin) / maxq
                self.zero = torch.round(-xmin / self.scale)
        
        if not self.perchannel:
            if weight:
                tmp = shape[0]
            else:
                tmp = shape[1] if len(shape) != 3 else shape[2]
            self.scale = self.scale.repeat(tmp)
            self.zero = self.zero.repeat(tmp)

        if weight:
            shape = [-1] + [1] * (len(shape) - 1)
            self.scale = self.scale.reshape(shape)
            self.zero = self.zero.reshape(shape)
            return
        if len(shape) == 4:
            self.scale = self.scale.reshape((1, -1, 1, 1))
            self.zero = self.zero.reshape((1, -1, 1, 1))
        if len(shape) == 3:
            self.scale = self.scale.reshape((1, 1, -1))
            self.zero = self.zero.reshape((1, 1, -1)) 
        if len(shape) == 2:
            self.scale = self.scale.unsqueeze(0)
            self.zero = self.zero.unsqueeze(0)

    def quantize(self, x):
        if self.ready():
            return quantize(x, self.scale, self.zero, self.minq, self.maxq)
        return x

    def enabled(self):
        return self.maxq > 0

    def ready(self):
        return torch.all(self.scale != 0)

def make_quant(module, quantinfos, name=''):
    if isinstance(module, (QuantLinear, QuantLinearReorder)):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        name1 = name + '.' + attr if name != '' else attr
        if name1 in quantinfos:
            qlinear = QuantLinearReorder if getattr(quantinfos[name1],'reorder', False) else QuantLinear
            setattr(
                module, attr, 
                qlinear(quantinfos[name1].bits, 
                            tmp.in_features, 
                            tmp.out_features, 
                            tmp.bias is not None, 
                            tmp.weight.dtype,
                            getattr(quantinfos[name1],'n_out', 0),
                            getattr(quantinfos[name1],'group_size', -1),
                            getattr(quantinfos[name1],'reorder', False),
                            name1).to(tmp.weight.device)
            )
    for name1, child in module.named_children():
        make_quant(child, quantinfos, name + '.' + name1 if name != '' else name1)

def lm_pack(model, quantinfos, linears=[nn.Linear]):
    from qeft.utils.misc import find_layers
    from tqdm import tqdm
    
    layers = find_layers(model, linears)
    layers = {n: layers[n] for n in quantinfos}
    make_quant(model, quantinfos)
    qlayers = find_layers(model, [QuantLinear, QuantLinearReorder])
    for name in tqdm(qlayers, f"Packing ..."):
        quantinfos[name] = quantinfos[name].cpu()
        qlayers[name].pack(
            layers[name], 
            scales = getattr(quantinfos[name], 'scale_group', quantinfos[name].scale), 
            zeros = getattr(quantinfos[name], 'zero_group', quantinfos[name].zero), 
            outlieridx = getattr(quantinfos[name], 'out_ids', None), 
            sym = getattr(quantinfos[name], 'sym', False)
        )
    print('Done.')
    return model

class QuantMatMul(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd
    def forward(ctx, x, oweight, fn_dequant, qweight, scales, zeros, shape, outids, bias, name):
        # 1. Dequantize
        dtype = scales.dtype
        group_size = shape[0] // min(scales.shape)
        
        out = torch.empty(shape, dtype=dtype, device=oweight.device)
        fn_dequant(qweight, out, scales, zeros)
        out[outids, :] = oweight.to(dtype)
        out = out.t()
        
        # 2. Matmul
        output = torch.nn.functional.linear(x.to(dtype), out.to(dtype), bias)
        ctx.dequant_params = [oweight, fn_dequant, qweight, scales, zeros, shape, outids, dtype, name]
        ctx.tensors = torch.index_select(x, -1, outids)
        return output

    @staticmethod
    @torch.cuda.amp.custom_bwd
    def backward(ctx, grad_output):
        x_outlier = ctx.tensors
        oweight, fn_dequant, qweight, scales, zeros, shape, outids, dtype, name = ctx.dequant_params

        # Dequantize
        out = torch.empty(shape, dtype=dtype, device=oweight.device)
        fn_dequant(qweight, out, scales, zeros)
        out[outids, :] = oweight.to(dtype)
        out = out.t()
        
        grad_input = None
        grad_oweight = None
        
        if ctx.needs_input_grad[0]:
            grad_input = torch.matmul(grad_output, out.to(grad_output.dtype))
        if ctx.needs_input_grad[1]:
            grad_oweight = torch.matmul(grad_output.transpose(-2,-1), x_outlier.to(grad_output.dtype))
            grad_oweight = grad_oweight.transpose(-1, -2)
            
        return grad_input, grad_oweight, None, None, None, None, None, None, None, None

class QuantLinear(nn.Module):

    def __init__(self, bits, infeatures, outfeatures, bias, dtype, outlierfeatures, group_size, reorder, name): # TODO
        super().__init__()
        assert bits in [3, 4], "Only 3, 4 bits are supported."
        assert infeatures % group_size == 0
        
        self.bits = bits
        self.infeatures = infeatures
        self.outfeatures = outfeatures
        self.outlierfeatures = outlierfeatures
        
        self.group_size = group_size
        self.reorder = reorder
        
        self.register_buffer(
            'qweight', torch.empty((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
        )
        numgroup = infeatures // group_size if group_size > 0 else 1
        
        self.register_buffer('scales', torch.empty((outfeatures, numgroup), dtype=dtype))
        self.register_buffer('zeros', torch.empty((outfeatures // 2, numgroup), dtype=torch.uint8))
        
        if bias:
            self.register_buffer('bias', torch.empty(outfeatures, dtype=dtype))
        else:
            self.bias = None
        
        # for weak columns
        if outlierfeatures > 0:
            self.register_buffer(
                'oweight', torch.empty((outlierfeatures, outfeatures), dtype=dtype)
            )
            self.register_buffer(
                'outlieridx', torch.empty((outlierfeatures), dtype=torch.int)
            )
        
        self.faster = True
        self.dtype = dtype
        self.name = name
        
    def pack(self, linear, scales, zeros, outlieridx:torch.Tensor, sym:bool=False):
        dtype = linear.weight.dtype
        
        self.sym = sym
        if sym:
            zeros += 2**(self.bits - 1)
            
        if linear.bias is not None:
            self.bias = linear.bias.to(dtype)
            
        self.outlieridx = outlieridx

        if self.outlierfeatures > 0:
            if self.reorder:
                self.oweight = linear.weight.data[:,-self.outlierfeatures:].t().contiguous()
            else:
                self.oweight = torch.index_select(linear.weight.data, 1, self.outlieridx).t().contiguous()
        
        zeros_interleave = torch.repeat_interleave(zeros, max(self.group_size, 1), dim=1)
        scales_interleave = torch.repeat_interleave(scales, max(self.group_size, 1), dim=1)
        intweight = torch.round((linear.weight.data + zeros_interleave * scales_interleave) / scales_interleave).to(torch.int)
        intweight = intweight.t().contiguous()
        intweight = intweight.numpy().astype(np.uint32)
        if self.outlierfeatures > 0:
            if self.reorder:
                if self.group_size > 0:
                    for i in range(self.outlierfeatures):
                        intweight[-self.outlierfeatures + i, :] = zeros[:, (self.infeatures - self.outlierfeatures + i) // self.group_size].numpy().astype(np.uint32).squeeze()
                    self.outlieridx = torch.arange(self.infeatures - self.outlierfeatures, self.infeatures, 
                                            device=self.outlieridx.device,
                                            dtype=self.outlieridx.dtype)
                else:
                    intweight[-self.outlierfeatures, :] = zeros.numpy().astype(np.uint32).squeeze()
            else:
                if self.group_size > 0:
                    for idx in outlieridx:
                        intweight[idx,:] = zeros[:, (idx // self.group_size)].numpy().astype(np.uint32).squeeze()
                else:
                    for idx in outlieridx:
                        intweight[idx,:] = zeros.numpy().astype(np.uint32).squeeze()
        qweight = np.zeros(
            (self.infeatures // 32 * self.bits, self.outfeatures), dtype=np.uint32
        )
        
        self.scales = scales.to(dtype)
        zeros = zeros.to(torch.uint8)
        zeros_int = torch.zeros((zeros.shape[0] // 2, zeros.shape[1]), dtype=torch.uint8)
        for i in range(zeros_int.shape[0]):
            zeros_int[i] = (zeros[2*i] | zeros[2*i + 1] << 4)
        self.zeros = zeros_int
        
        i = 0
        row = 0
        if self.bits == 3:
            while row < qweight.shape[0]:
                for j in range(i, i + 10):
                    qweight[row] |= intweight[j] << (3 * (j - i))
                i += 10
                qweight[row] |= intweight[i] << 30
                row += 1
                qweight[row] |= (intweight[i] >> 2) & 1
                i += 1
                for j in range(i, i + 10):    
                    qweight[row] |= intweight[j] << (3 * (j - i) + 1)
                i += 10
                qweight[row] |= intweight[i] << 31
                row += 1
                qweight[row] |= (intweight[i] >> 1) & 0x3
                i += 1
                for j in range(i, i + 10):    
                    qweight[row] |= intweight[j] << (3 * (j - i) + 2)
                i += 10
                row += 1
        elif self.bits == 4:
            while row < qweight.shape[0]:
                for j in range(i, i + 8):
                    qweight[row] |= intweight[j] << (4 * (j - i))
                i += 8
                row += 1
        else:
            raise NotImplementedError

        qweight = qweight.astype(np.int32)
        self.qweight = torch.from_numpy(qweight)
        
    def set_for_tune(self):
        self.qweight = torch.nn.Parameter(self.qweight, requires_grad=False)
        if self.outlierfeatures > 0:
            self.oweight = torch.nn.Parameter(self.oweight.to(dtype=torch.float), requires_grad=True)
    
    def set_kernel(self, faster):
        # [oc, 1] -> [1, oc] or [oc, g] -> [g, oc]
        self.scales = self.scales.t().contiguous()
        self.zeros = self.zeros.t().contiguous()
        
        if faster == False:
            self.oweight = self.oweight.float()
            self.scales = self.scales.float()
        if self.bits == 4:
            self.matvec = qeft_cuda.vecquant4matmul_faster_group
            self.outmatvec = qeft_cuda.vecquant4outliermatmul_faster_group
            self.dequant = qeft_cuda.matquant4dequant_faster_group
        else:
            raise NotImplementedError
        
        self.matmul = QuantMatMul.apply
            

    def forward(self, x):
        outshape = x.shape[:-1] + (self.outfeatures, )
        x = x.reshape(-1, x.shape[-1])
        if x.shape[-1] == x.numel():
            y = self.outmatvec(
                x, self.qweight,
                self.scales, self.zeros,
                self.oweight, self.outlieridx,
                self.outrow, self.cnt
                )
            y = y + self.bias if self.bias is not None else y
        else:
            y = self.matmul(x, self.oweight, self.dequant, 
                self.qweight, self.scales, 
                self.zeros, (self.infeatures, self.outfeatures),
                self.outlieridx,
                self.bias, self.name)
        return y.reshape(outshape)

def make_divisible(c, divisor):
    return (c + divisor - 1) // divisor

def calculate_zeros_width(infeatures, group_size=128, pack_num=8):
    if group_size >= 128:
        size_multiplier = 1
    elif group_size == 64:
        size_multiplier = 2
    elif group_size == 32:
        size_multiplier = 4
    else:
        raise NotImplementedError

    base_width = make_divisible(infeatures // group_size, pack_num)
    base_width = make_divisible(base_width, size_multiplier) * size_multiplier
    return base_width

def pack_oweight(oweight, interleave=4):
    new_oweight = []
    for i in range(0, oweight.shape[0], 2*interleave):
        for j in range(interleave):
            new_row = []
            for k in range(0, oweight.shape[1], 32): # 128
                new_row.append(torch.stack([oweight[i+j,k:k+32], oweight[i+j+interleave,k:k+32]], dim=0).t().flatten())
            new_oweight.append(torch.cat(new_row, dim=0))

    return torch.stack(new_oweight, dim=0)
    
def pack_intweight(unpacked_qweight, interleave, kstride):
    # unpacked_qweight: [N, K]
    N = unpacked_qweight.shape[0]
    K = unpacked_qweight.shape[1]

    Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
    # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)

    # reorder each 8 weights for fast dequantization
    # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
    Packed_Kernel = Packed_Kernel.reshape(N, K)

    # interleaving every four rows
    Packed_Kernel = Packed_Kernel.reshape(
        N // interleave, interleave, K // kstride, kstride
    )
    # N // 4, K // 64, 4, 64
    Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
    Packed_Kernel = Packed_Kernel.reshape(
        N // interleave, K // kstride, kstride, interleave
    )
    # Packing -> (N // 4, K // 64, 64)

    Packed_Kernel = (
        Packed_Kernel[..., 0]
        | (Packed_Kernel[..., 1] << 4)
        | (Packed_Kernel[..., 2] << 8)
        | (Packed_Kernel[..., 3] << 12)
    )
    # reshape to (N // 4, K), FP16 format
    Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
    qweight = (
        torch.tensor(Packed_Kernel.astype("int16"))
        .to(unpacked_qweight.device)
        .contiguous()
    )
    return qweight
    
class QuantLinearReorder(nn.Module):

    def __init__(self, bits, infeatures, outfeatures, bias, dtype, outlierfeatures, group_size, reorder, name): # TODO
        super().__init__()
        assert bits in [4], "Only 4 bits is supported."
        assert dtype == torch.float16, "Only fp16 is supported."
        assert reorder == True, "Only reordered format is supported."
        
        self.bits = bits
        self.infeatures = infeatures
        self.outfeatures = outfeatures
        self.outlierfeatures = outlierfeatures
        
        self.group_size = group_size if group_size != -1 else infeatures
        self.interleave = 4
        self.reorder = reorder # True
        
        assert infeatures % group_size == 0
        assert outfeatures % (32 // self.bits) == 0
        pack_num = 32 // self.bits
        int16_pack_num = 16 // self.bits
        
        self.register_buffer(
            "qweight",
            torch.zeros(
                (
                    outfeatures // self.interleave,
                    infeatures // int16_pack_num * self.interleave,
                ),
                dtype=torch.int16
            ),
        )
        
        if group_size == -1:
            self.register_buffer(
                "scales",
                torch.zeros(
                    (
                        1,
                        outfeatures,
                    ),
                    dtype=torch.float16,
                ),
            )
            self.register_buffer(
                "scaled_zeros",
                torch.zeros(
                    (
                        1,
                        outfeatures,
                    ),
                    dtype=torch.float16,
                ),
            )
        else:
            self.register_buffer(
                "scales",
                torch.zeros(
                    (
                        calculate_zeros_width(infeatures, self.group_size) * pack_num,
                        outfeatures,
                    ),
                    dtype=torch.float16,
                ),
            )
            self.register_buffer(
                "scaled_zeros",
                torch.zeros(
                    (
                        calculate_zeros_width(infeatures, self.group_size) * pack_num,
                        outfeatures,
                    ),
                    dtype=torch.float16,
                ),
            )

        if bias:
            self.register_buffer(
                "bias", torch.zeros((outfeatures), dtype=torch.float16)
            )
        else:
            self.bias = None
        
        # for weak columns
        if outlierfeatures > 0:
            self.register_buffer(
                'oweight', torch.zeros((outfeatures // 2, outlierfeatures * 2), dtype=dtype)
            )
            self.register_buffer(
                'outlieridx', torch.zeros((outlierfeatures), dtype=torch.int)
            )
        
        self.faster = True
        self.dtype = dtype
        self.name = name
        
    def pack(self, linear, scales, zeros, outlieridx:torch.Tensor, sym:bool=False):
        dtype = self.dtype
        
        self.sym = sym
        if sym:
            zeros += 2**(self.bits - 1)
            
        if linear.bias is not None:
            self.bias = linear.bias.to(dtype)
            
        if self.outlierfeatures > 0:
            oweight = linear.weight.data[:,-self.outlierfeatures:].clone()
            self.oweight = pack_oweight(oweight, interleave=4)
        
        # [OC, IC // g]
        scale_zeros =  zeros * scales
        if self.group_size == self.infeatures:
            scales_interleave = scales
            scale_zeros_interleave = scale_zeros
        else:
            scales_interleave = torch.repeat_interleave(scales, self.group_size, dim=1)
            scale_zeros_interleave = torch.repeat_interleave(scale_zeros, self.group_size, dim=1)
        
        intweight = torch.round((linear.weight.data + scale_zeros_interleave) / scales_interleave).to(torch.int)
        intweight = intweight.to(dtype=torch.int32)
        
        if self.outlierfeatures > 0:
            for i in range(self.infeatures - self.outlierfeatures, self.infeatures):
                intweight[:, i] = zeros[:, i // self.group_size].to(torch.int32)
        self.qweight = pack_intweight(intweight, interleave=4, kstride=64)
        
        # [IC // g, OC]
        self.scales[:scales.shape[1], :] = scales.t().contiguous().to(dtype)
        self.scaled_zeros[:scale_zeros.shape[1], :] = -scale_zeros.t().contiguous().to(dtype)

    def set_kernel(self, faster):
        if self.oweight.shape[1] % 64 > 0:
            self.oweight = torch.cat([torch.zeros((self.oweight.shape[0],64 - self.oweight.shape[1] % 64), dtype=self.oweight.dtype, device=self.oweight.device), self.oweight], dim=-1)

    def forward(self, x):
        outshape = x.shape[:-1] + (self.outfeatures, )
        x = x.reshape(-1, x.shape[-1])
        if x.shape[0] < 8:
            y = qeft_cuda.gemv_forward_cuda_new(
                x,
                self.qweight,
                self.scales,
                self.scaled_zeros,
                self.oweight,
                x.shape[0],
                self.outfeatures,
                self.infeatures,
                self.group_size,
            )
        else:
            y = qeft_cuda.gemm_forward_cuda_new(
                x, self.qweight, self.scales, self.scaled_zeros, self.oweight
            )
        y = y + self.bias if self.bias is not None else y
        return y.reshape(outshape)