#!/usr/bin/env python3
"""
对讲机音频接收器
将接收到的摩斯音频转换回文字
"""

import sys
import numpy as np
import threading
import time
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent / "src"))

from src.morse_protocol_integration import MorseProtocolSystem


class WalkieTalkieAudioReceiver:
    """
    对讲机音频接收器

    功能：
    1. 录音（从麦克风）
    2. 检测声音
    3. 识别点划
    4. 摩斯解码
    5. 文字输出
    """

    def __init__(self,
                 frequency=800,
                 sample_rate=44100,
                 dot_duration=0.1,
                 threshold=0.1):
        """
        初始化接收器

        参数：
        frequency: 期望的音频频率（Hz）
        sample_rate: 采样率（Hz）
        dot_duration: 点的标准时长（秒）
        threshold: 音量阈值（0.0-1.0）
        """
        self.frequency = frequency
        self.sample_rate = sample_rate
        self.dot_duration = dot_duration
        self.dash_duration = dot_duration * 3
        self.threshold = threshold

        # 初始化摩斯系统
        self.morse_system = MorseProtocolSystem(
            enable_ai=True,  # 启用AI纠错
            enable_protocol=False
        )

        # 录音状态
        self.is_recording = False
        self.recorded_audio = []

    def start_recording(self, duration=10):
        """
        开始录音

        参数：
        duration: 录音时长（秒）

        返回：
        音频数据
        """
        try:
            import sounddevice as sd

            print(f"开始录音 {duration} 秒...")
            print("请播放对讲机的摩斯音频...")

            # 录音
            recording = sd.rec(
                int(duration * self.sample_rate),
                samplerate=self.sample_rate,
                channels=1
            )

            sd.wait()  # 等待录音完成

            audio = recording.flatten()
            print(f"✓ 录音完成，采样数: {len(audio)}")

            return audio

        except ImportError:
            print("⚠️  需要安装 sounddevice: pip install sounddevice")
            return None
        except Exception as e:
            print(f"✗ 录音失败: {e}")
            return None

    def load_from_file(self, filename):
        """
        从WAV文件加载音频

        参数：
        filename: WAV文件路径

        返回：
        音频数据和采样率
        """
        try:
            from scipy.io import wavfile

            sample_rate, audio = wavfile.read(filename)

            # 转换为浮点数
            if audio.dtype == np.int16:
                audio = audio.astype(np.float32) / 32768.0
            elif audio.dtype == np.int32:
                audio = audio.astype(np.float32) / 2147483648.0

            print(f"✓ 已加载: {filename}")
            print(f"  采样率: {sample_rate} Hz")
            print(f"  样本数: {len(audio)}")

            return audio, sample_rate

        except ImportError:
            print("⚠️  需要安装 scipy: pip install scipy")
            return None, None
        except Exception as e:
            print(f"✗ 加载失败: {e}")
            return None, None

    def detect_signals(self, audio):
        """
        检测音频中的摩斯信号

        参数：
        audio: 音频数据

        返回：
        信号列表 [(开始时间, 结束时间, 类型), ...]
        类型: 'dot' 或 'dash'
        """
        print("\n检测摩斯信号...")

        # 计算音量包络
        envelope = np.abs(audio)

        # 平滑处理
        window_size = int(self.sample_rate * 0.01)  # 10ms窗口
        if len(envelope) > window_size:
            envelope = np.convolve(envelope,
                                   np.ones(window_size)/window_size,
                                   mode='same')

        # 检测超过阈值的部分
        is_signal = envelope > self.threshold

        # 找到信号边界
        signals = []
        in_signal = False
        signal_start = 0

        for i, signal in enumerate(is_signal):
            if signal and not in_signal:
                # 信号开始
                signal_start = i
                in_signal = True
            elif not signal and in_signal:
                # 信号结束
                signal_end = i
                duration = (signal_end - signal_start) / self.sample_rate

                # 判断是点还是划
                if duration < self.dot_duration * 1.5:
                    signal_type = 'dot'
                elif duration < self.dash_duration * 1.5:
                    signal_type = 'dash'
                else:
                    signal_type = 'dash'  # 默认为划

                signals.append((signal_start, signal_end, signal_type))
                in_signal = False

        print(f"✓ 检测到 {len(signals)} 个信号")

        return signals

    def signals_to_morse(self, signals):
        """
        将信号列表转换为摩斯密码

        参数：
        signals: 信号列表

        返回：
        摩斯密码字符串
        """
        morse_chars = []

        for i, (start, end, signal_type) in enumerate(signals):
            if signal_type == 'dot':
                morse_chars.append('・')
            elif signal_type == 'dash':
                morse_chars.append('-')

            # 检查间隔
            if i < len(signals) - 1:
                gap = signals[i+1][0] - end
                gap_duration = gap / self.sample_rate

                # 判断间隔类型
                if gap_duration > self.dot_duration * 4:
                    # 单词间隔
                    morse_chars.append(' / ')
                elif gap_duration > self.dot_duration * 1.5:
                    # 字符间隔
                    morse_chars.append(' ')

        morse_code = ''.join(morse_chars)

        print(f"✓ 摩斯密码: {morse_code}")

        return morse_code

    def decode_morse(self, morse_code):
        """
        解码摩斯密码为文字

        参数：
        morse_code: 摩斯密码字符串

        返回：
        解码后的文字
        """
        print("\n解码摩斯密码...")

        # 使用AI解码
        result = self.morse_system.decode_morse(
            morse_code,
            use_ai=True
        )

        text = result.morse_code

        print(f"✓ 解码文字: {text}")

        if result.ai_corrections:
            print(f"✓ AI修正: {result.ai_corrections}")

        return text

    def process_audio(self, audio, sample_rate=None):
        """
        完整处理音频

        参数：
        audio: 音频数据
        sample_rate: 采样率（如果提供）

        返回：
        解码后的文字
        """
        if sample_rate:
            self.sample_rate = sample_rate

        # 步骤1: 检测信号
        signals = self.detect_signals(audio)

        if not signals:
            print("✗ 未检测到信号")
            return ""

        # 步骤2: 转换为摩斯
        morse_code = self.signals_to_morse(signals)

        # 步骤3: 解码
        text = self.decode_morse(morse_code)

        return text

    def auto_threshold(self, audio):
        """
        自动计算最佳阈值

        参数：
        audio: 音频数据
        """
        # 使用音量的中位数作为阈值
        envelope = np.abs(audio)
        median_level = np.median(envelope)
        self.threshold = median_level * 2

        print(f"✓ 自动阈值: {self.threshold:.3f}")


def demo_decode():
    """演示解码功能"""
    print("="*60)
    print("  对讲机音频接收器 - 解码演示")
    print("="*60)

    # 创建接收器
    receiver = WalkieTalkieAudioReceiver(
        frequency=800,
        sample_rate=44100,
        dot_duration=0.1,
        threshold=0.1
    )

    print("\n选项:")
    print("  1. 从WAV文件解码")
    print("  2. 实时录音解码")
    print("  3. 自动生成测试音频并解码")

    choice = input("\n选择 (1-3): ").strip()

    if choice == "1":
        # 从文件解码
        filename = input("WAV文件名: ").strip()
        if not filename.endswith('.wav'):
            filename += '.wav'

        audio, sample_rate = receiver.load_from_file(filename)

        if audio is not None:
            text = receiver.process_audio(audio, sample_rate)

    elif choice == "2":
        # 实时录音
        duration = input("录音时长（秒，默认10）: ").strip()
        try:
            duration = int(duration) if duration else 10
        except ValueError:
            duration = 10

        audio = receiver.start_recording(duration)

        if audio is not None:
            receiver.auto_threshold(audio)
            text = receiver.process_audio(audio)

    elif choice == "3":
        # 生成测试音频
        print("\n生成测试音频...")

        from walkietalkie_audio_converter import WalkieTalkieAudioConverter

        # 生成音频
        converter = WalkieTalkieAudioConverter()
        test_text = "SOS HELP ME"
        audio, sample_rate = converter.text_to_audio(test_text)

        # 保存
        test_file = "test_morse.wav"
        converter.save_to_wav(audio, sample_rate, test_file)

        print(f"\n现在解码 {test_file}...")

        # 解码
        audio, sr = receiver.load_from_file(test_file)
        if audio is not None:
            text = receiver.process_audio(audio, sr)


def interactive_receiver():
    """交互式接收器"""
    print("""
╔══════════════════════════════════════════════════════════╗
║        对讲机音频接收器                                 ║
║                                                          ║
║  将接收到的摩斯音频转换回文字                           ║
╚══════════════════════════════════════════════════════════╝
    """)

    receiver = WalkieTalkieAudioReceiver()

    while True:
        print("\n" + "="*60)
        print("菜单:")
        print("  1. 从文件解码")
        print("  2. 实时录音")
        print("  3. 调整设置")
        print("  4. 演示模式")
        print("  5. 退出")

        choice = input("\n选择 (1-5): ").strip()

        if choice == "1":
            filename = input("\nWAV文件名: ").strip()
            if not filename.endswith('.wav'):
                filename += '.wav'

            audio, sr = receiver.load_from_file(filename)
            if audio is not None:
                receiver.auto_threshold(audio)
                text = receiver.process_audio(audio, sr)

        elif choice == "2":
            duration = input("\n录音时长（秒）: ").strip()
            try:
                duration = int(duration)
            except ValueError:
                duration = 10

            audio = receiver.start_recording(duration)
            if audio is not None:
                receiver.auto_threshold(audio)
                text = receiver.process_audio(audio)

        elif choice == "3":
            print("\n当前设置:")
            print(f"  1. 期望频率: {receiver.frequency} Hz")
            print(f"  2. 点时长: {receiver.dot_duration} 秒")
            print(f"  3. 阈值: {receiver.threshold}")

            setting = input("\n调整哪项 (1-3, 其他=跳过): ").strip()

            if setting == "1":
                try:
                    freq = int(input("新频率 (Hz): "))
                    receiver.frequency = freq
                    print(f"✓ 频率已设置为 {freq} Hz")
                except ValueError:
                    print("✗ 无效输入")

            elif setting == "2":
                try:
                    duration = float(input("新的点时长 (秒): "))
                    receiver.dot_duration = duration
                    receiver.dash_duration = duration * 3
                    print(f"✓ 点时长已设置为 {duration} 秒")
                except ValueError:
                    print("✗ 无效输入")

            elif setting == "3":
                try:
                    threshold = float(input("新阈值 (0.0-1.0): "))
                    receiver.threshold = threshold
                    print(f"✓ 阈值已设置为 {threshold}")
                except ValueError:
                    print("✗ 无效输入")

        elif choice == "4":
            demo_decode()

        elif choice == "5":
            print("\n👋 再见！")
            break

        else:
            print("✗ 无效选择")


if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1 and sys.argv[1] == "demo":
        demo_decode()
    else:
        interactive_receiver()
