Source code for src.cost_tracking.ai_usage_middleware

"""
AI Usage Middleware for Token Tracking.

Intercepts all AI provider calls to track token usage per project.
"""

import asyncio
import functools
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Literal, Optional

from src.cost_tracking.token_tracker import token_tracker
from src.logging.conversation_logger import conversation_logger


[docs] class AIUsageMiddleware: """ Middleware that wraps AI provider calls to track token usage. This intercepts all calls to AI providers (OpenAI, Anthropic, etc.) and tracks token consumption per project. """
[docs] def __init__(self) -> None: self.current_project_context: Dict[str, Dict[str, Any]] = {} self.token_tracker = token_tracker
[docs] def set_project_context( self, agent_id: str, project_id: str, task_id: Optional[str] = None ) -> None: """ Set the current project context for an agent. This should be called when an agent starts working on a project/task. """ self.current_project_context[agent_id] = { "project_id": project_id, "task_id": task_id, "start_time": datetime.now(timezone.utc), }
[docs] def clear_project_context(self, agent_id: str) -> None: """Clear project context when agent finishes.""" if agent_id in self.current_project_context: del self.current_project_context[agent_id]
[docs] def get_current_project(self, agent_id: str) -> Optional[str]: """Get current project for an agent.""" context = self.current_project_context.get(agent_id, {}) return context.get("project_id")
[docs] def track_ai_usage(self, func: Callable[..., Any]) -> Callable[..., Any]: """ Decorate AI provider methods to track token usage. This wraps AI provider methods to capture token usage. """ @functools.wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # Extract agent context if available agent_id = kwargs.get("agent_id") or getattr(args[0], "agent_id", None) project_id = None if agent_id: project_id = self.get_current_project(agent_id) # If no project context, try to extract from args if not project_id: # Check if project_id is in kwargs or context project_id = kwargs.get("project_id") if not project_id and len(args) > 1 and isinstance(args[1], dict): project_id = args[1].get("project_id") # Call the original function start_time = datetime.now(timezone.utc) result = await func(*args, **kwargs) end_time = datetime.now(timezone.utc) # Extract token usage from result if isinstance(result, dict): usage = result.get("usage", {}) input_tokens = usage.get("input_tokens", 0) output_tokens = usage.get("output_tokens", 0) model = result.get("model", "unknown") if input_tokens > 0 or output_tokens > 0: # Track tokens if not project_id: project_id = "unassigned" stats = await self.token_tracker.track_tokens( project_id=project_id, input_tokens=input_tokens, output_tokens=output_tokens, model=model, metadata={ "agent_id": agent_id, "task_id": self.current_project_context.get( agent_id or "system", {} ).get("task_id"), "duration_ms": (end_time - start_time).total_seconds() * 1000, "function": func.__name__, }, ) # Log significant usage if input_tokens + output_tokens > 1000: conversation_logger.log_pm_thinking( f"AI token usage for {project_id}: " f"{input_tokens + output_tokens} tokens " f"(${stats['total_cost']:.2f} total)", { "project_id": project_id, "tokens": input_tokens + output_tokens, "cost": stats["total_cost"], "rate": stats["current_spend_rate"], }, ) return result return wrapper
[docs] def wrap_ai_provider(self, provider_instance: Any) -> Any: """ Wrap an AI provider instance to track all its method calls. This modifies the provider instance to track token usage on all methods that make API calls. """ # Methods that typically make AI API calls ai_methods = [ "analyze", "complete", "chat", "generate", "call_model", "generate_task_instructions", "analyze_blocker", "generate_response", "classify", "embed", "summarize", ] for method_name in ai_methods: if hasattr(provider_instance, method_name): original_method = getattr(provider_instance, method_name) if asyncio.iscoroutinefunction(original_method): wrapped_method = self.track_ai_usage(original_method) setattr(provider_instance, method_name, wrapped_method) return provider_instance
# Global middleware instance ai_usage_middleware = AIUsageMiddleware()
[docs] def track_project_tokens(project_id: str, agent_id: Optional[str] = None) -> Any: """ Context manager to track AI tokens for a specific project. Usage: with track_project_tokens("project_123", "agent_1"): # All AI calls in this block will be tracked to project_123 await ai_engine.analyze(...) """ class TokenTrackingContext: def __init__(self, project_id: str, agent_id: Optional[str]): self.project_id = project_id self.agent_id = agent_id or "system" def __enter__(self) -> Any: ai_usage_middleware.set_project_context(self.agent_id, self.project_id) return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]: ai_usage_middleware.clear_project_context(self.agent_id) return False return TokenTrackingContext(project_id, agent_id)