ai/neural/worker/nn-training-worker.js

/**
 * @fileoverview Web Worker für das Neuronale Netz Training.
 * 
 * Dieser Worker läuft in einem separaten Thread und führt das Training
 * (Forward + Backward Pass) durch, ohne den Main-Thread zu blockieren.
 * 
 * Kommunikation:
 * - Main → Worker: Befehle (START_TRAINING, STOP, SET_WEIGHT, etc.)
 * - Worker → Main: Snapshots (SNAPSHOT, TRAINING_COMPLETE, ERROR)
 * 
 * Der Worker sendet Snapshots adaptiv (~60fps-Target), damit der
 * NNOrchestrator im Main-Thread die IframeBridge nicht überlastet.
 * 
 * WICHTIG: Dieser Worker importiert die Core-Dateien via importScripts().
 * Er hat KEINEN Zugriff auf DebugConfig (kein window-Objekt).
 *
 * @author Alexander Wolf
 * @see docs/architecture/NEURAL_NET_ARCHITECTURE.md
 */

// ============================================================================
// WORKER IMPORTS
// ============================================================================

// In einem Worker haben wir kein 'window', daher müssen wir die Abhängigkeiten
// explizit importieren. Der Pfad wird relativ zum HTML-Dokument aufgelöst.
// Die importierten Dateien registrieren ihre Klassen auf 'self' statt 'window'.
try {
    importScripts(
        '../training/activation-functions.js',
        '../training/loss-functions.js',
        '../core/neuron.js',
        '../core/layer.js',
        '../core/network.js'
    );
} catch (e) {
    // Fallback: Pfade relativ zum Worker-Standort
    try {
        importScripts(
            '../../ai/neural/training/activation-functions.js',
            '../../ai/neural/training/loss-functions.js',
            '../../ai/neural/core/neuron.js',
            '../../ai/neural/core/layer.js',
            '../../ai/neural/core/network.js'
        );
    } catch (e2) {
        postMessage({ type: 'ERROR', payload: { message: 'Failed to import dependencies', error: e2.message } });
    }
}

// ============================================================================
// WORKER STATE
// ============================================================================

/**
 * Das aktuelle Netzwerk (wird bei START_TRAINING erstellt).
 * @type {NeuralNetwork|null}
 */
let network = null;

/**
 * Flag: Ist das Training aktiv?
 * @type {boolean}
 */
let isTraining = false;

/**
 * Flag: Wurde ein Stopp angefordert?
 * @type {boolean}
 */
let stopRequested = false;

/**
 * Adaptiver Snapshot-Intervall (ms).
 * Target: ~60fps für die Visualisierung.
 * @type {number}
 */
const SNAPSHOT_INTERVAL_MS = 16;

// ============================================================================
// MESSAGE HANDLER
// ============================================================================

/**
 * Empfängt Befehle vom Main-Thread (NNOrchestrator).
 * 
 * Unterstützte Befehle:
 * - CREATE_NETWORK: Erstellt ein neues Netzwerk
 * - START_TRAINING: Startet das Training mit gegebenem Dataset
 * - STOP_TRAINING: Stoppt das laufende Training
 * - FORWARD_PASS: Einzelner Forward-Pass (für Prediction)
 * - SET_WEIGHT: Setzt ein einzelnes Gewicht
 * - SET_BIAS: Setzt einen Bias
 * - SET_LEARNING_RATE: Ändert die Lernrate
 * - GET_SNAPSHOT: Fordert einen Snapshot an
 * - LOAD_WEIGHTS: Lädt gespeicherte Gewichte
 */
self.onmessage = function(e) {
    const { type, payload } = e.data;

    switch (type) {
        case 'CREATE_NETWORK':
            handleCreateNetwork(payload);
            break;

        case 'START_TRAINING':
            handleStartTraining(payload);
            break;

        case 'STOP_TRAINING':
            stopRequested = true;
            break;

        case 'FORWARD_PASS':
            handleForwardPass(payload);
            break;

        case 'SET_WEIGHT':
            handleSetWeight(payload);
            break;

        case 'SET_BIAS':
            handleSetBias(payload);
            break;

        case 'SET_LEARNING_RATE':
            if (network) {
                network.setLearningRate(payload.learningRate);
            }
            break;

        case 'GET_SNAPSHOT':
            if (network) {
                postMessage({
                    type: 'SNAPSHOT',
                    payload: network.getSnapshot()
                });
            }
            break;

        case 'LOAD_WEIGHTS':
            if (network && payload) {
                network.deserialize(payload);
                postMessage({ type: 'WEIGHTS_LOADED' });
            }
            break;

        default:
            postMessage({
                type: 'ERROR',
                payload: { message: `Unknown command: ${type}` }
            });
    }
};

// ============================================================================
// COMMAND HANDLERS
// ============================================================================

/**
 * Erstellt ein neues Netzwerk mit der gegebenen Topologie.
 *
 * @param {NetworkTopology} topology
 */
function handleCreateNetwork(topology) {
    try {
        network = new NeuralNetwork(topology);
        postMessage({
            type: 'NETWORK_CREATED',
            payload: network.getSnapshot()
        });
    } catch (error) {
        postMessage({
            type: 'ERROR',
            payload: { message: 'Network creation failed', error: error.message }
        });
    }
}

/**
 * Startet den Trainings-Loop.
 * Sendet adaptiv Snapshots (~60fps-Target) an den Main-Thread.
 *
 * @param {Object} config
 * @param {TrainingSample[]} config.dataset - Trainingsdaten
 * @param {number} config.epochs - Anzahl der Epochen
 * @param {number} [config.snapshotInterval=16] - Snapshot-Intervall in ms
 */
function handleStartTraining(config) {
    if (!network) {
        postMessage({
            type: 'ERROR',
            payload: { message: 'No network created. Call CREATE_NETWORK first.' }
        });
        return;
    }

    if (isTraining) {
        postMessage({
            type: 'ERROR',
            payload: { message: 'Training already running.' }
        });
        return;
    }

    const { dataset, epochs, snapshotInterval = SNAPSHOT_INTERVAL_MS } = config;

    // Dataset in Float64Array konvertieren für Performance
    const preparedDataset = dataset.map(sample => ({
        input: sample.input instanceof Float64Array ? sample.input : new Float64Array(sample.input),
        target: sample.target instanceof Float64Array ? sample.target : new Float64Array(sample.target)
    }));

    isTraining = true;
    stopRequested = false;

    postMessage({ type: 'TRAINING_STARTED', payload: { epochs } });

    // Asynchroner Trainings-Loop mit adaptivem Snapshot
    runTrainingLoop(preparedDataset, epochs, snapshotInterval);
}

/**
 * Trainings-Loop mit adaptiver Snapshot-Frequenz.
 * Verwendet setTimeout(0) nach jedem Batch, um dem Message-Handler
 * die Chance zu geben, STOP-Befehle zu verarbeiten.
 *
 * @param {TrainingSample[]} dataset
 * @param {number} totalEpochs
 * @param {number} snapshotInterval
 */
function runTrainingLoop(dataset, totalEpochs, snapshotInterval) {
    let lastSnapshotTime = performance.now();
    let currentEpoch = 0;

    /**
     * Führt einen Batch von Epochen aus (nicht-blockierend via setTimeout).
     * Pro Aufruf werden bis zu EPOCHS_PER_BATCH Epochen berechnet,
     * dann wird mit setTimeout(0) an den Event-Loop zurückgegeben.
     */
    function processBatch() {
        if (stopRequested || currentEpoch >= totalEpochs) {
            // Training beendet
            isTraining = false;
            postMessage({
                type: 'TRAINING_COMPLETE',
                payload: {
                    totalEpochs: currentEpoch,
                    finalLoss: network.lastLoss,
                    snapshot: network.getSnapshot()
                }
            });
            return;
        }

        // Mehrere Epochen pro Batch (amortisiert setTimeout-Overhead)
        const epochsPerBatch = Math.min(10, totalEpochs - currentEpoch);

        for (let b = 0; b < epochsPerBatch && !stopRequested; b++) {
            const result = network.trainEpoch(dataset);
            currentEpoch++;

            // Adaptiver Snapshot: Nur senden wenn genug Zeit vergangen
            const now = performance.now();
            if (now - lastSnapshotTime >= snapshotInterval) {
                postMessage({
                    type: 'SNAPSHOT',
                    payload: {
                        ...network.getCompactSnapshot(),
                        accuracy: result.accuracy,
                        epochsTotal: totalEpochs,
                        progress: currentEpoch / totalEpochs
                    }
                });
                lastSnapshotTime = now;
            }
        }

        // Nächsten Batch über Event-Loop scheduling (erlaubt STOP-Verarbeitung)
        setTimeout(processBatch, 0);
    }

    // Initiales Snapshot vor dem Training
    postMessage({
        type: 'SNAPSHOT',
        payload: {
            ...network.getSnapshot(),
            epochsTotal: totalEpochs,
            progress: 0
        }
    });

    // Training starten
    processBatch();
}

/**
 * Einzelner Forward-Pass für Prediction.
 *
 * @param {Object} payload
 * @param {number[]} payload.input - Eingabedaten
 * @param {number[]} [payload.legalMask] - Maske für legale Züge
 */
function handleForwardPass(payload) {
    if (!network) {
        postMessage({ type: 'ERROR', payload: { message: 'No network created.' } });
        return;
    }

    const { input, legalMask } = payload;
    let output;

    if (legalMask) {
        output = network.forwardMasked(input, legalMask);
    } else {
        output = network.forward(input);
    }

    postMessage({
        type: 'PREDICTION',
        payload: {
            input,
            output: Array.from(output),
            activations: network.layers.map(l => Array.from(l.outputs))
        }
    });
}

/**
 * Setzt ein einzelnes Gewicht.
 * @param {Object} payload
 */
function handleSetWeight(payload) {
    if (!network) return;
    const { layerIndex, neuronIndex, weightIndex, value } = payload;
    network.setWeight(layerIndex, neuronIndex, weightIndex, value);
}

/**
 * Setzt einen Bias.
 * @param {Object} payload
 */
function handleSetBias(payload) {
    if (!network) return;
    const { layerIndex, neuronIndex, value } = payload;
    network.setBias(layerIndex, neuronIndex, value);
}