viz/neural/nn-mlp-visualizer.js

/**
 * @fileoverview NNMLPVisualizer — SVG Data-Flow-Visualisierung eines MLP.
 *
 * Phase 2: Verschmolzenes SVG mit integrierten Input/Output-Boards.
 *
 * Rendert das Netzwerk als links→rechts Datenfluss:
 *   Input (3×3 klickbare Rects) → Hidden Layers (Kreise) → Output (3×3 Heatmap-Rects)
 *
 * Features:
 *   - Input-Neuronen als klickbare <rect> im 3×3-Gitter (X/O/leer)
 *   - Output-Neuronen als <rect> mit Heatmap-Farbe + %-Label
 *   - Hidden-Neuronen als Kreise (vollständig, kein Truncation)
 *   - ALLE Verbindungen gerendert (Stärke ∝ |Gewicht|, Farbe ∝ Vorzeichen)
 *   - Zoom/Pan via <g class="nn-viewport">
 *   - Weight-Filter: applyWeightFilter(percentage)
 *   - Receptive-Field-Hover auf erstem Hidden Layer
 *
 * Convention: Reine Rendering-Klasse, keine Game-Logik, kein Controller-Code.
 *
 * @author Alexander Wolf
 * @see ToDos/NN_PLAYGROUND_PLAN_v2.md
 */

/* global CELL_EMPTY, PLAYER1, PLAYER2 */

class NNMLPVisualizer {
    /**
     * @param {Object} config
     * @param {HTMLElement|string} config.container — Container-Element oder ID
     * @param {Object} [config.callbacks] — Event-Callbacks
     * @param {function(number): void} [config.callbacks.onInputCellClick] — (cellIndex)
     * @param {function(number, number[]): void} [config.callbacks.onHiddenNeuronHover]
     * @param {function(): void} [config.callbacks.onHiddenNeuronLeave]
     */
    constructor(config) {
        const container = typeof config.container === 'string'
            ? document.getElementById(config.container)
            : config.container;

        if (!container) throw new Error('[NNMLPVisualizer] Container not found');

        /** @type {HTMLElement} */
        this.container = container;

        /** @type {Object} */
        this.callbacks = config.callbacks || {};

        /** @type {SVGSVGElement|null} */
        this.svg = null;

        /** @type {SVGGElement|null} Viewport group for zoom/pan */
        this.viewport = null;

        /** @type {number[]} */
        this.topology = [];

        /** @type {Object|null} */
        this.lastSnapshot = null;

        // ── Layout constants ──
        this.RECT_SIZE = 40;
        this.RECT_GAP = 4;
        this.NEURON_RADIUS = 12;
        this.LAYER_GAP = 180;
        this.NEURON_GAP = 32;
        this.PADDING = 60;

        // ── Color scales (Farbtrennung: Spieler vs Gewichte vs Heatmap) ──
        // Spieler: Blau / Rot (unchanged, used for input board)
        // Gewichte: Violett (pos) / Orange (neg) — different from player colors!
        this.WEIGHT_POS_COLOR = '#8b5cf6';   // Violett
        this.WEIGHT_NEG_COLOR = '#f97316';   // Orange
        this.WEIGHT_ZERO_COLOR = '#555';

        // ── Zoom/Pan state ──
        this._scale = 1;
        this._offsetX = 0;
        this._offsetY = 0;
        this._isPanning = false;
        this._panStartX = 0;
        this._panStartY = 0;

        // ── Weight filter ──
        this._weightFilterPct = 100;

        // ── Bound handlers (for removal) ──
        this._onWheel = this._handleWheel.bind(this);
        this._onPointerDown = this._handlePointerDown.bind(this);
        this._onPointerMove = this._handlePointerMove.bind(this);
        this._onPointerUp = this._handlePointerUp.bind(this);
    }

    // ========================================================================
    // INIT
    // ========================================================================

    /**
     * Initialisiert das SVG für eine gegebene Topologie.
     * @param {number[]} topology — z.B. [9, 36, 9]
     */
    init(topology) {
        this.topology = topology;
        this.container.innerHTML = '';

        const { width, height } = this._calculateDimensions();
        const svgNS = 'http://www.w3.org/2000/svg';

        // ── SVG root ──
        this.svg = document.createElementNS(svgNS, 'svg');
        this.svg.setAttribute('viewBox', `0 0 ${width} ${height}`);
        this.svg.setAttribute('preserveAspectRatio', 'xMidYMid meet');
        this.svg.classList.add('nn-mlp-svg');
        this.svg.style.width = '100%';
        this.svg.style.height = '100%';

        // ── Defs ──
        const defs = document.createElementNS(svgNS, 'defs');
        defs.innerHTML = `
            <filter id="nn-glow" x="-50%" y="-50%" width="200%" height="200%">
                <feGaussianBlur in="SourceGraphic" stdDeviation="3" result="blur"/>
                <feMerge>
                    <feMergeNode in="blur"/>
                    <feMergeNode in="SourceGraphic"/>
                </feMerge>
            </filter>
        `;
        this.svg.appendChild(defs);

        // ── Viewport group (zoom/pan target) ──
        this.viewport = document.createElementNS(svgNS, 'g');
        this.viewport.classList.add('nn-viewport');
        this.svg.appendChild(this.viewport);

        // ── Sub-groups inside viewport ──
        this.connectionsGroup = document.createElementNS(svgNS, 'g');
        this.connectionsGroup.classList.add('nn-connections');
        this.viewport.appendChild(this.connectionsGroup);

        this.neuronsGroup = document.createElementNS(svgNS, 'g');
        this.neuronsGroup.classList.add('nn-neurons');
        this.viewport.appendChild(this.neuronsGroup);

        this.labelsGroup = document.createElementNS(svgNS, 'g');
        this.labelsGroup.classList.add('nn-labels');
        this.viewport.appendChild(this.labelsGroup);

        this.tooltipGroup = document.createElementNS(svgNS, 'g');
        this.tooltipGroup.classList.add('nn-tooltip');
        this.tooltipGroup.style.display = 'none';
        this.viewport.appendChild(this.tooltipGroup);

        this.container.appendChild(this.svg);

        // ── Compute & render ──
        this.neuronPositions = this._computeNeuronPositions();
        this._renderConnections();
        this._renderNeurons();
        this._renderLabels();

        // ── Setup zoom/pan ──
        this._scale = 1;
        this._offsetX = 0;
        this._offsetY = 0;
        this._setupZoomPan();
    }

    // ========================================================================
    // DIMENSIONS & LAYOUT
    // ========================================================================

    /**
     * @returns {{ width: number, height: number }}
     * @private
     */
    _calculateDimensions() {
        const numLayers = this.topology.length;
        const maxNeurons = Math.max(...this.topology);

        const gridH = 3 * (this.RECT_SIZE + this.RECT_GAP) + 2 * this.PADDING;
        const columnH = maxNeurons * this.NEURON_GAP + 2 * this.PADDING;
        const height = Math.max(gridH, Math.min(columnH, 800));
        const width = numLayers * this.LAYER_GAP + 2 * this.PADDING;

        return { width, height };
    }

    /**
     * Computes (x, y) for every neuron.
     * Input & Output (count=9) → 3×3 rect grid.
     * Hidden → vertical column of circles.
     * @returns {Object[][]}
     * @private
     */
    _computeNeuronPositions() {
        const positions = [];
        const { height } = this._calculateDimensions();
        const centerY = height / 2;

        for (let l = 0; l < this.topology.length; l++) {
            const count = this.topology[l];
            const x = this.PADDING + l * this.LAYER_GAP + this.LAYER_GAP / 2;
            const layerPositions = [];

            const isIO = (l === 0 || l === this.topology.length - 1) && count === 9;

            if (isIO) {
                // 3×3 grid of rects
                const cellTotal = this.RECT_SIZE + this.RECT_GAP;
                const gridW = 3 * cellTotal;
                const gridH = 3 * cellTotal;

                for (let i = 0; i < 9; i++) {
                    const row = Math.floor(i / 3);
                    const col = i % 3;
                    layerPositions.push({
                        x: x - gridW / 2 + col * cellTotal + this.RECT_SIZE / 2,
                        y: centerY - gridH / 2 + row * cellTotal + this.RECT_SIZE / 2
                    });
                }
            } else {
                // Vertical column — show ALL neurons (no truncation)
                const totalH = (count - 1) * this.NEURON_GAP;
                const startY = centerY - totalH / 2;

                for (let n = 0; n < count; n++) {
                    layerPositions.push({
                        x,
                        y: startY + n * this.NEURON_GAP
                    });
                }
            }

            positions.push(layerPositions);
        }

        return positions;
    }

    // ========================================================================
    // RENDER CONNECTIONS
    // ========================================================================

    /**
     * Renders ALL connection lines between adjacent layers (no sampling).
     * @private
     */
    _renderConnections() {
        const svgNS = 'http://www.w3.org/2000/svg';
        this.connectionElements = [];

        for (let l = 0; l < this.topology.length - 1; l++) {
            const fromPositions = this.neuronPositions[l];
            const toPositions = this.neuronPositions[l + 1];
            const layerConnections = [];

            for (let from = 0; from < fromPositions.length; from++) {
                const neuronConnections = [];
                for (let to = 0; to < toPositions.length; to++) {
                    const line = document.createElementNS(svgNS, 'line');
                    line.setAttribute('x1', fromPositions[from].x);
                    line.setAttribute('y1', fromPositions[from].y);
                    line.setAttribute('x2', toPositions[to].x);
                    line.setAttribute('y2', toPositions[to].y);
                    line.classList.add('nn-connection');
                    line.setAttribute('stroke', this.WEIGHT_ZERO_COLOR);
                    line.setAttribute('stroke-width', '0.5');
                    line.setAttribute('stroke-opacity', '0.12');
                    this.connectionsGroup.appendChild(line);
                    neuronConnections.push(line);
                }
                layerConnections.push(neuronConnections);
            }
            this.connectionElements.push(layerConnections);
        }
    }

    // ========================================================================
    // RENDER NEURONS
    // ========================================================================

    /**
     * Renders all neurons: input rects, hidden circles, output rects.
     * @private
     */
    _renderNeurons() {
        const svgNS = 'http://www.w3.org/2000/svg';
        this.neuronElements = [];
        this._inputRects = [];
        this._inputTexts = [];
        this._outputRects = [];
        this._outputTexts = [];

        for (let l = 0; l < this.topology.length; l++) {
            const layerNeurons = [];
            const positions = this.neuronPositions[l];
            const isInput = l === 0 && this.topology[l] === 9;
            const isOutput = l === this.topology.length - 1 && this.topology[l] === 9;
            const isHidden = !isInput && !isOutput;

            for (let n = 0; n < positions.length; n++) {
                const pos = positions[n];

                if (isInput) {
                    const { group, rect, text } = this._createInputRect(pos, n);
                    this.neuronsGroup.appendChild(group);
                    layerNeurons.push({ group, circle: rect }); // keep .circle for compat
                    this._inputRects.push(rect);
                    this._inputTexts.push(text);
                } else if (isOutput) {
                    const { group, rect, text } = this._createOutputRect(pos, n);
                    this.neuronsGroup.appendChild(group);
                    layerNeurons.push({ group, circle: rect });
                    this._outputRects.push(rect);
                    this._outputTexts.push(text);
                } else {
                    const { group, circle } = this._createHiddenCircle(pos, l, n, isHidden);
                    this.neuronsGroup.appendChild(group);
                    layerNeurons.push({ group, circle });
                }
            }

            this.neuronElements.push(layerNeurons);
        }
    }

    /**
     * Creates a clickable input rect (3×3 board cell).
     * @param {{ x: number, y: number }} pos — center of the rect
     * @param {number} index — 0–8
     * @returns {{ group: SVGGElement, rect: SVGRectElement, text: SVGTextElement }}
     * @private
     */
    _createInputRect(pos, index) {
        const svgNS = 'http://www.w3.org/2000/svg';
        const half = this.RECT_SIZE / 2;

        const g = document.createElementNS(svgNS, 'g');
        g.classList.add('nn-input-group');

        const rect = document.createElementNS(svgNS, 'rect');
        rect.setAttribute('x', pos.x - half);
        rect.setAttribute('y', pos.y - half);
        rect.setAttribute('width', this.RECT_SIZE);
        rect.setAttribute('height', this.RECT_SIZE);
        rect.setAttribute('rx', '4');
        rect.setAttribute('fill', 'rgba(26, 26, 46, 0.6)');
        rect.setAttribute('stroke', 'var(--nn-neuron-stroke, #3498db)');
        rect.setAttribute('stroke-width', '1.5');
        rect.classList.add('nn-input-rect');
        rect.style.cursor = 'pointer';
        g.appendChild(rect);

        // X / O text — use dominant-baseline for reliable vertical centering
        const text = document.createElementNS(svgNS, 'text');
        text.setAttribute('x', pos.x);
        text.setAttribute('y', pos.y);
        text.setAttribute('text-anchor', 'middle');
        text.setAttribute('dominant-baseline', 'central');
        text.setAttribute('font-size', '20');
        text.setAttribute('font-weight', '700');
        text.setAttribute('font-family', 'system-ui, sans-serif');
        text.setAttribute('fill', '#fff');
        text.setAttribute('pointer-events', 'none');
        text.textContent = '';
        g.appendChild(text);

        // Click handler
        g.addEventListener('click', (e) => {
            e.stopPropagation();
            if (this.callbacks.onInputCellClick) {
                this.callbacks.onInputCellClick(index);
            }
        });

        return { group: g, rect, text };
    }

    /**
     * Creates an output rect (heatmap cell with % label).
     * @param {{ x: number, y: number }} pos
     * @param {number} index
     * @returns {{ group: SVGGElement, rect: SVGRectElement, text: SVGTextElement }}
     * @private
     */
    _createOutputRect(pos, index) {
        const svgNS = 'http://www.w3.org/2000/svg';
        const half = this.RECT_SIZE / 2;

        const g = document.createElementNS(svgNS, 'g');
        g.classList.add('nn-output-group');

        const rect = document.createElementNS(svgNS, 'rect');
        rect.setAttribute('x', pos.x - half);
        rect.setAttribute('y', pos.y - half);
        rect.setAttribute('width', this.RECT_SIZE);
        rect.setAttribute('height', this.RECT_SIZE);
        rect.setAttribute('rx', '4');
        rect.setAttribute('fill', 'rgba(46, 204, 113, 0.05)');
        rect.setAttribute('stroke', 'rgba(46, 204, 113, 0.3)');
        rect.setAttribute('stroke-width', '1.5');
        rect.classList.add('nn-output-rect');
        g.appendChild(rect);

        // Percentage text — use dominant-baseline for centering
        const text = document.createElementNS(svgNS, 'text');
        text.setAttribute('x', pos.x);
        text.setAttribute('y', pos.y);
        text.setAttribute('text-anchor', 'middle');
        text.setAttribute('dominant-baseline', 'central');
        text.setAttribute('font-size', '11');
        text.setAttribute('font-weight', '600');
        text.setAttribute('font-family', 'system-ui, sans-serif');
        text.setAttribute('fill', 'rgba(255,255,255,0.4)');
        text.setAttribute('pointer-events', 'none');
        text.textContent = '';
        g.appendChild(text);

        return { group: g, rect, text };
    }

    /**
     * Creates a hidden neuron circle.
     * @param {{ x: number, y: number }} pos
     * @param {number} layerIdx
     * @param {number} neuronIdx
     * @param {boolean} isFirstHidden — Enable receptive-field hover
     * @returns {{ group: SVGGElement, circle: SVGCircleElement }}
     * @private
     */
    _createHiddenCircle(pos, layerIdx, neuronIdx, isHidden) {
        const svgNS = 'http://www.w3.org/2000/svg';

        const g = document.createElementNS(svgNS, 'g');
        g.classList.add('nn-neuron-group');

        const circle = document.createElementNS(svgNS, 'circle');
        circle.setAttribute('cx', pos.x);
        circle.setAttribute('cy', pos.y);
        circle.setAttribute('r', this.NEURON_RADIUS);
        circle.classList.add('nn-neuron');
        circle.setAttribute('fill', 'var(--nn-neuron-fill, #1a1a2e)');
        circle.setAttribute('stroke', 'var(--nn-neuron-stroke, #3498db)');
        circle.setAttribute('stroke-width', '1.5');
        g.appendChild(circle);

        // Weight heatmap hover for ALL hidden layers
        if (isHidden) {
            circle.classList.add('nn-neuron--interactive');
            circle.dataset.layerIndex = layerIdx;
            circle.dataset.neuronIndex = neuronIdx;

            circle.addEventListener('mouseenter', () => {
                this._showReceptiveField(layerIdx, neuronIdx, pos);
            });
            circle.addEventListener('mouseleave', () => {
                this._hideReceptiveField();
            });
        }

        return { group: g, circle };
    }

    // ========================================================================
    // RENDER LABELS
    // ========================================================================

    /**
     * @private
     */
    _renderLabels() {
        const svgNS = 'http://www.w3.org/2000/svg';
        const { height } = this._calculateDimensions();
        const names = this._getLayerNames();

        for (let l = 0; l < this.topology.length; l++) {
            const positions = this.neuronPositions[l];
            const centerX = positions.reduce((s, p) => s + p.x, 0) / positions.length;

            const label = document.createElementNS(svgNS, 'text');
            label.setAttribute('x', centerX);
            label.setAttribute('y', height - 12);
            label.setAttribute('text-anchor', 'middle');
            label.setAttribute('fill', 'var(--nn-label-color, #9ca3af)');
            label.setAttribute('font-size', '11');
            label.setAttribute('font-weight', '600');
            label.textContent = names[l];
            this.labelsGroup.appendChild(label);
        }
    }

    /**
     * @returns {string[]}
     * @private
     */
    _getLayerNames() {
        return this.topology.map((size, i) => {
            if (i === 0) return `Input (${size})`;
            if (i === this.topology.length - 1) return `Output (${size})`;
            return `Hidden ${i} (${size})`;
        });
    }

    // ========================================================================
    // ZOOM / PAN
    // ========================================================================

    /**
     * Sets up wheel-zoom and pointer-drag pan on the SVG.
     * @private
     */
    _setupZoomPan() {
        if (!this.svg) return;

        this.svg.addEventListener('wheel', this._onWheel, { passive: false });
        this.svg.addEventListener('pointerdown', this._onPointerDown);
        this.svg.addEventListener('pointermove', this._onPointerMove);
        this.svg.addEventListener('pointerup', this._onPointerUp);
        this.svg.addEventListener('pointerleave', this._onPointerUp);
    }

    /**
     * @param {WheelEvent} e
     * @private
     */
    _handleWheel(e) {
        e.preventDefault();

        const zoomFactor = e.deltaY < 0 ? 1.1 : 0.9;
        const newScale = Math.max(0.2, Math.min(5, this._scale * zoomFactor));
        const ratio = newScale / this._scale;

        // Zoom toward mouse position
        const pt = this._screenToSVG(e.clientX, e.clientY);
        this._offsetX = pt.x - (pt.x - this._offsetX) * ratio;
        this._offsetY = pt.y - (pt.y - this._offsetY) * ratio;
        this._scale = newScale;

        this._applyViewportTransform();
    }

    /**
     * @param {PointerEvent} e
     * @private
     */
    _handlePointerDown(e) {
        // Only pan on middle-click or when clicking on SVG background (not neurons)
        if (e.button === 1 || (e.button === 0 && e.target === this.svg)) {
            this._isPanning = true;
            this._panStartX = e.clientX;
            this._panStartY = e.clientY;
            this.svg.style.cursor = 'grabbing';
            this.svg.setPointerCapture(e.pointerId);
        }
    }

    /**
     * @param {PointerEvent} e
     * @private
     */
    _handlePointerMove(e) {
        if (!this._isPanning) return;

        // Convert pixel delta to SVG units
        const svgRect = this.svg.getBoundingClientRect();
        const viewBox = this.svg.viewBox.baseVal;
        const kx = viewBox.width / svgRect.width;
        const ky = viewBox.height / svgRect.height;

        const dx = (e.clientX - this._panStartX) * kx / this._scale;
        const dy = (e.clientY - this._panStartY) * ky / this._scale;

        this._offsetX += dx;
        this._offsetY += dy;
        this._panStartX = e.clientX;
        this._panStartY = e.clientY;

        this._applyViewportTransform();
    }

    /**
     * @param {PointerEvent} e
     * @private
     */
    _handlePointerUp(e) {
        if (this._isPanning) {
            this._isPanning = false;
            this.svg.style.cursor = '';
        }
    }

    /**
     * Converts screen coordinates to SVG coordinate space.
     * @param {number} clientX
     * @param {number} clientY
     * @returns {{ x: number, y: number }}
     * @private
     */
    _screenToSVG(clientX, clientY) {
        const rect = this.svg.getBoundingClientRect();
        const viewBox = this.svg.viewBox.baseVal;
        return {
            x: (clientX - rect.left) / rect.width * viewBox.width,
            y: (clientY - rect.top) / rect.height * viewBox.height
        };
    }

    /**
     * Applies the current transform to the viewport group.
     * @private
     */
    _applyViewportTransform() {
        if (!this.viewport) return;
        this.viewport.setAttribute('transform',
            `translate(${this._offsetX}, ${this._offsetY}) scale(${this._scale})`
        );
    }

    // ========================================================================
    // WEIGHT FILTER
    // ========================================================================

    /**
     * Hides connections whose absolute weight falls below the given percentile.
     * 100% = show all, 0% = hide all.
     * @param {number} percentage — 0–100
     */
    applyWeightFilter(percentage) {
        this._weightFilterPct = percentage;

        if (!this.connectionElements || !this.lastSnapshot) return;

        // Per-layer weight filter: compute threshold per transition
        for (let l = 0; l < this.lastSnapshot.layers.length; l++) {
            const layerData = this.lastSnapshot.layers[l];
            if (!this.connectionElements[l]) continue;

            // Collect absolute weights for THIS layer transition
            const layerWeights = [];
            for (const neuron of layerData.neurons) {
                for (let w = 0; w < neuron.weights.length; w++) {
                    layerWeights.push(Math.abs(neuron.weights[w]));
                }
            }
            layerWeights.sort((a, b) => a - b);

            // Per-layer threshold
            const cutoffIdx = Math.floor(layerWeights.length * (1 - percentage / 100));
            const threshold = cutoffIdx < layerWeights.length ? layerWeights[cutoffIdx] : 0;

            for (let from = 0; from < this.connectionElements[l].length; from++) {
                for (let to = 0; to < this.connectionElements[l][from].length; to++) {
                    const line = this.connectionElements[l][from][to];
                    if (!line) continue;

                    const neuronData = layerData.neurons[to];
                    if (!neuronData || from >= neuronData.weights.length) continue;

                    const absW = Math.abs(neuronData.weights[from]);
                    if (absW < threshold) {
                        line.classList.add('nn-connection--hidden');
                    } else {
                        line.classList.remove('nn-connection--hidden');
                    }
                }
            }
        }
    }

    // ========================================================================
    // RECEPTIVE FIELD (Hover Tooltip)
    // ========================================================================

    /**
     * @param {number} layerIdx
     * @param {number} neuronIdx
     * @param {{ x: number, y: number }} pos
     * @private
     */
    _showReceptiveField(layerIdx, neuronIdx, pos) {
        if (!this.lastSnapshot) return;

        const layerData = this.lastSnapshot.layers[layerIdx - 1];
        if (!layerData || !layerData.neurons[neuronIdx]) return;

        const weights = layerData.neurons[neuronIdx].weights;
        if (!weights || weights.length === 0) return;

        // Only highlight input rects for the first hidden layer (weights.length===9)
        if (weights.length === 9 && this.callbacks.onHiddenNeuronHover) {
            this.callbacks.onHiddenNeuronHover(neuronIdx, Array.from(weights));
        }

        this._renderReceptiveFieldTooltip(pos, weights);
    }

    /**
     * @param {{ x: number, y: number }} pos
     * @param {Float64Array|number[]} weights
     * @private
     */
    _renderReceptiveFieldTooltip(pos, weights) {
        const svgNS = 'http://www.w3.org/2000/svg';
        this.tooltipGroup.innerHTML = '';
        this.tooltipGroup.style.display = '';

        const n = weights.length;
        const cellSize = 22;
        const padding = 6;

        // For 9 weights: 3×3 grid. Otherwise: dynamic grid (cols = ceil(sqrt(n)))
        const cols = n === 9 ? 3 : Math.ceil(Math.sqrt(n));
        const rows = Math.ceil(n / cols);

        const gridW = cols * cellSize + 2 * padding;
        const gridH = rows * cellSize + 20 + 2 * padding;
        const tx = pos.x + this.NEURON_RADIUS + 10;
        const ty = pos.y - gridH / 2;

        // Background
        const bg = document.createElementNS(svgNS, 'rect');
        bg.setAttribute('x', tx - padding);
        bg.setAttribute('y', ty - padding);
        bg.setAttribute('width', gridW + padding);
        bg.setAttribute('height', gridH + padding);
        bg.setAttribute('rx', '6');
        bg.setAttribute('fill', 'rgba(15, 15, 30, 0.95)');
        bg.setAttribute('stroke', 'rgba(96, 165, 250, 0.3)');
        bg.setAttribute('stroke-width', '1');
        this.tooltipGroup.appendChild(bg);

        // Title
        const titleLabel = n === 9 ? 'Receptive Field' : `Gewichte (${n})`;
        const title = document.createElementNS(svgNS, 'text');
        title.setAttribute('x', tx + (cols * cellSize) / 2);
        title.setAttribute('y', ty - padding + 12);
        title.setAttribute('text-anchor', 'middle');
        title.setAttribute('fill', '#9ca3af');
        title.setAttribute('font-size', '9');
        title.textContent = titleLabel;
        this.tooltipGroup.appendChild(title);

        let maxAbs = 0;
        for (let i = 0; i < n; i++) maxAbs = Math.max(maxAbs, Math.abs(weights[i]));
        if (maxAbs === 0) maxAbs = 1;

        for (let i = 0; i < n; i++) {
            const row = Math.floor(i / cols);
            const col = i % cols;
            const w = weights[i];
            const normalized = w / maxAbs;
            const intensity = Math.abs(normalized);

            const cx = tx + col * cellSize;
            const cy = ty + 16 + row * cellSize;

            const rect = document.createElementNS(svgNS, 'rect');
            rect.setAttribute('x', cx);
            rect.setAttribute('y', cy);
            rect.setAttribute('width', cellSize - 1);
            rect.setAttribute('height', cellSize - 1);
            rect.setAttribute('rx', '3');

            let color;
            if (normalized > 0) {
                color = `rgba(139, 92, 246, ${0.3 + intensity * 0.7})`;   // Violett (pos)
            } else {
                color = `rgba(249, 115, 22, ${0.3 + intensity * 0.7})`;   // Orange (neg)
            }
            rect.setAttribute('fill', color);
            this.tooltipGroup.appendChild(rect);

            const text = document.createElementNS(svgNS, 'text');
            text.setAttribute('x', cx + cellSize / 2 - 0.5);
            text.setAttribute('y', cy + cellSize / 2 + 3);
            text.setAttribute('text-anchor', 'middle');
            text.setAttribute('fill', intensity > 0.5 ? '#fff' : '#aaa');
            text.setAttribute('font-size', '8');
            text.textContent = w.toFixed(1);
            this.tooltipGroup.appendChild(text);
        }
    }

    /**
     * @private
     */
    _hideReceptiveField() {
        this.tooltipGroup.style.display = 'none';
        if (this.callbacks.onHiddenNeuronLeave) {
            this.callbacks.onHiddenNeuronLeave();
        }
    }

    // ========================================================================
    // UPDATE FROM SNAPSHOT
    // ========================================================================

    /**
     * Updates the visualization from a NeuralNetwork snapshot.
     * @param {NetworkSnapshot} snapshot
     */
    updateFromSnapshot(snapshot) {
        this.lastSnapshot = snapshot;

        if (!this._topologyMatch(snapshot.topology)) {
            this.init(snapshot.topology);
            this.lastSnapshot = snapshot;
        }

        this._updateConnections(snapshot);
        this._updateNeuronActivations(snapshot);
    }

    /**
     * @param {number[]} topo
     * @returns {boolean}
     * @private
     */
    _topologyMatch(topo) {
        if (!this.topology || topo.length !== this.topology.length) return false;
        return topo.every((v, i) => v === this.topology[i]);
    }

    /**
     * Updates connection line widths, colors, opacity from snapshot weights.
     * @param {NetworkSnapshot} snapshot
     * @private
     */
    _updateConnections(snapshot) {
        if (!this.connectionElements) return;

        for (let l = 0; l < snapshot.layers.length; l++) {
            const layerData = snapshot.layers[l];
            if (!this.connectionElements[l]) continue;

            let maxW = 0;
            for (const neuron of layerData.neurons) {
                for (let w = 0; w < neuron.weights.length; w++) {
                    maxW = Math.max(maxW, Math.abs(neuron.weights[w]));
                }
            }
            if (maxW === 0) maxW = 1;

            for (let from = 0; from < this.connectionElements[l].length; from++) {
                for (let to = 0; to < this.connectionElements[l][from].length; to++) {
                    const line = this.connectionElements[l][from][to];
                    if (!line) continue;

                    const neuronData = layerData.neurons[to];
                    if (!neuronData || from >= neuronData.weights.length) continue;

                    const w = neuronData.weights[from];
                    const normalized = Math.abs(w) / maxW;

                    line.setAttribute('stroke-width', (0.3 + normalized * 2.7).toFixed(2));
                    line.setAttribute('stroke', w >= 0 ? this.WEIGHT_POS_COLOR : this.WEIGHT_NEG_COLOR);
                    line.setAttribute('stroke-opacity', (0.08 + normalized * 0.62).toFixed(2));
                }
            }
        }

        // Re-apply weight filter if active
        if (this._weightFilterPct < 100) {
            this.applyWeightFilter(this._weightFilterPct);
        }
    }

    /**
     * Updates hidden neuron fill colors based on activation values.
     * @param {NetworkSnapshot} snapshot
     * @private
     */
    _updateNeuronActivations(snapshot) {
        if (!this.neuronElements) return;

        for (let l = 0; l < snapshot.layers.length; l++) {
            const layerData = snapshot.layers[l];
            const elIdx = l + 1; // neuronElements[0] = input layer
            const elementLayer = this.neuronElements[elIdx];
            if (!elementLayer) continue;

            // Skip output layer (handled by updateOutputHeatmap)
            if (elIdx === this.topology.length - 1) continue;

            let maxAct = 0;
            for (const n of layerData.neurons) {
                maxAct = Math.max(maxAct, Math.abs(n.output));
            }
            if (maxAct === 0) maxAct = 1;

            for (let n = 0; n < elementLayer.length && n < layerData.neurons.length; n++) {
                const activation = layerData.neurons[n].output;
                const norm = Math.abs(activation) / maxAct;

                const brightness = Math.round(20 + norm * 80);
                const hue = activation >= 0 ? 210 : 0;
                const sat = Math.round(30 + norm * 50);

                elementLayer[n].circle.setAttribute('fill',
                    `hsl(${hue}, ${sat}%, ${brightness}%)`
                );
            }
        }
    }

    // ========================================================================
    // INPUT BOARD
    // ========================================================================

    /**
     * Updates input rects from TTTRegularBoard.grid.
     * grid values: CELL_EMPTY=0, PLAYER1=1, PLAYER2=2
     *
     * @param {number[]} grid — 9 values from TTTRegularBoard.grid
     */
    updateInputBoard(grid) {
        if (!this._inputRects || !this._inputTexts) return;

        for (let i = 0; i < 9; i++) {
            const rect = this._inputRects[i];
            const text = this._inputTexts[i];
            if (!rect || !text) continue;

            const cell = grid[i];

            if (cell === PLAYER1) {
                rect.setAttribute('fill', 'rgba(26, 26, 46, 0.6)');
                rect.setAttribute('stroke', '#5dade2');
                text.textContent = '✕';
                text.setAttribute('fill', '#5dade2');
            } else if (cell === PLAYER2) {
                rect.setAttribute('fill', 'rgba(26, 26, 46, 0.6)');
                rect.setAttribute('stroke', '#ec7063');
                text.textContent = '○';
                text.setAttribute('fill', '#ec7063');
            } else {
                rect.setAttribute('fill', 'rgba(26, 26, 46, 0.6)');
                rect.setAttribute('stroke', 'var(--nn-neuron-stroke, #3498db)');
                text.textContent = '';
                text.setAttribute('fill', '#fff');
            }
        }
    }

    // ========================================================================
    // OUTPUT HEATMAP
    // ========================================================================

    /**
     * Updates output rects as a heatmap of NN predictions.
     *
     * @param {Float64Array} predictions — Softmax output (9 values)
     * @param {number[]} grid — Board grid to dim occupied cells
     */
    updateOutputHeatmap(predictions, grid) {
        if (!this._outputRects || !this._outputTexts) return;

        // Find best legal prediction
        let bestIdx = -1;
        let bestVal = -1;
        for (let i = 0; i < 9; i++) {
            if (grid[i] === 0 && predictions[i] > bestVal) {
                bestVal = predictions[i];
                bestIdx = i;
            }
        }

        for (let i = 0; i < 9; i++) {
            const rect = this._outputRects[i];
            const text = this._outputTexts[i];
            if (!rect || !text) continue;

            const p = predictions[i];
            const isOccupied = grid[i] !== 0;

            if (isOccupied) {
                // Dim occupied cells
                rect.setAttribute('fill', 'rgba(40, 40, 60, 0.3)');
                rect.setAttribute('stroke', 'rgba(100, 100, 120, 0.2)');
                rect.classList.remove('nn-output-rect--best');
                text.textContent = '–';
                text.setAttribute('fill', 'rgba(255,255,255,0.15)');
            } else {
                // Heatmap color: green intensity ∝ prediction
                const intensity = Math.min(p * 3, 1);
                rect.setAttribute('fill',
                    `rgba(46, 204, 113, ${(0.05 + intensity * 0.6).toFixed(2)})`
                );
                rect.setAttribute('stroke',
                    `rgba(46, 204, 113, ${(0.2 + intensity * 0.5).toFixed(2)})`
                );

                // Best move highlight
                if (i === bestIdx) {
                    rect.classList.add('nn-output-rect--best');
                } else {
                    rect.classList.remove('nn-output-rect--best');
                }

                // Percentage label
                const pct = (p * 100).toFixed(0);
                text.textContent = pct + '%';
                text.setAttribute('fill',
                    intensity > 0.3 ? '#fff' : 'rgba(255,255,255,0.4)'
                );
            }
        }
    }

    // ========================================================================
    // INPUT HIGHLIGHT (Receptive Field)
    // ========================================================================

    /**
     * Highlights input rects by weight magnitude (receptive field visualization).
     * @param {number[]} weights — 9 weights from a hidden neuron
     */
    highlightInputWeights(weights) {
        if (!this._inputRects) return;

        let maxAbs = 0;
        for (const w of weights) maxAbs = Math.max(maxAbs, Math.abs(w));
        if (maxAbs === 0) maxAbs = 1;

        for (let i = 0; i < 9; i++) {
            const rect = this._inputRects[i];
            if (!rect) continue;

            const w = weights[i];
            const normalized = w / maxAbs;
            const intensity = Math.abs(normalized);

            let fill;
            if (normalized > 0.1) {
                fill = `rgba(52, 152, 219, ${0.2 + intensity * 0.8})`;
            } else if (normalized < -0.1) {
                fill = `rgba(231, 76, 60, ${0.2 + intensity * 0.8})`;
            } else {
                fill = 'var(--nn-neuron-fill, #1a1a2e)';
            }

            rect.setAttribute('fill', fill);
            rect.setAttribute('stroke-width', (1.5 + intensity * 2).toFixed(1));
        }
    }

    /**
     * Resets input highlights back to board state.
     * @param {number[]} grid — TTTRegularBoard.grid
     */
    resetInputHighlight(grid) {
        this.updateInputBoard(grid || Array(9).fill(0));
        if (this._inputRects) {
            for (const rect of this._inputRects) {
                if (rect) rect.setAttribute('stroke-width', '1.5');
            }
        }
    }

    // ========================================================================
    // DESTROY
    // ========================================================================

    /**
     * Destroys the visualizer and cleans up event listeners.
     */
    destroy() {
        if (this.svg) {
            this.svg.removeEventListener('wheel', this._onWheel);
            this.svg.removeEventListener('pointerdown', this._onPointerDown);
            this.svg.removeEventListener('pointermove', this._onPointerMove);
            this.svg.removeEventListener('pointerup', this._onPointerUp);
            this.svg.removeEventListener('pointerleave', this._onPointerUp);
        }
        this.container.innerHTML = '';
        this.svg = null;
        this.viewport = null;
        this.neuronElements = null;
        this.connectionElements = null;
        this._inputRects = null;
        this._inputTexts = null;
        this._outputRects = null;
        this._outputTexts = null;
        this.lastSnapshot = null;
    }
}