1분 요약

  • KV Cache가 필요한 이유
  • engine.py 소스 코드 분석
  • 토큰 생성 루프 구현
  • Temperature와 Top-p 샘플링
  • 배치 추론 최적화
  • 메모리 vs 속도 트레이드오프

KV Cache란?

언어 모델의 추론 속도를 높이는 핵심 기술 Key, Value 값을 저장해 두고 재사용하여 연산량을 획기적으로 줄인다.

O(n^2) O(n)으로 최적화 메모리와 속도 트레이드오프가 생길 수 있음

필요성

KV Cache가 없을 경우 매번 전체 시퀀스를 처음부터 다시 계산해야 한다.

def generate_without_cache():
	for i in range(max_length):
		outputs = model(input_ids)
		next_token = outputs.logits[:, -1, :].argmax(dim=-1)
		input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim = -1)
	return input_ids
 
def generate_with_cache(model, input_ids, max_length):
	past_key_values = None
	for i in range(max_length):
		outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
		next_token = outputs.logits[:, -1, :].argmax(dim=-1)
		input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
	return input_ids
class LLMEngine:
	def __init__(self, model):
		self.model = model
		self.kv_cache = {}
	
	def generate(self, prompt, max_tokens=100):
		input_ids = self.tokenizer.encode(prompt)
		past_kv = None
		
		for _ in range(max_tokens):
			logits, past_kv = self.model(
				input_ids[-1:]
				past_key_values=past_kv
			)
			next_token = self._sample(logits)
			input_ids.append(next_token)

KV Cache 최적화 방법

배치 추론 최적화

  • KV Cache의 경우 각 시퀀스의 길이가 다를 수 있으므로 동적 배칭과 패딩 전략이 중요하다.
def batch_generate(model, prompts, max_length=100):
	batch_input_ids = [tokenizer.encode(p) for p in prompts]
	
	batch_size = len(prompts)
	past_key_balues = [None] * batch_size
	
	for step in range(max_length):
		batch_outputs = []
		for i, input_ids in enumerate(batch_input_ids):
			outputs = model(
				input_ids = torch.tensor([input_ids[-1]]).unzqueeze(0),
				past_key_values = past_key_values[i],
				use_cache=True
			)