133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
# Pack pairs of int4 values into int8, in row major order; first int4
|
||
|
# value goes into lower order bits, and second int4 value into higher
|
||
|
# order bits of resulting int8 value.
|
||
|
def pack_int4_to_int8(weight):
|
||
|
assert weight.dim() == 2
|
||
|
assert weight.shape[1] % 2 == 0
|
||
|
assert weight.dtype == torch.int8
|
||
|
return ((weight[:, 1::2] & 0xF) << 4) | (weight[:, 0::2] & 0xF)
|
||
|
|
||
|
|
||
|
# Unpack quandruples of bits in int8 values into int4 values, in row
|
||
|
# major order; lower 4 bits go into first int4 value goes, and upper 4
|
||
|
# bits go into second int4 value.
|
||
|
def unpack_int8_to_int4(weight):
|
||
|
assert weight.dim() == 2
|
||
|
assert weight.dtype == torch.int8
|
||
|
return torch.stack((weight & 0xF, (weight >> 4) & 0xF), dim=2).view(
|
||
|
weight.shape[0], 2 * weight.shape[1]
|
||
|
)
|
||
|
|
||
|
|
||
|
# Transpose the weight matrix, and then reorder its elements according
|
||
|
# to underlying requirements of CUTLASS library, so that it could be
|
||
|
# used for CUTLASS-based mixed datatypes linear operation.
|
||
|
def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
||
|
weight, dtypeq, transpose=False
|
||
|
):
|
||
|
assert weight.dim() == 2
|
||
|
assert weight.dtype == torch.int8
|
||
|
assert dtypeq == torch.int8 or dtypeq == torch.quint4x2
|
||
|
assert weight.device.type == "cuda"
|
||
|
|
||
|
device = weight.device
|
||
|
|
||
|
# subbyte_transpose
|
||
|
if not transpose:
|
||
|
if dtypeq == torch.int8:
|
||
|
outp = weight.T
|
||
|
elif dtypeq == torch.quint4x2:
|
||
|
outp = pack_int4_to_int8(unpack_int8_to_int4(weight.view(torch.int8)).T)
|
||
|
else:
|
||
|
outp = weight
|
||
|
|
||
|
ncols, nrows = outp.shape # type: ignore[possibly-undefined]
|
||
|
assert nrows % (32 if dtypeq == torch.quint4x2 else 64) == 0
|
||
|
assert ncols % 64 == 0
|
||
|
|
||
|
# permute_B_rows_for_mixed_gemm
|
||
|
# (permute cols actually, as transpose is applied first here)
|
||
|
if dtypeq == torch.quint4x2:
|
||
|
cols_permuted = (
|
||
|
torch.tensor(
|
||
|
[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15],
|
||
|
device=device,
|
||
|
)
|
||
|
+ (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand(
|
||
|
nrows // 16, 16
|
||
|
)
|
||
|
).view(-1)
|
||
|
else:
|
||
|
cols_permuted = (
|
||
|
torch.tensor(
|
||
|
[0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15],
|
||
|
device=device,
|
||
|
)
|
||
|
+ (torch.arange(0, nrows // 16, device=device).reshape(-1, 1) * 16).expand(
|
||
|
nrows // 16, 16
|
||
|
)
|
||
|
).view(-1)
|
||
|
outp = outp.index_copy(1, cols_permuted, outp)
|
||
|
|
||
|
# interleave_column_major_tensor
|
||
|
magic0 = 4 if dtypeq == torch.quint4x2 else 2
|
||
|
magic1 = 32 // magic0
|
||
|
|
||
|
tmp0 = (
|
||
|
(torch.arange(0, ncols // magic0, device=device) * (nrows // 4 * magic0))
|
||
|
.view(-1, 1)
|
||
|
.repeat(1, nrows // 4 * magic0)
|
||
|
.view(-1)
|
||
|
)
|
||
|
tmp1 = (
|
||
|
(torch.arange(0, nrows // 4 // magic1, device=device) * (magic0 * magic1))
|
||
|
.view(-1, 1)
|
||
|
.repeat(1, magic1)
|
||
|
.view(-1)
|
||
|
.repeat(ncols)
|
||
|
)
|
||
|
tmp2 = (
|
||
|
(torch.arange(0, magic0, device=device) * magic1)
|
||
|
.view(-1, 1)
|
||
|
.repeat(1, nrows // 4)
|
||
|
.view(-1)
|
||
|
.repeat(ncols // magic0)
|
||
|
)
|
||
|
tmp3 = torch.arange(0, magic1, device=device).repeat(nrows // 4 * ncols // magic1)
|
||
|
|
||
|
outp_offsets = tmp0 + tmp1 + tmp2 + tmp3
|
||
|
|
||
|
tmp = outp.view(-1).view(torch.int32)
|
||
|
outp = torch.zeros_like(tmp)
|
||
|
outp.scatter_(0, outp_offsets, tmp)
|
||
|
outp = outp.view(weight.dtype)
|
||
|
|
||
|
# add_bias_and_interleave_quantized_tensor_inplace
|
||
|
tmp = outp.view(-1)
|
||
|
|
||
|
outp = torch.empty_like(tmp)
|
||
|
if dtypeq == torch.int8:
|
||
|
tmp = (tmp.to(torch.int) + 128).to(tmp.dtype)
|
||
|
outp[0::4] = tmp[0::4]
|
||
|
outp[1::4] = tmp[2::4]
|
||
|
outp[2::4] = tmp[1::4]
|
||
|
outp[3::4] = tmp[3::4]
|
||
|
elif dtypeq == torch.quint4x2:
|
||
|
tmp0 = ((tmp & 0xF) + 8) & 0xF
|
||
|
tmp0 = (tmp0[1::2] << 4) | tmp0[0::2]
|
||
|
tmp1 = (((tmp >> 4) & 0xF) + 8) & 0xF
|
||
|
tmp1 = (tmp1[1::2] << 4) | tmp1[0::2]
|
||
|
outp[0::4] = tmp0[0::2]
|
||
|
outp[1::4] = tmp0[1::2]
|
||
|
outp[2::4] = tmp1[0::2]
|
||
|
outp[3::4] = tmp1[1::2]
|
||
|
|
||
|
if dtypeq == torch.quint4x2:
|
||
|
nrows *= 2
|
||
|
ncols //= 2
|
||
|
|
||
|
return outp.view(nrows, ncols).view(torch.uint8)
|