Source code for src.ai.providers.llm_abstraction

"""
LLM Abstraction Layer for Marcus AI.

Provides a unified interface across different LLM providers
(Anthropic, OpenAI, local models)
with intelligent fallback and provider switching capabilities.

This module implements the strategy pattern for LLM providers, allowing seamless
switching between providers and automatic fallback on failures.

Classes
-------
LLMAbstraction
    Multi-provider LLM abstraction with intelligent fallback

Notes
-----
Provider selection is controlled by the MARCUS_LLM_PROVIDER environment variable.
The system automatically falls back to alternative providers on failure.

Examples
--------
>>> llm = LLMAbstraction()
>>> analysis = await llm.analyze_task_semantics(task, context)
>>> if analysis.fallback_used:
...     print("Using fallback provider")
"""

import asyncio
import functools
import logging
import os
from datetime import datetime, timezone
from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar

from src.core.models import Priority, Task, TaskStatus

from .base_provider import (
    BaseLLMProvider,
    EffortEstimate,
    SemanticAnalysis,
    SemanticDependency,
)

logger = logging.getLogger(__name__)


_T = TypeVar("_T")


def _tagged_operation(
    operation: str,
) -> Callable[[Callable[..., Awaitable[_T]]], Callable[..., Awaitable[_T]]]:
    """Wrap an LLMAbstraction method in ``recorder.operation_context``.

    Kaia review on PR #517: the five high-level methods on
    :class:`LLMAbstraction` (``analyze_task_semantics``,
    ``infer_dependencies_semantic``, ``generate_enhanced_description``,
    ``estimate_effort_intelligently``,
    ``analyze_blocker_and_suggest_solutions``) all needed the same
    four-line preamble that pushed an ``operation_context`` for the
    duration of the call. This decorator consolidates the boilerplate.

    Parameters
    ----------
    operation : str
        Operation key from :mod:`src.cost_tracking.operations` to
        stamp onto ``token_events.operation`` for calls made inside
        the wrapped method.

    Returns
    -------
    Callable
        A method decorator that wraps the original coroutine in the
        appropriate ``operation_context``.

    Notes
    -----
    The recorder import is performed lazily inside the wrapper so
    that ``llm_abstraction.py`` can be imported without forcing the
    cost-tracking module load — keeping startup paths lean.
    """

    def decorator(fn: Callable[..., Awaitable[_T]]) -> Callable[..., Awaitable[_T]]:
        @functools.wraps(fn)
        async def wrapper(self: Any, *args: Any, **kwargs: Any) -> _T:
            from src.cost_tracking.cost_recorder import get_recorder

            with get_recorder().operation_context(operation):
                return await fn(self, *args, **kwargs)

        return wrapper

    return decorator


[docs] class LLMAbstraction: """ Multi-provider LLM abstraction with intelligent fallback. Supports multiple LLM providers with automatic fallback when primary fails. Provides a unified interface for all AI operations in Marcus. Attributes ---------- providers : dict Available LLM provider instances current_provider : str Name of the primary provider to use fallback_providers : list of str Ordered list of providers to try on failure provider_stats : dict Performance statistics for each provider Methods ------- analyze_task_semantics(task, context) Analyze task meaning and intent infer_dependencies_semantic(tasks) Infer logical dependencies between tasks generate_enhanced_description(task, context) Create improved task descriptions estimate_effort_intelligently(task, context) AI-powered effort estimation analyze_blocker_and_suggest_solutions(task, blocker, severity, agent) Analyze blockers and provide solutions Notes ----- Providers are initialized lazily to avoid circular imports. Statistics are tracked for intelligent provider selection. """
[docs] def __init__(self) -> None: self.providers: Dict[str, BaseLLMProvider] = {} # Get provider from config first, then env var as override from src.config.marcus_config import get_config config = get_config() self.current_provider = config.ai.provider or "anthropic" # Build fallback list based on available providers self.fallback_providers: List[str] = [] # Initialize providers (deferred to avoid circular imports) self.providers = {} self._providers_initialized = False # Performance tracking initialized after providers load self.provider_stats: Dict[str, Dict[str, Any]] = {} logger.info( f"LLM abstraction initialized with primary provider: " f"{self.current_provider}" )
def _initialize_providers(self) -> None: """ Initialize available LLM providers. Uses lazy loading to avoid circular imports. Providers are only initialized when first needed. Notes ----- Failed provider initialization is logged but doesn't stop the system. At least one provider must initialize successfully. """ if self._providers_initialized: return logger.debug("Starting provider initialization...") # Load config to get API keys directly try: from src.config.marcus_config import get_config config = get_config() except Exception as e: logger.warning(f"Failed to load config, falling back to env vars: {e}") # Create minimal config with defaults from src.config.marcus_config import MarcusConfig config = MarcusConfig() # Provider lockdown (Marcus #531). When the user explicitly sets # ``config.ai.provider``, ONLY that provider initializes. Other # providers never enter ``self.providers`` and never become # fallback candidates — even if their credentials happen to be # present in config or in the environment. # # The earlier code gated only the ENV-VAR fallback by # ``configured_provider``, but left the init block gated only # on key validity. Config substitution (``"openai_api_key": # "${OPENAI_API_KEY}"`` in config_marcus.json) put a real OpenAI # key into ``config.ai.openai_api_key`` whenever the env var was # exported in the user's shell — so OpenAI silently joined the # fallback chain even when ``provider: anthropic`` was set, and # cascaded billing to OpenAI when Anthropic momentarily failed. configured_provider = config.ai.provider or "" def _allowed(name: str) -> bool: """Return True iff ``name`` may initialize under the current config. When ``configured_provider`` is set, only the matching name is allowed. When it's empty (legacy auto-discovery), every provider with valid credentials is allowed. """ return not configured_provider or configured_provider == name # ----- Anthropic ----------------------------------------------------- if _allowed("anthropic"): anthropic_key = config.ai.anthropic_api_key or "" if not anthropic_key: anthropic_key = os.getenv("CLAUDE_API_KEY", "").strip() if ( anthropic_key and anthropic_key.startswith("sk-ant-") and len(anthropic_key) > 10 and anthropic_key != "sk-ant-your-api-key-here" ): try: from .anthropic_provider import AnthropicProvider # Pass key directly to the provider — never write into # os.environ. ANTHROPIC_API_KEY in the env would force # Claude Code subprocesses (Epictetus, project creator, # workers, monitor) to bill the API instead of using the # user's Claude Code subscription. self.providers["anthropic"] = AnthropicProvider( api_key=anthropic_key ) self.fallback_providers.append("anthropic") logger.info("Successfully initialized Anthropic provider") except Exception as e: logger.warning(f"Failed to initialize Anthropic provider: {e}") else: logger.debug( f"Skipping Anthropic provider - no valid API key configured " f"(key present: {bool(anthropic_key)})" ) # ----- OpenAI -------------------------------------------------------- if _allowed("openai"): openai_key = config.ai.openai_api_key or "" if not openai_key: openai_key = os.getenv("OPENAI_API_KEY", "").strip() if ( openai_key and openai_key.startswith("sk-") and len(openai_key) > 10 and openai_key != "sk-your-openai-key-here" ): try: from .openai_provider import OpenAIProvider # Temporarily set env var for the provider os.environ["OPENAI_API_KEY"] = openai_key self.providers["openai"] = OpenAIProvider() self.fallback_providers.append("openai") logger.info("Successfully initialized OpenAI provider") except Exception as e: logger.warning(f"Failed to initialize OpenAI provider: {e}") else: logger.debug( f"Skipping OpenAI provider - no valid API key configured " f"(key present: {bool(openai_key)})" ) elif config.ai.openai_api_key or os.getenv("OPENAI_API_KEY", "").strip(): # Diagnostic only: user has an OpenAI key available somewhere # but `config.ai.provider` excludes openai. Tell them loudly # so they know we deliberately ignored the key. logger.info( "OpenAI key present but provider=%r — OpenAI deliberately " "NOT initialized. To use OpenAI, set ai.provider='openai' " "in config_marcus.json.", configured_provider, ) # ----- Cloud --------------------------------------------------------- # Cloud was already gated correctly (only inits when explicitly # configured), so the existing check stays. Comment kept for # consistency with the rewritten anthropic/openai blocks above. if _allowed("cloud") and configured_provider == "cloud": cloud_key = config.ai.cloud_api_key or "" if not cloud_key: cloud_key = os.getenv("MARCUS_CLOUD_LLM_KEY", "").strip() cloud_url = config.ai.cloud_url or "" if not cloud_url: cloud_url = os.getenv("MARCUS_CLOUD_LLM_URL", "").strip() cloud_model = config.ai.model or "" if cloud_key and cloud_url and cloud_model: try: from .cloud_provider import CloudLLMProvider self.providers["cloud"] = CloudLLMProvider( model=cloud_model, api_key=cloud_key, url=cloud_url, ) self.fallback_providers.append("cloud") logger.info( "Successfully initialized cloud LLM provider: " "model=%s url=%s", cloud_model, cloud_url, ) except Exception as e: logger.warning("Failed to initialize cloud LLM provider: %s", e) else: logger.debug( "Skipping cloud provider — missing key=%s url=%s model=%s", bool(cloud_key), bool(cloud_url), bool(cloud_model), ) # ----- Local --------------------------------------------------------- if _allowed("local"): local_model_path = config.ai.local_model or "" if not local_model_path: local_model_path = os.getenv("MARCUS_LOCAL_LLM_PATH", "").strip() if local_model_path: try: from .local_provider import LocalLLMProvider self.providers["local"] = LocalLLMProvider(local_model_path) self.fallback_providers.append("local") logger.info( f"Successfully initialized local LLM provider " f"with model: {local_model_path}" ) except Exception as e: logger.warning(f"Failed to initialize local LLM provider: {e}") # Hard-fail when the user explicitly set a provider and it didn't # initialize. The earlier code logged a warning and silently # cascaded to whichever provider happened to be available — that # caused real cost rows to land under the "wrong" provider after a # silent fallback. Surface the gap immediately. (Marcus #531) if configured_provider and configured_provider not in self.providers: raise RuntimeError( f"config.ai.provider={configured_provider!r} is set but the " f"provider failed to initialize. Refusing to silently fall " f"back to another provider. Check that the corresponding " f"credentials are present and valid in config_marcus.json " f"or the matching environment variable, then restart Marcus." ) # Initialize provider stats only for successfully loaded providers self.provider_stats = { provider: {"requests": 0, "failures": 0, "avg_response_time": 0.0} for provider in self.providers.keys() } # Ensure we have at least one provider if not self.providers: raise RuntimeError( "No LLM providers could be initialized. " "Please check your AI configuration." ) # Ensure current provider is available, otherwise use first available if self.current_provider not in self.providers: self.current_provider = list(self.providers.keys())[0] logger.warning( f"Requested provider not available, using {self.current_provider}" ) self._providers_initialized = True logger.info(f"Initialized providers: {list(self.providers.keys())}")
[docs] @_tagged_operation("analyze_task_semantics") async def analyze_task_semantics( self, task: Task, context: Dict[str, Any] ) -> SemanticAnalysis: """ Analyze task semantics using the best available provider. Parameters ---------- task : Task Task to analyze for semantic meaning context : dict Project context including related tasks Returns ------- SemanticAnalysis Comprehensive semantic analysis including intent and risks Notes ----- Automatically falls back to alternative providers on failure. Tagged ``analyze_task_semantics`` via :func:`_tagged_operation` so the resulting ``token_events.operation`` row carries that label instead of the provider's default. """ result = await self._execute_with_fallback( "analyze_task", task=task, context=context ) return result # type: ignore
[docs] @_tagged_operation("infer_dependencies") async def infer_dependencies_semantic( self, tasks: List[Task] ) -> List[SemanticDependency]: """ Infer semantic dependencies between tasks. Parameters ---------- tasks : list of Task All tasks to analyze for dependencies Returns ------- list of SemanticDependency Inferred logical relationships between tasks Notes ----- Complements rule-based dependency detection with semantic understanding. Tagged ``infer_dependencies`` via :func:`_tagged_operation`. """ result = await self._execute_with_fallback("infer_dependencies", tasks=tasks) return result # type: ignore
[docs] @_tagged_operation("enrich_task") async def generate_enhanced_description( self, task: Task, context: Dict[str, Any] ) -> str: """ Generate enhanced task description. Parameters ---------- task : Task Task needing clearer description context : dict Project context for better understanding Returns ------- str Enhanced description with more detail and clarity """ result = await self._execute_with_fallback( "generate_enhanced_description", task=task, context=context ) return result # type: ignore
[docs] @_tagged_operation("estimate_effort") async def estimate_effort_intelligently( self, task: Task, context: Dict[str, Any] ) -> EffortEstimate: """ Estimate task effort using AI. Parameters ---------- task : Task Task to estimate completion time for context : dict Project context with historical performance data Returns ------- EffortEstimate AI-powered time estimate with confidence and factors """ result = await self._execute_with_fallback( "estimate_effort", task=task, context=context ) return result # type: ignore
[docs] @_tagged_operation("analyze_blocker") async def analyze_blocker_and_suggest_solutions( self, task: Task, blocker_description: str, severity: str, agent: Optional[Dict[str, Any]], ) -> List[str]: """ Analyze a blocker and suggest solutions. Parameters ---------- task : Task The blocked task blocker_description : str Detailed description of the blocker severity : str Severity level: 'low', 'medium', or 'high' agent : dict, optional Agent information for context Returns ------- list of str Prioritized list of solution suggestions Notes ----- Higher severity blockers receive more detailed analysis. """ context = { "blocker_description": blocker_description, "severity": severity, "agent": agent, } result = await self._execute_with_fallback( "analyze_blocker", task=task, blocker=blocker_description, context=context, ) return result # type: ignore
async def _execute_with_fallback(self, method_name: str, **kwargs: Any) -> Any: """ Execute method with automatic provider fallback. Tries the primary provider first, then falls back to alternatives in order if the primary fails. Parameters ---------- method_name : str Name of the provider method to call **kwargs Arguments to pass to the method Returns ------- Any Result from the first successful provider Raises ------ Exception If all providers fail with details of each failure Notes ----- Updates provider statistics for intelligent future selection. Marks results with fallback_used=True when not using primary. """ # Ensure providers are initialized self._initialize_providers() providers_to_try = [self.current_provider] + [ p for p in self.fallback_providers if p != self.current_provider ] last_exception = None for provider_name in providers_to_try: if provider_name not in self.providers: continue provider = self.providers[provider_name] try: logger.debug(f"Trying {method_name} with provider: {provider_name}") # Track request self.provider_stats[provider_name]["requests"] += 1 # Execute method method = getattr(provider, method_name) result = await method(**kwargs) # Mark fallback usage if not primary if hasattr(result, "fallback_used"): result.fallback_used = provider_name != self.current_provider logger.debug( f"Successfully executed {method_name} with {provider_name}" ) return result except Exception as e: logger.warning( f"Provider {provider_name} failed for {method_name}: {e}" ) self.provider_stats[provider_name]["failures"] += 1 last_exception = e continue # All providers failed available_providers = list(self.providers.keys()) logger.error( f"All available providers {available_providers} failed for {method_name}" ) # Provide a more helpful error message if not available_providers: raise Exception( "No AI providers are configured. " "Please check your API keys in config_marcus.json. " "Make sure keys start with 'sk-ant-' for Anthropic or 'sk-' for OpenAI." ) elif len(available_providers) == 1: provider_name = available_providers[0] error_msg = f"{provider_name.capitalize()} API error: {last_exception}" if "401" in str(last_exception): error_msg += ( f". Please check that your {provider_name} API key in " f"config_marcus.json is valid and not expired." ) elif "API key" in str(last_exception): error_msg = ( f"Invalid {provider_name} API key. Please check that your " f"API key in config_marcus.json is correct and starts with " f"'{'sk-ant-' if provider_name == 'anthropic' else 'sk-'}'." ) else: error_msg = f"All LLM providers failed. Last error: {last_exception}" raise Exception(error_msg)
[docs] async def analyze( self, prompt: str, context: Any, *, operation: Optional[str] = None ) -> str: """ Analyze content using LLM. Parameters ---------- prompt : str The prompt to analyze. context : Any Analysis context (may carry ``max_tokens`` override). operation : str, optional Logical operation label to attach to the cost event. When provided, the recorder's active PlannerContext is shadowed with an ``operation_override`` for the duration of the call, so ``token_events.operation`` records ``operation`` instead of whatever default the provider stamps. Used for per-call drill-down in the Cato cost dashboard. See ``src/cost_tracking/operations.py`` for the canonical taxonomy and human-readable descriptions. Returns ------- str Analysis result as string. """ # Ensure providers are initialized before trying to use them self._initialize_providers() # Pass max_tokens through ONLY if the caller explicitly attached it # to ``context``. Otherwise let the provider use its configured # default (sourced from ``config.ai.max_tokens``). Previously this # method hardcoded 2000, which silently overrode the user's config # — a problem for reasoning-distilled models whose <think> blocks # alone exceed 2000 tokens before any structured output appears. kwargs: Dict[str, Any] = {"prompt": prompt} if hasattr(context, "max_tokens"): kwargs["max_tokens"] = context.max_tokens if operation: # Scope the recorder's active context with operation_override # so the resulting token_events row is tagged correctly. Local # import avoids cost_tracking import at module load. from src.cost_tracking.cost_recorder import get_recorder with get_recorder().operation_context(operation): result = await self._execute_with_fallback("complete", **kwargs) else: result = await self._execute_with_fallback("complete", **kwargs) return result # type: ignore
[docs] async def switch_provider(self, provider_name: str) -> bool: """ Switch to a different provider. Args ---- provider_name: Name of provider to switch to Returns ------- True if switch successful, False otherwise """ if provider_name not in self.providers: logger.error(f"Provider {provider_name} not available") return False old_provider = self.current_provider self.current_provider = provider_name logger.info(f"Switched LLM provider from {old_provider} to {provider_name}") return True
[docs] def get_provider_stats(self) -> Dict[str, Any]: """Get performance statistics for all providers.""" return { "current_provider": self.current_provider, "available_providers": list(self.providers.keys()), "stats": self.provider_stats.copy(), }
[docs] def get_best_provider(self) -> str: """ Determine the best performing provider based on success rate. Returns ------- Name of best performing provider """ best_provider = self.current_provider best_success_rate = 0.0 for provider, stats in self.provider_stats.items(): if provider not in self.providers: continue requests = stats["requests"] if requests == 0: continue success_rate = 1.0 - (stats["failures"] / requests) if success_rate > best_success_rate: best_success_rate = success_rate best_provider = provider return str(best_provider)
[docs] async def health_check(self) -> Dict[str, Any]: """ Check health of all providers. Returns ------- Health status for each provider """ health_status = {} for provider_name, provider in self.providers.items(): try: # Simple test request test_task = Task( id="health-check", name="Test task", description="Health check test", status=TaskStatus.TODO, priority=Priority.LOW, assigned_to=None, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), due_date=None, estimated_hours=1.0, ) await asyncio.wait_for( provider.analyze_task(test_task, {"project_type": "test"}), timeout=10.0, ) health_status[provider_name] = { "status": "healthy", "response_time": "< 10s", "last_check": "now", } except asyncio.TimeoutError: health_status[provider_name] = { "status": "timeout", "error": "Request timed out after 10s", } except Exception as e: health_status[provider_name] = {"status": "error", "error": str(e)} return health_status