Repository URL to install this package:
|
Version:
0.2.4+cu122 ▾
|
import torch
class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device):
"""
The window size is the same as the max_seq_len. The window will
automatically roll once max_seq_len is exceeded.
"""
# [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half()
self.max_seq_len = max_seq_len
def get_kv(self, batch_size, start_pos, seqlen, head_dim):
"""
Gets the key-value store in correct shapes.
"""
xv = (
self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
)
xk = (
self.k[:batch_size, :, :, : start_pos + seqlen, :]
.transpose(2, 3)
.contiguous()
)
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
return xv, xk
def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
"""
Updates the values in the key-value store.
"""
self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store
def roll_kv_n_steps(self, start_pos, n=100):
"""
Roll cache n to the left.
"""
n = min(n, self.max_seq_len)
# Roll cache to the left
self.v = torch.roll(self.v, shifts=-n, dims=2)
self.k = torch.roll(self.k, shifts=-n, dims=3)
# Zero out the new part
self.v[:, :, -n:, :] = 0
self.k[:, :, :, -n:, :] = 0
return start_pos - n
def to(self, device):
self.k = self.k.to(device)
self.v = self.v.to(device)
def increase_batch_size(self, to_bsz):
"""Dynamically allocate new kv when batch size changes."""
self.v = torch.zeros(
to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device
)
self.k = torch.zeros(
to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device
)
def decrease_batch_size(self, to_bsz):
"""Dynamically remove part of cache if batch size changes."""
self.v = self.v[:to_bsz, :, :, :]
self.k = self.k[:to_bsz, :, :, :, :]