diff --git a/backend/app/graph.py b/backend/app/graph.py index 22fa2e5..ee50925 100644 --- a/backend/app/graph.py +++ b/backend/app/graph.py @@ -9,13 +9,13 @@ from langgraph.prebuilt import create_react_agent from langgraph_supervisor import create_supervisor from app.agents import get_tools_by_names -from app.intent import ClassificationResult, IntentClassifier if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.graph.state import CompiledStateGraph + from app.intent import ClassificationResult, IntentClassifier from app.registry import AgentRegistry logger = logging.getLogger(__name__) diff --git a/backend/tests/integration/test_phase2_checkpoints.py b/backend/tests/integration/test_phase2_checkpoints.py index 50c3a24..3ed517c 100644 --- a/backend/tests/integration/test_phase2_checkpoints.py +++ b/backend/tests/integration/test_phase2_checkpoints.py @@ -90,7 +90,7 @@ def _agent(name: str, desc: str, perm: str = "read") -> AgentConfig: # --------------------------------------------------------------------------- @pytest.mark.integration -class TestCheckpoint1_OrderQueryRouting: +class TestCheckpoint1OrderQueryRouting: """Verify intent classifier routes order queries to order_lookup.""" @pytest.mark.asyncio @@ -161,7 +161,7 @@ class TestCheckpoint1_OrderQueryRouting: # --------------------------------------------------------------------------- @pytest.mark.integration -class TestCheckpoint2_MultiIntentSequential: +class TestCheckpoint2MultiIntentSequential: """Verify multi-intent classified and hint injected for sequential execution.""" @pytest.mark.asyncio @@ -235,7 +235,7 @@ class TestCheckpoint2_MultiIntentSequential: # --------------------------------------------------------------------------- @pytest.mark.integration -class TestCheckpoint3_AmbiguousClarification: +class TestCheckpoint3AmbiguousClarification: """Verify ambiguous messages trigger clarification prompt.""" @pytest.mark.asyncio @@ -263,7 +263,9 @@ class TestCheckpoint3_AmbiguousClarification: mock_classifier.classify = AsyncMock(return_value=ClassificationResult( intents=(), is_ambiguous=True, - clarification_question="Could you please provide more details about what you need help with?", + clarification_question=( + "Could you please provide more details about what you need help with?" + ), )) graph.intent_classifier = mock_classifier mock_registry = MagicMock() @@ -294,7 +296,7 @@ class TestCheckpoint3_AmbiguousClarification: # --------------------------------------------------------------------------- @pytest.mark.integration -class TestCheckpoint4_InterruptTTLAutoCancel: +class TestCheckpoint4InterruptTTLAutoCancel: """Verify interrupt TTL expiration triggers auto-cancel and retry prompt.""" @pytest.mark.asyncio @@ -361,7 +363,7 @@ class TestCheckpoint4_InterruptTTLAutoCancel: # --------------------------------------------------------------------------- @pytest.mark.integration -class TestCheckpoint5_WebhookEscalation: +class TestCheckpoint5WebhookEscalation: """Verify webhook escalation sends POST and retries on failure.""" @pytest.mark.asyncio @@ -442,7 +444,7 @@ class TestCheckpoint5_WebhookEscalation: # --------------------------------------------------------------------------- @pytest.mark.integration -class TestCheckpoint6_EcommerceTemplate: +class TestCheckpoint6EcommerceTemplate: """Verify e-commerce template loads with correct agents.""" def test_ecommerce_template_loads_4_agents(self) -> None: diff --git a/backend/tests/integration/test_routing.py b/backend/tests/integration/test_routing.py index 4d8f24a..2ec8634 100644 --- a/backend/tests/integration/test_routing.py +++ b/backend/tests/integration/test_routing.py @@ -24,7 +24,6 @@ from app.registry import AgentConfig from app.session_manager import SessionManager from app.ws_handler import dispatch_message - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -78,10 +77,22 @@ def _state(*, interrupt: bool = False, data: dict | None = None): AGENTS = ( - AgentConfig(name="order_lookup", description="Looks up orders", permission="read", tools=["get_order_status", "get_tracking_info"]), - AgentConfig(name="order_actions", description="Modifies orders", permission="write", tools=["cancel_order"]), - AgentConfig(name="discount", description="Applies discounts", permission="write", tools=["apply_discount", "generate_coupon"]), - AgentConfig(name="fallback", description="Handles unclear requests", permission="read", tools=["fallback_respond"]), + AgentConfig( + name="order_lookup", description="Looks up orders", + permission="read", tools=["get_order_status", "get_tracking_info"], + ), + AgentConfig( + name="order_actions", description="Modifies orders", + permission="write", tools=["cancel_order"], + ), + AgentConfig( + name="discount", description="Applies discounts", + permission="write", tools=["apply_discount", "generate_coupon"], + ), + AgentConfig( + name="fallback", description="Handles unclear requests", + permission="read", tools=["fallback_respond"], + ), ) @@ -136,7 +147,9 @@ class TestSingleIntentRouting: @pytest.mark.asyncio async def test_routes_to_order_lookup(self) -> None: result = ClassificationResult( - intents=(IntentTarget(agent_name="order_lookup", confidence=0.95, reasoning="status query"),), + intents=(IntentTarget( + agent_name="order_lookup", confidence=0.95, reasoning="status query", + ),), ) graph = _make_graph(result, [ _tool_chunk("get_order_status", {"order_id": "1042"}, "order_lookup"), diff --git a/backend/tests/integration/test_websocket.py b/backend/tests/integration/test_websocket.py index 74222dd..25e4de0 100644 --- a/backend/tests/integration/test_websocket.py +++ b/backend/tests/integration/test_websocket.py @@ -19,7 +19,6 @@ from app.interrupt_manager import InterruptManager from app.session_manager import SessionManager from app.ws_handler import dispatch_message - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -182,7 +181,8 @@ class TestWebSocketInterruptApproval: @pytest.mark.asyncio async def test_interrupt_then_approve(self) -> None: st_int = _state(interrupt=True, data={"action": "cancel_order", "order_id": "1042"}) - g = _graph(chunks=[], st=st_int, resume_chunks=[_chunk("Order 1042 cancelled.", "order_actions")]) + resume = [_chunk("Order 1042 cancelled.", "order_actions")] + g = _graph(chunks=[], st=st_int, resume_chunks=resume) g_, sm, im, cb, ws = _setup(graph=g) # Send message -> triggers interrupt @@ -209,7 +209,8 @@ class TestWebSocketInterruptApproval: @pytest.mark.asyncio async def test_interrupt_then_reject(self) -> None: st_int = _state(interrupt=True) - g = _graph(chunks=[], st=st_int, resume_chunks=[_chunk("Order remains active.", "order_actions")]) + resume = [_chunk("Order remains active.", "order_actions")] + g = _graph(chunks=[], st=st_int, resume_chunks=resume) g_, sm, im, cb, ws = _setup(graph=g) await _send(ws, g_, sm, im, cb, content="Cancel order 1042") diff --git a/backend/tests/unit/test_intent.py b/backend/tests/unit/test_intent.py index a62da92..3a9a98b 100644 --- a/backend/tests/unit/test_intent.py +++ b/backend/tests/unit/test_intent.py @@ -16,8 +16,10 @@ from app.intent import ( from app.registry import AgentConfig -def _make_agent(name: str, desc: str = "test", permission: str = "read") -> AgentConfig: - return AgentConfig(name=name, description=desc, permission=permission, tools=["fallback_respond"]) +def _make_agent(name: str, desc: str = "test", perm: str = "read") -> AgentConfig: + return AgentConfig( + name=name, description=desc, permission=perm, tools=["fallback_respond"], + ) @pytest.mark.unit diff --git a/backend/tests/unit/test_ws_handler.py b/backend/tests/unit/test_ws_handler.py index 14443b9..eee9661 100644 --- a/backend/tests/unit/test_ws_handler.py +++ b/backend/tests/unit/test_ws_handler.py @@ -222,7 +222,9 @@ class TestHandleUserMessage: im = InterruptManager() sm.touch("t1") - await handle_user_message(ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im) + await handle_user_message( + ws, graph, sm, cb, "t1", "cancel order 1042", interrupt_manager=im, + ) # Interrupt should be registered assert im.has_pending("t1")