1분 요약
- KV Cache가 필요한 이유
- engine.py 소스 코드 분석
- 토큰 생성 루프 구현
- Temperature와 Top-p 샘플링
- 배치 추론 최적화
- 메모리 vs 속도 트레이드오프
KV Cache란?
언어 모델의 추론 속도를 높이는 핵심 기술 Key, Value 값을 저장해 두고 재사용하여 연산량을 획기적으로 줄인다.
O(n^2) → O(n)으로 최적화 메모리와 속도 트레이드오프가 생길 수 있음
필요성
KV Cache가 없을 경우 매번 전체 시퀀스를 처음부터 다시 계산해야 한다.
def generate_without_cache():
for i in range(max_length):
outputs = model(input_ids)
next_token = outputs.logits[:, -1, :].argmax(dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim = -1)
return input_ids
def generate_with_cache(model, input_ids, max_length):
past_key_values = None
for i in range(max_length):
outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
next_token = outputs.logits[:, -1, :].argmax(dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
return input_idsclass LLMEngine:
def __init__(self, model):
self.model = model
self.kv_cache = {}
def generate(self, prompt, max_tokens=100):
input_ids = self.tokenizer.encode(prompt)
past_kv = None
for _ in range(max_tokens):
logits, past_kv = self.model(
input_ids[-1:]
past_key_values=past_kv
)
next_token = self._sample(logits)
input_ids.append(next_token)KV Cache 최적화 방법
배치 추론 최적화
- KV Cache의 경우 각 시퀀스의 길이가 다를 수 있으므로 동적 배칭과 패딩 전략이 중요하다.
def batch_generate(model, prompts, max_length=100):
batch_input_ids = [tokenizer.encode(p) for p in prompts]
batch_size = len(prompts)
past_key_balues = [None] * batch_size
for step in range(max_length):
batch_outputs = []
for i, input_ids in enumerate(batch_input_ids):
outputs = model(
input_ids = torch.tensor([input_ids[-1]]).unzqueeze(0),
past_key_values = past_key_values[i],
use_cache=True
)