import asyncio
from collections.abc import AsyncGenerator

import pytest
from freezegun import freeze_time
from graphql.utilities import get_introspection_query

import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.extensions.tracing.apollo import (
    ApolloTracingExtension,
    ApolloTracingExtensionSync,
)


@freeze_time("20120114 12:00:01")
def test_tracing_sync(mocker):
    mocker.patch(
        "strawberry.extensions.tracing.apollo.time.perf_counter_ns", return_value=0
    )

    @strawberry.type
    class Person:
        name: str = "Jess"

    @strawberry.type
    class Query:
        @strawberry.field
        def person(self) -> Person:
            return Person()

    schema = strawberry.Schema(query=Query, extensions=[ApolloTracingExtensionSync])

    query = """
        query {
            person {
                name
            }
        }
    """

    result = schema.execute_sync(query)

    assert not result.errors

    assert result.extensions == {
        "tracing": {
            "version": 1,
            "startTime": "2012-01-14T12:00:01.000000Z",
            "endTime": "2012-01-14T12:00:01.000000Z",
            "duration": 0,
            "execution": {
                "resolvers": [
                    {
                        "path": ["person"],
                        "field_name": "person",
                        "parentType": "Query",
                        "returnType": "Person!",
                        "startOffset": 0,
                        "duration": 0,
                    },
                ]
            },
            "validation": {"startOffset": 0, "duration": 0},
            "parsing": {"startOffset": 0, "duration": 0},
        }
    }


@pytest.mark.asyncio
@freeze_time("20120114 12:00:01")
async def test_tracing_async(mocker):
    mocker.patch(
        "strawberry.extensions.tracing.apollo.time.perf_counter_ns", return_value=0
    )

    @strawberry.type
    class Person:
        name: str = "Jess"

    @strawberry.type
    class Query:
        @strawberry.field
        def example(self) -> str:
            return "Hi"

        @strawberry.field
        async def person(self) -> Person:
            return Person()

    schema = strawberry.Schema(query=Query, extensions=[ApolloTracingExtension])

    query = """
        query {
            example
            person {
                name
            }
        }
    """

    result = await schema.execute(query)

    assert not result.errors

    assert result.extensions == {
        "tracing": {
            "version": 1,
            "startTime": "2012-01-14T12:00:01.000000Z",
            "endTime": "2012-01-14T12:00:01.000000Z",
            "duration": 0,
            "execution": {
                "resolvers": [
                    {
                        "duration": 0,
                        "field_name": "example",
                        "parentType": "Query",
                        "path": ["example"],
                        "returnType": "String!",
                        "startOffset": 0,
                    },
                    {
                        "path": ["person"],
                        "field_name": "person",
                        "parentType": "Query",
                        "returnType": "Person!",
                        "startOffset": 0,
                        "duration": 0,
                    },
                ]
            },
            "validation": {"startOffset": 0, "duration": 0},
            "parsing": {"startOffset": 0, "duration": 0},
        }
    }


@freeze_time("20120114 12:00:01")
def test_should_not_trace_introspection_sync_queries(mocker):
    mocker.patch(
        "strawberry.extensions.tracing.apollo.time.perf_counter_ns", return_value=0
    )

    @strawberry.type
    class Person:
        name: str = "Jess"

    @strawberry.type
    class Query:
        @strawberry.field
        def person(self) -> Person:
            return Person()

    schema = strawberry.Schema(query=Query, extensions=[ApolloTracingExtensionSync])

    result = schema.execute_sync(get_introspection_query())

    assert not result.errors
    assert result.extensions == {
        "tracing": {
            "version": 1,
            "startTime": "2012-01-14T12:00:01.000000Z",
            "endTime": "2012-01-14T12:00:01.000000Z",
            "duration": 0,
            "execution": {"resolvers": []},
            "validation": {"startOffset": 0, "duration": 0},
            "parsing": {"startOffset": 0, "duration": 0},
        }
    }


@pytest.mark.asyncio
@freeze_time("20120114 12:00:01")
async def test_should_not_trace_introspection_async_queries(mocker):
    mocker.patch(
        "strawberry.extensions.tracing.apollo.time.perf_counter_ns", return_value=0
    )

    @strawberry.type
    class Person:
        name: str = "Jess"

    @strawberry.type
    class Query:
        @strawberry.field
        async def person(self) -> Person:
            return Person()

    schema = strawberry.Schema(query=Query, extensions=[ApolloTracingExtension])

    result = await schema.execute(get_introspection_query())

    assert not result.errors
    assert result.extensions == {
        "tracing": {
            "version": 1,
            "startTime": "2012-01-14T12:00:01.000000Z",
            "endTime": "2012-01-14T12:00:01.000000Z",
            "duration": 0,
            "execution": {"resolvers": []},
            "validation": {"startOffset": 0, "duration": 0},
            "parsing": {"startOffset": 0, "duration": 0},
        }
    }


@pytest.mark.asyncio
async def test_tracing_resolvers_populated_on_multiple_executions():
    """Test that resolvers field is populated on every execution, not just the first."""

    @strawberry.type
    class Query:
        @strawberry.field
        def node(self) -> str:
            return ""

    schema = strawberry.Schema(Query, extensions=[ApolloTracingExtension])

    for i in range(3):
        result = await schema.execute("{ node }")
        assert not result.errors
        assert result.extensions is not None
        resolvers = result.extensions["tracing"]["execution"]["resolvers"]
        assert len(resolvers) == 1, (
            f"Expected 1 resolver on execution {i}, got {len(resolvers)}"
        )
        assert resolvers[0]["field_name"] == "node"


def test_tracing_resolvers_populated_on_multiple_sync_executions():
    """Test that resolvers field is populated on every sync execution, not just the first."""

    @strawberry.type
    class Query:
        @strawberry.field
        def node(self) -> str:
            return ""

    schema = strawberry.Schema(Query, extensions=[ApolloTracingExtensionSync])

    for i in range(3):
        result = schema.execute_sync("{ node }")
        assert not result.errors
        assert result.extensions is not None
        resolvers = result.extensions["tracing"]["execution"]["resolvers"]
        assert len(resolvers) == 1, (
            f"Expected 1 resolver on execution {i}, got {len(resolvers)}"
        )
        assert resolvers[0]["field_name"] == "node"


@pytest.mark.asyncio
async def test_concurrent_execution_context_result():
    """Regression test: execution_context.result must not be None during concurrent requests."""
    results = []

    class RecordResultExtension(SchemaExtension):
        async def on_operation(self):
            yield
            results.append(self.execution_context.result)

    @strawberry.type
    class Query:
        @strawberry.field
        def node(self) -> str:
            return ""

    schema = strawberry.Schema(Query, extensions=[RecordResultExtension])
    await asyncio.gather(
        schema.execute("query { node }"),
        schema.execute("query { node }"),
    )
    assert len(results) == 2
    assert all(r is not None for r in results)


@pytest.mark.asyncio
async def test_concurrent_subscribe_has_isolated_extensions():
    """Regression test: concurrent subscriptions must get isolated extension instances."""
    contexts: list[object] = []

    class RecordContextExtension(SchemaExtension):
        async def on_execute(self):
            contexts.append(id(self.execution_context))
            yield

    @strawberry.type
    class Query:
        @strawberry.field
        def node(self) -> str:
            return ""

    @strawberry.type
    class Subscription:
        @strawberry.subscription
        async def node(self) -> AsyncGenerator[str, None]:
            yield ""

    schema = strawberry.Schema(
        query=Query, subscription=Subscription, extensions=[RecordContextExtension]
    )

    async def run() -> None:
        sub = await schema.subscribe("subscription { node }")
        async for result in sub:  # type: ignore[union-attr]
            assert result.data == {"node": ""}
            break

    await asyncio.gather(run(), run())
    assert len(contexts) == 2
    assert contexts[0] != contexts[1], "each subscription must have its own context"
