動かざることバグの如し

近づきたいよ 君の理想に

LiteLLM Proxyで会話数(messages)を制限する方法

環境

  • litellm v1.80

やりたいこと

現在、LiteLLMをLLMプロキシサーバーとして利用し、ロールプレイアプリを実装している。

ロールプレイという性質上、チャットのターン数は急増しやすく、50ターンに達することも珍しくない。しかし、生成に必要な文脈は直近の会話だけで十分なケースが大半である。50ターン分すべてを送信してしまうと、コストが増大するだけでなく、不要な情報によりコンテキストがブレる原因にもなってしまう。

LiteLLM Proxy側で会話数を制限したい場合、本来であればクライアント実装側で過去ログを間引くのが定石である。しかし、今回は諸事情により「アプリ都合でそれができない(クライアント側で会話数を制限できない)」という制約がある。

そこで、指定したターン数以上を切り捨てる処理をProxy側で実装することで対応した。

コード

from litellm.integrations.custom_logger import CustomLogger  
from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache  
from typing import Any, Literal  
  
class MessageTruncationHandler(CustomLogger):  
    def __init__(self, max_messages: int = 10):  
        self.max_messages = max_messages  

    async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
            "completion",
            "text_completion",
            "embeddings",
            "image_generation",
            "moderation",
            "audio_transcription",
        ])  -> dict: 
        # メッセージリストを取得  
        messages = data.get("messages", [])  
          
        # メッセージ数が制限を超えている場合、直近のmax_messages件に絞り込む  
        if len(messages) > self.max_messages:  
            # systemメッセージは保持し、直近のmax_messages-1件を追加  
            system_messages = [msg for msg in messages if msg.get("role") == "system"]  
            other_messages = [msg for msg in messages if msg.get("role") != "system"]  
              
            # 直近のメッセージを取得  
            recent_messages = other_messages[-(self.max_messages - len(system_messages)):]  
              
            # systemメッセージと直近メッセージを結合  
            data["messages"] = system_messages + recent_messages  
              
            print(f"Truncated messages from {len(messages)} to {len(data['messages'])}")  
          
        return data  
  
# インスタンスを作成  
message_truncation_handler = MessageTruncationHandler(max_messages=8)

LiteLLM Proxyには callbacks という仕組みがあり、リクエストを上流LLMに投げる直前に async_pre_call_hook() を挟める。ここで data(OpenAI互換のリクエストボディ)が渡されるので、data["messages"] を書き換えれば「直近Nターンだけ送る」が実現できる。

このコードがやっていることは以下である。

  • data["messages"] を取り出す
  • max_messages を超える場合だけ間引く
  • systemは残し、それ以外のメッセージを末尾から必要数だけ残す

ポイントは、LiteLLM Proxyが受け取る /v1/chat/completions のpayloadはだいたい {"messages": [...]} なので、ここを削るだけでprompt tokenが減り、コストとコンテキストブレを同時に抑えられる点だ。

設定

実際にProxyに組み込むには、Pythonファイルとして /app/message_truncation_callback.pyに配置して、proxyの config.yaml から参照する。形式は ファイル名.インスタンス名 である。

litellm_settings:
  # ファイル名.インスタンス名 で指定
  callbacks: ["message_truncation_callback.message_truncation_handler"]