"""
AWS IoT Fleet Provisioning for SecureLink Pi Gateway.

Uses a batch-specific claim certificate to register with AWS IoT Core,
exchange for a unique device certificate, and receive configuration.
"""

import json
import logging
import threading
import time
from pathlib import Path

from awsiot import iotidentity, mqtt_connection_builder
from awscrt import mqtt

from config import (
    CLAIM_CERT, CLAIM_KEY, AWS_ROOT_CA,
    DEVICE_CERT, DEVICE_KEY, DEVICE_CONFIG, DEVICE_DIR,
    IOT_ENDPOINT, IOT_PROVISIONING_TEMPLATE,
)

log = logging.getLogger('securelink.iot')


class IoTProvisioner:
    """Handles AWS IoT fleet provisioning to get unique device credentials."""

    def __init__(self, serial: str, mac: str):
        self.serial = serial
        self.mac = mac
        self.thing_name = None
        self._provisioning_done = threading.Event()
        self._error = None

    def is_provisioned(self) -> bool:
        """Check if device already has IoT credentials."""
        return DEVICE_CERT.exists() and DEVICE_KEY.exists()

    def provision(self) -> dict:
        """
        Run fleet provisioning flow:
        1. Connect with claim cert
        2. Create new keys and certificate
        3. Register thing via provisioning template
        4. Store device credentials
        Returns device config dict.
        """
        if not IOT_ENDPOINT:
            raise RuntimeError('IOT_ENDPOINT not configured')

        if not CLAIM_CERT.exists():
            raise RuntimeError(f'Claim certificate not found: {CLAIM_CERT}')

        log.info(f'Starting fleet provisioning (serial={self.serial})')

        # Connect to IoT Core with claim certificate
        connection = mqtt_connection_builder.mtls_from_path(
            endpoint=IOT_ENDPOINT,
            cert_filepath=str(CLAIM_CERT),
            pri_key_filepath=str(CLAIM_KEY),
            ca_filepath=str(AWS_ROOT_CA),
            client_id=f'securelink-provision-{self.serial}',
            clean_session=True,
        )

        connect_future = connection.connect()
        connect_future.result(timeout=30)
        log.info('Connected to IoT Core with claim certificate')

        identity_client = iotidentity.IotIdentityClient(connection)

        # Step 1: Create keys and certificate
        keys_response = self._create_keys_and_certificate(identity_client)
        certificate_id = keys_response['certificateId']
        certificate_pem = keys_response['certificatePem']
        private_key = keys_response['privateKey']
        ownership_token = keys_response['certificateOwnershipToken']

        log.info(f'Received device certificate: {certificate_id[:12]}...')

        # Step 2: Register thing with provisioning template
        thing_response = self._register_thing(
            identity_client, ownership_token
        )

        self.thing_name = thing_response.get('thingName',
                                              f'securelink-{self.serial}')
        config = thing_response.get('deviceConfiguration', {})
        log.info(f'Thing registered: {self.thing_name}')

        # Step 3: Store device credentials
        DEVICE_DIR.mkdir(parents=True, exist_ok=True)
        DEVICE_CERT.write_text(certificate_pem)
        DEVICE_KEY.write_text(private_key)
        DEVICE_KEY.chmod(0o600)

        device_config = {
            'thing_name': self.thing_name,
            'certificate_id': certificate_id,
            'serial': self.serial,
            'device_configuration': config,
        }
        DEVICE_CONFIG.write_text(json.dumps(device_config, indent=2))

        # Disconnect claim cert connection
        connection.disconnect().result(timeout=10)
        log.info('Fleet provisioning complete')

        return device_config

    def _create_keys_and_certificate(self, client) -> dict:
        """Request new keys and certificate from IoT Core."""
        result = {}
        error = None
        done = threading.Event()

        def on_accepted(response):
            nonlocal result
            result = {
                'certificateId': response.certificate_id,
                'certificatePem': response.certificate_pem,
                'privateKey': response.private_key,
                'certificateOwnershipToken': response.certificate_ownership_token,
            }
            done.set()

        def on_rejected(response):
            nonlocal error
            error = f'CreateKeysAndCertificate rejected: {response.error_message}'
            done.set()

        client.subscribe_to_create_keys_and_certificate_accepted(
            request=iotidentity.CreateKeysAndCertificateSubscriptionRequest(),
            qos=mqtt.QoS.AT_LEAST_ONCE,
            callback=on_accepted,
        ).result(timeout=10)

        client.subscribe_to_create_keys_and_certificate_rejected(
            request=iotidentity.CreateKeysAndCertificateSubscriptionRequest(),
            qos=mqtt.QoS.AT_LEAST_ONCE,
            callback=on_rejected,
        ).result(timeout=10)

        client.publish_create_keys_and_certificate(
            request=iotidentity.CreateKeysAndCertificateRequest(),
            qos=mqtt.QoS.AT_LEAST_ONCE,
        ).result(timeout=10)

        if not done.wait(timeout=30):
            raise RuntimeError('Timeout waiting for certificate creation')
        if error:
            raise RuntimeError(error)

        return result

    def _register_thing(self, client, ownership_token: str) -> dict:
        """Register thing using provisioning template."""
        result = {}
        error = None
        done = threading.Event()

        def on_accepted(response):
            nonlocal result
            result = {
                'thingName': response.thing_name,
                'deviceConfiguration': response.device_configuration or {},
            }
            done.set()

        def on_rejected(response):
            nonlocal error
            error = f'RegisterThing rejected: {response.error_message}'
            done.set()

        client.subscribe_to_register_thing_accepted(
            request=iotidentity.RegisterThingSubscriptionRequest(
                template_name=IOT_PROVISIONING_TEMPLATE,
            ),
            qos=mqtt.QoS.AT_LEAST_ONCE,
            callback=on_accepted,
        ).result(timeout=10)

        client.subscribe_to_register_thing_rejected(
            request=iotidentity.RegisterThingSubscriptionRequest(
                template_name=IOT_PROVISIONING_TEMPLATE,
            ),
            qos=mqtt.QoS.AT_LEAST_ONCE,
            callback=on_rejected,
        ).result(timeout=10)

        client.publish_register_thing(
            request=iotidentity.RegisterThingRequest(
                template_name=IOT_PROVISIONING_TEMPLATE,
                certificate_ownership_token=ownership_token,
                parameters={
                    'SerialNumber': self.serial,
                    'MacAddress': self.mac,
                    'HardwareId': 'raspberry-pi',
                },
            ),
            qos=mqtt.QoS.AT_LEAST_ONCE,
        ).result(timeout=10)

        if not done.wait(timeout=30):
            raise RuntimeError('Timeout waiting for thing registration')
        if error:
            raise RuntimeError(error)

        return result

    def get_config_via_mqtt(self) -> dict:
        """
        Connect with device cert, subscribe to config topic,
        receive VPN configuration (activation code, vpn_server, etc.).
        """
        if not DEVICE_CONFIG.exists():
            raise RuntimeError('Device not provisioned')

        device_config = json.loads(DEVICE_CONFIG.read_text())
        thing_name = device_config['thing_name']

        log.info(f'Connecting as {thing_name} to receive VPN config')

        connection = mqtt_connection_builder.mtls_from_path(
            endpoint=IOT_ENDPOINT,
            cert_filepath=str(DEVICE_CERT),
            pri_key_filepath=str(DEVICE_KEY),
            ca_filepath=str(AWS_ROOT_CA),
            client_id=thing_name,
            clean_session=True,
        )

        connect_future = connection.connect()
        connect_future.result(timeout=30)

        config_topic = f'securelink/{thing_name}/config'
        vpn_config = {}
        received = threading.Event()

        def on_config(topic, payload, **kwargs):
            nonlocal vpn_config
            vpn_config = json.loads(payload)
            log.info(f'Received VPN config from {topic}')
            received.set()

        connection.subscribe(
            topic=config_topic,
            qos=mqtt.QoS.AT_LEAST_ONCE,
            callback=on_config,
        ).result(timeout=10)

        # Publish a request to trigger config delivery
        connection.publish(
            topic=f'securelink/{thing_name}/config/request',
            payload=json.dumps({'serial': self.serial, 'action': 'get_config'}),
            qos=mqtt.QoS.AT_LEAST_ONCE,
        ).result(timeout=10)

        if not received.wait(timeout=60):
            connection.disconnect().result(timeout=10)
            raise RuntimeError('Timeout waiting for VPN config')

        connection.disconnect().result(timeout=10)
        return vpn_config
