#!/usr/bin/env python3
"""
GHL MCP Client — HighLevel Model Context Protocol Integration
=============================================================
Connects to the GHL MCP server (https://services.leadconnectorhq.com/mcp/)
and provides a Python interface to all 36 GHL tools.

Usage:
    # As a library
    from ghl_mcp_client import GHLMCPClient
    client = GHLMCPClient(pit_token="pit-...", location_id="...")
    tools = await client.list_tools()
    result = await client.call_tool("contacts_get-contacts", {"query_limit": "10"})

    # As a standalone script
    python ghl_mcp_client.py --action list-tools
    python ghl_mcp_client.py --action call --tool contacts_get-contacts --params '{"query_limit":"10"}'
    python ghl_mcp_client.py --action interactive

Author: Builder Agent
Created: 2026-04-09
Location: ifitVy09cFwlEVgEAFDS
"""

import os
import sys
import json
import asyncio
import argparse
import logging
from typing import Any, Optional
from contextlib import asynccontextmanager

from mcp.client.streamable_http import streamablehttp_client
from mcp import ClientSession

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DEFAULT_MCP_URL = "https://services.leadconnectorhq.com/mcp/"
DEFAULT_PIT_TOKEN = os.getenv("GHL_MCP_PIT_TOKEN", os.environ.get("GHL_AGENCY_PIT", ""))
DEFAULT_LOCATION_ID = os.getenv("GHL_MCP_LOCATION_ID", "ifitVy09cFwlEVgEAFDS")

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger("ghl-mcp")


class GHLMCPClient:
    """
    Async client for the HighLevel MCP Server.
    Manages session lifecycle and provides tool listing + calling.
    """

    def __init__(
        self,
        pit_token: str = DEFAULT_PIT_TOKEN,
        location_id: str = DEFAULT_LOCATION_ID,
        mcp_url: str = DEFAULT_MCP_URL,
    ):
        self.pit_token = pit_token
        self.location_id = location_id
        self.mcp_url = mcp_url
        self._session: Optional[ClientSession] = None
        self._tools_cache: Optional[list] = None

    @asynccontextmanager
    async def connect(self):
        """Context manager that establishes and tears down MCP session."""
        headers = {
            "Authorization": f"Bearer {self.pit_token}",
            "locationId": self.location_id,
        }
        log.info(f"Connecting to GHL MCP at {self.mcp_url} (location: {self.location_id})")

        async with streamablehttp_client(self.mcp_url, headers=headers) as (read_stream, write_stream, _):
            async with ClientSession(read_stream, write_stream) as session:
                await session.initialize()
                self._session = session
                self._tools_cache = None
                log.info("MCP session initialized successfully")
                yield self
                self._session = None

    async def list_tools(self) -> list[dict]:
        """Return all available GHL MCP tools with name, description, and input schema."""
        if not self._session:
            raise RuntimeError("Not connected. Use `async with client.connect():` first.")
        if self._tools_cache:
            return self._tools_cache

        result = await self._session.list_tools()
        self._tools_cache = [
            {
                "name": t.name,
                "description": t.description,
                "input_schema": t.inputSchema if hasattr(t, "inputSchema") else {},
            }
            for t in result.tools
        ]
        log.info(f"Loaded {len(self._tools_cache)} tools from GHL MCP")
        return self._tools_cache

    async def call_tool(self, tool_name: str, arguments: dict[str, Any] = None) -> dict:
        """
        Call a GHL MCP tool by name with the given arguments.
        Returns the parsed result.
        """
        if not self._session:
            raise RuntimeError("Not connected. Use `async with client.connect():` first.")

        arguments = arguments or {}
        log.info(f"Calling tool: {tool_name} with args: {json.dumps(arguments, indent=2)}")

        result = await self._session.call_tool(tool_name, arguments)

        # Parse response content
        parsed = []
        for content_block in result.content:
            if hasattr(content_block, "text"):
                try:
                    parsed.append(json.loads(content_block.text))
                except json.JSONDecodeError:
                    parsed.append({"text": content_block.text})
            else:
                parsed.append({"type": str(type(content_block)), "data": str(content_block)})

        return {
            "tool": tool_name,
            "is_error": result.isError if hasattr(result, "isError") else False,
            "results": parsed if len(parsed) > 1 else (parsed[0] if parsed else {}),
        }

    async def get_tool_schema(self, tool_name: str) -> Optional[dict]:
        """Get the input schema for a specific tool."""
        tools = await self.list_tools()
        for t in tools:
            if t["name"] == tool_name:
                return t
        return None


# ---------------------------------------------------------------------------
# Convenience functions for quick one-shot calls
# ---------------------------------------------------------------------------
async def quick_call(tool_name: str, arguments: dict = None, **kwargs) -> dict:
    """One-shot: connect, call a tool, disconnect."""
    client = GHLMCPClient(**kwargs)
    async with client.connect():
        return await client.call_tool(tool_name, arguments)


async def quick_list_tools(**kwargs) -> list[dict]:
    """One-shot: connect, list tools, disconnect."""
    client = GHLMCPClient(**kwargs)
    async with client.connect():
        return await client.list_tools()


# ---------------------------------------------------------------------------
# Interactive mode
# ---------------------------------------------------------------------------
async def interactive_mode():
    """Interactive REPL for exploring GHL MCP tools."""
    client = GHLMCPClient()
    async with client.connect():
        tools = await client.list_tools()
        print(f"\n{'='*60}")
        print(f"  GHL MCP Interactive Client")
        print(f"  Location: {client.location_id}")
        print(f"  Available tools: {len(tools)}")
        print(f"{'='*60}\n")

        # Show tool categories
        categories = {}
        for t in tools:
            cat = t["name"].split("_")[0]
            categories.setdefault(cat, []).append(t["name"])

        print("Tool categories:")
        for cat, tool_names in sorted(categories.items()):
            print(f"  {cat} ({len(tool_names)} tools)")
            for name in tool_names:
                print(f"    - {name}")
        print()

        while True:
            try:
                user_input = input("ghl-mcp> ").strip()
            except (EOFError, KeyboardInterrupt):
                print("\nGoodbye!")
                break

            if not user_input:
                continue
            if user_input in ("quit", "exit", "q"):
                print("Goodbye!")
                break
            if user_input == "tools":
                for t in tools:
                    print(f"  {t['name']}: {t['description'][:80]}...")
                continue
            if user_input.startswith("schema "):
                name = user_input.split(" ", 1)[1]
                schema = await client.get_tool_schema(name)
                if schema:
                    print(json.dumps(schema, indent=2))
                else:
                    print(f"Tool '{name}' not found")
                continue
            if user_input.startswith("call "):
                parts = user_input.split(" ", 2)
                tool_name = parts[1]
                args = json.loads(parts[2]) if len(parts) > 2 else {}
                try:
                    result = await client.call_tool(tool_name, args)
                    print(json.dumps(result, indent=2, default=str))
                except Exception as e:
                    print(f"Error: {e}")
                continue

            print("Commands: tools | schema <name> | call <name> [json_args] | quit")


# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
async def main():
    parser = argparse.ArgumentParser(description="GHL MCP Client")
    parser.add_argument("--action", choices=["list-tools", "call", "interactive", "test"],
                        default="test", help="Action to perform")
    parser.add_argument("--tool", help="Tool name (for --action call)")
    parser.add_argument("--params", help="JSON params (for --action call)", default="{}")
    parser.add_argument("--pit-token", default=DEFAULT_PIT_TOKEN)
    parser.add_argument("--location-id", default=DEFAULT_LOCATION_ID)
    args = parser.parse_args()

    kwargs = {"pit_token": args.pit_token, "location_id": args.location_id}

    if args.action == "list-tools":
        tools = await quick_list_tools(**kwargs)
        print(f"\n{'='*60}")
        print(f"  GHL MCP Tools ({len(tools)} available)")
        print(f"{'='*60}\n")
        for i, t in enumerate(tools, 1):
            print(f"  {i:2d}. {t['name']}")
            desc_lines = t['description'].split('\n')
            print(f"      {desc_lines[0][:100]}")
            print()

    elif args.action == "call":
        if not args.tool:
            print("Error: --tool required for call action")
            sys.exit(1)
        params = json.loads(args.params)
        result = await quick_call(args.tool, params, **kwargs)
        print(json.dumps(result, indent=2, default=str))

    elif args.action == "interactive":
        await interactive_mode()

    elif args.action == "test":
        print("Running GHL MCP connection test...")
        client = GHLMCPClient(**kwargs)
        async with client.connect():
            # 1. List tools
            tools = await client.list_tools()
            print(f"  [PASS] Connected to GHL MCP - {len(tools)} tools available")

            # 2. Get location info
            print("  Testing: locations_get-location...")
            result = await client.call_tool("locations_get-location", {
                "path_locationId": client.location_id
            })
            if not result.get("is_error"):
                loc_data = result.get("results", {})
                name = loc_data.get("location", {}).get("name", loc_data.get("name", "Unknown"))
                print(f"  [PASS] Location: {name}")
            else:
                print(f"  [WARN] Location call returned: {result}")

            # 3. Get contacts (limit 3)
            print("  Testing: contacts_get-contacts...")
            result = await client.call_tool("contacts_get-contacts", {
                "query_limit": "3",
                "query_locationId": client.location_id
            })
            if not result.get("is_error"):
                contacts = result.get("results", {}).get("contacts", [])
                print(f"  [PASS] Contacts retrieved: {len(contacts)}")
                for c in contacts[:3]:
                    name = f"{c.get('firstName', '')} {c.get('lastName', '')}".strip()
                    email = c.get("email", "N/A")
                    print(f"         - {name} ({email})")
            else:
                print(f"  [WARN] Contacts call returned: {result}")

            # 4. Get pipelines
            print("  Testing: opportunities_get-pipelines...")
            result = await client.call_tool("opportunities_get-pipelines", {
                "query_locationId": client.location_id
            })
            if not result.get("is_error"):
                pipelines = result.get("results", {}).get("pipelines", [])
                print(f"  [PASS] Pipelines retrieved: {len(pipelines)}")
                for p in pipelines:
                    print(f"         - {p.get('name', 'Unknown')} (stages: {len(p.get('stages', []))})")
            else:
                print(f"  [WARN] Pipelines call returned: {result}")

            print(f"\n  {'='*50}")
            print(f"  GHL MCP Integration Test Complete")
            print(f"  All {len(tools)} tools are accessible!")
            print(f"  {'='*50}")


if __name__ == "__main__":
    asyncio.run(main())
