使用TensorSpace可视化TensorFlow.js预训练模型

【免费下载链接】tensorspace Neural network 3D visualization framework, build interactive and intuitive model in browsers, support pre-trained deep learning models from TensorFlow, Keras, TensorFlow.js 【免费下载链接】tensorspace 项目地址: https://gitcode.com/gh_mirrors/te/tensorspace

你是否曾经面对复杂的神经网络模型感到困惑?想要直观理解模型内部的结构和工作原理?TensorSpace.js正是为解决这一痛点而生的革命性3D可视化框架。本文将手把手教你如何使用TensorSpace来可视化TensorFlow.js预训练模型,让你真正"看见"神经网络的运行过程。

什么是TensorSpace?

TensorSpace是一个基于TensorFlow.js、Three.js和Tween.js构建的神经网络3D可视化框架。它提供类似Keras的API来构建深度学习层、加载预训练模型,并在浏览器中生成3D可视化效果。通过TensorSpace,你可以直观地学习模型结构、训练过程以及基于中间信息的预测结果。

核心优势

  • 交互式 - 使用Layer API在浏览器中构建交互式模型
  • 直观性 - 可视化中间推理过程中的信息流
  • 集成性 - 支持TensorFlow、Keras和TensorFlow.js的预训练模型

环境准备与安装

基础安装方式

首先需要下载必要的依赖文件:

<!-- TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.18.0/dist/tf.min.js"></script>

<!-- Three.js -->
<script src="https://cdn.jsdelivr.net/npm/three@0.137.0/build/three.min.js"></script>

<!-- Tween.js -->
<script src="https://cdn.jsdelivr.net/npm/@tweenjs/tween.js@18.6.4/dist/tween.umd.js"></script>

<!-- TrackballControls -->
<script src="https://cdn.jsdelivr.net/npm/three@0.137.0/examples/js/controls/TrackballControls.js"></script>

<!-- TensorSpace -->
<script src="https://cdn.jsdelivr.net/npm/tensorspace@0.8.5/dist/tensorspace.min.js"></script>

使用NPM/Yarn安装

对于现代前端项目,推荐使用包管理器:

# 使用NPM
npm install tensorspace @tensorflow/tfjs three @tweenjs/tween.js

# 或使用Yarn
yarn add tensorspace @tensorflow/tfjs three @tweenjs/tween.js

然后在代码中导入:

import * as TSP from 'tensorspace';
import * as tf from '@tensorflow/tfjs';

TensorFlow.js模型预处理

在使用TensorSpace可视化之前,需要对TensorFlow.js模型进行预处理。这里我们使用TensorSpace-Converter工具。

安装TensorSpace-Converter

pip install tensorspacejs

转换TensorFlow.js模型

假设我们有一个训练好的LeNet模型(mnist.json和mnist.weights.bin),转换命令如下:

tensorspacejs_converter \
    --input_model_from="tfjs" \
    --output_layer_names="myPadding,myConv1,myMaxPooling1,myConv2,myMaxPooling2,myDense1,myDense2,myDense3" \
    ./rawModel/mnist.json \
    ./convertedModel/

参数说明

参数 说明 示例值
--input_model_from 输入模型来源 tfjs
--output_layer_names 需要输出的层名称 myPadding,myConv1,...
输入路径 TensorFlow.js模型拓扑文件 ./rawModel/mnist.json
输出路径 转换后模型保存目录 ./convertedModel/

转换完成后,你将在目标目录看到以下文件结构:

convertedModel/
├── model.json
└── model.weights.bin

构建TensorSpace可视化模型

现在让我们开始构建3D可视化界面。

基本HTML结构

<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <title>TensorSpace TensorFlow.js可视化</title>
    <style>
        html, body {
            margin: 0;
            padding: 0;
            width: 100%;
            height: 100%;
            overflow: hidden;
        }
        #container {
            width: 100%;
            height: 100%;
            background-color: #000;
        }
    </style>
</head>
<body>
    <div id="container"></div>
    
    <!-- 引入依赖库 -->
    <script src="lib/tf.min.js"></script>
    <script src="lib/three.min.js"></script>
    <script src="lib/tween.min.js"></script>
    <script src="lib/TrackballControls.js"></script>
    <script src="lib/tensorspace.min.js"></script>
    
    <script>
        // 可视化代码将在这里编写
    </script>
</body>
</html>

创建TensorSpace模型实例

// 获取容器元素
let container = document.getElementById("container");

// 创建Sequential模型实例
let model = new TSP.models.Sequential(container, {
    animeTime: 200,        // 动画时长(ms)
    stats: true,           // 显示性能统计
    initRotation: {        // 初始旋转角度
        x: -0.3,
        y: 0,
        z: 0
    }
});

构建模型拓扑结构

根据LeNet的网络结构添加各层:

// 构建LeNet网络结构
model.add(new TSP.layers.GreyscaleInput({
    shape: [28, 28, 1],    // 输入形状
    name: "inputLayer"     // 层名称
}));

model.add(new TSP.layers.Padding2d({
    padding: [2, 2],
    name: "paddingLayer"
}));

model.add(new TSP.layers.Conv2d({
    kernelSize: [5, 5],
    filters: 6,
    strides: [1, 1],
    name: "conv1"
}));

model.add(new TSP.layers.Pooling2d({
    poolSize: [2, 2],
    strides: [2, 2],
    name: "pool1"
}));

model.add(new TSP.layers.Conv2d({
    kernelSize: [5, 5],
    filters: 16,
    strides: [1, 1],
    name: "conv2"
}));

model.add(new TSP.layers.Pooling2d({
    poolSize: [2, 2],
    strides: [2, 2],
    name: "pool2"
}));

model.add(new TSP.layers.Dense({
    units: 120,
    name: "dense1"
}));

model.add(new TSP.layers.Dense({
    units: 84,
    name: "dense2"
}));

model.add(new TSP.layers.Output1d({
    outputs: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
    name: "outputLayer"
}));

加载预处理后的模型

// 加载转换后的TensorFlow.js模型
model.load({
    type: "tfjs",                    // 模型类型
    url: './convertedModel/model.json', // 模型文件路径
    tfjsLoadOption: {                // TensorFlow.js加载选项
        requestInit: {
            credentials: 'same-origin'
        }
    }
});

// 初始化模型
model.init(function() {
    console.log("TensorSpace模型初始化完成!");
    
    // 可以在这里添加预测代码
    // model.predict(inputData);
});

交互功能实现

添加手写输入功能

// 创建画布用于手写输入
let canvas = document.createElement('canvas');
canvas.width = 280;
canvas.height = 280;
canvas.style.position = 'absolute';
canvas.style.top = '20px';
canvas.style.left = '20px';
canvas.style.border = '2px solid #fff';
canvas.style.backgroundColor = '#000';
document.body.appendChild(canvas);

let ctx = canvas.getContext('2d');
let isDrawing = false;

// 鼠标事件处理
canvas.addEventListener('mousedown', startDrawing);
canvas.addEventListener('mousemove', draw);
canvas.addEventListener('mouseup', stopDrawing);
canvas.addEventListener('mouseout', stopDrawing);

function startDrawing(e) {
    isDrawing = true;
    draw(e);
}

function draw(e) {
    if (!isDrawing) return;
    
    ctx.lineWidth = 15;
    ctx.lineCap = 'round';
    ctx.strokeStyle = '#fff';
    
    ctx.lineTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop);
    ctx.stroke();
    ctx.beginPath();
    ctx.moveTo(e.clientX - canvas.offsetLeft, e.clientY - canvas.offsetTop);
}

function stopDrawing() {
    isDrawing = false;
    ctx.beginPath();
    
    // 转换为模型输入格式
    let imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
    let inputData = processImageData(imageData);
    
    // 进行预测
    model.predict(inputData);
}

function processImageData(imageData) {
    let processedData = [];
    let data = imageData.data;
    
    // 将图像数据转换为28x28的灰度数组
    for (let y = 0; y < 28; y++) {
        for (let x = 0; x < 28; x++) {
            // 采样并归一化
            let pixelIndex = ((y * 10) * 280 + (x * 10)) * 4;
            let r = data[pixelIndex];
            let g = data[pixelIndex + 1];
            let b = data[pixelIndex + 2];
            let gray = (r + g + b) / 3 / 255; // 归一化到0-1
            processedData.push(gray);
        }
    }
    
    return processedData;
}

// 清空画布按钮
let clearBtn = document.createElement('button');
clearBtn.textContent = '清空';
clearBtn.style.position = 'absolute';
clearBtn.style.top = '310px';
clearBtn.style.left = '20px';
clearBtn.addEventListener('click', function() {
    ctx.clearRect(0, 0, canvas.width, canvas.height);
    model.clear(); // 清空模型显示
});
document.body.appendChild(clearBtn);

模型控制功能

// 添加控制面板
let controlPanel = document.createElement('div');
controlPanel.style.position = 'absolute';
controlPanel.style.top = '20px';
controlPanel.style.right = '20px';
controlPanel.style.background = 'rgba(0,0,0,0.7)';
controlPanel.style.padding = '10px';
controlPanel.style.color = '#fff';
controlPanel.style.borderRadius = '5px';

controlPanel.innerHTML = `
    <h3>模型控制</h3>
    <button id="rotateBtn">自动旋转</button>
    <button id="resetBtn">重置视角</button>
    <button id="wireframeBtn">线框模式</button>
    <br><br>
    <label>动画速度: </label>
    <input type="range" id="speedSlider" min="50" max="500" value="200">
`;

document.body.appendChild(controlPanel);

// 控制功能实现
document.getElementById('rotateBtn').addEventListener('click', function() {
    model.toggleAutoRotate();
});

document.getElementById('resetBtn').addEventListener('click', function() {
    model.resetCamera();
});

document.getElementById('wireframeBtn').addEventListener('click', function() {
    model.toggleWireframe();
});

document.getElementById('speedSlider').addEventListener('input', function(e) {
    model.setAnimeTime(parseInt(e.target.value));
});

高级功能与最佳实践

1. 模型性能优化

// 配置性能优化选项
let model = new TSP.models.Sequential(container, {
    animeTime: 200,
    stats: true,
    renderer: {
        antialias: false,    // 关闭抗锯齿提升性能
        precision: 'mediump' // 使用中等精度
    },
    layerConfig: {
        lodDistance: 50,     // 层次细节距离
        maxInstance: 1000    // 最大实例数
    }
});

2. 自定义层样式

// 自定义卷积层样式
model.add(new TSP.layers.Conv2d({
    kernelSize: [5, 5],
    filters: 6,
    styles: {
        color: 0x3498db,     // 蓝色主题
        opacity: 0.8,
        wireframe: {
            visible: true,
            color: 0xffffff
        }
    },
    name: "customConv1"
}));

3. 事件监听与回调

// 添加模型事件监听
model.on('loadComplete', function() {
    console.log('模型加载完成');
});

model.on('predictStart', function() {
    console.log('预测开始');
});

model.on('predictComplete', function(results) {
    console.log('预测完成', results);
    
    // 显示预测结果
    displayResults(results);
});

model.on('layerActive', function(layerName) {
    console.log('激活层:', layerName);
});

function displayResults(results) {
    let resultDiv = document.createElement('div');
    resultDiv.style.position = 'absolute';
    resultDiv.style.bottom = '20px';
    resultDiv.style.left = '20px';
    resultDiv.style.background = 'rgba(0,0,0,0.7)';
    resultDiv.style.padding = '10px';
    resultDiv.style.color = '#fff';
    resultDiv.style.borderRadius = '5px';
    
    let html = '<h3>预测结果</h3>';
    results.forEach((score, index) => {
        html += `<div>数字 ${index}: ${(score * 100).toFixed(2)}%</div>`;
    });
    
    resultDiv.innerHTML = html;
    document.body.appendChild(resultDiv);
    
    // 3秒后自动移除
    setTimeout(() => {
        document.body.removeChild(resultDiv);
    }, 3000);
}

故障排除与常见问题

1. 模型加载失败

问题: 控制台显示"Failed to load model"错误 解决方案:

model.load({
    type: "tfjs",
    url: './convertedModel/model.json',
    onError: function(error) {
        console.error('模型加载错误:', error);
        // 可以在这里添加重试逻辑或错误处理
    },
    onProgress: function(progress) {
        console.log('加载进度:', progress);
    }
});

2. 跨域问题

问题: 从不同域加载模型时出现CORS错误 解决方案:

  • 配置服务器允许跨域请求
  • 或使用中间服务
  • 或将模型文件放在同域下

3. 性能问题

问题: 模型渲染卡顿 解决方案:

  • 减少模型复杂度
  • 使用层次细节(LOD)
  • 优化层配置参数

完整示例代码

以下是一个完整的TensorSpace可视化示例:

<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <title>TensorFlow.js模型可视化</title>
    <style>
        body { margin: 0; padding: 0; overflow: hidden; background: #000; }
        #container { width: 100vw; height: 100vh; }
        .control { position: absolute; top: 20px; right: 20px; background: rgba(0,0,0,0.8); 
                  padding: 15px; color: #fff; border-radius: 8px; z-index: 100; }
        .control button { margin: 5px; padding: 8px 12px; background: #3498db; 
                         border: none; border-radius: 4px; color: white; cursor: pointer; }
        .control button:hover { background: #2980b9; }
        canvas { position: absolute; top: 20px; left: 20px; border: 2px solid #fff; 
                background: #000; cursor: crosshair; }
    </style>
</head>
<body>
    <div id="container"></div>
    <canvas id="drawCanvas" width="280" height="280"></canvas>
    
    <div class="control">
        <h3>控制面板</h3>
        <button id="clearBtn">清空画布</button>
        <button id="rotateBtn">旋转</button>
        <button id="resetBtn">重置</button>
        <br>
        <label>动画速度: <input type="range" id="speedSlider" min="50" max="500" value="200"></label>
    </div>

    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.18.0/dist/tf.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/three@0.137.0/build/three.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tweenjs/tween.js@18.6.4/dist/tween.umd.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/three@0.137.0/examples/js/controls/TrackballControls.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/tensorspace@0.8.5/dist/tensorspace.min.js"></script>

    <script>
        // 初始化绘图画布
        const canvas = document.getElementById('drawCanvas');
        const ctx = canvas.getContext('2d');
        let isDrawing = false;

        // 绘图事件处理
        canvas.addEventListener('mousedown', startDrawing);
        canvas.addEventListener('mousemove', draw);
        canvas.addEventListener('mouseup', stopDrawing);
        canvas.addEventListener('mouseout', stopDrawing);

        function startDrawing(e) {
            isDrawing = true;
            draw(e);
        }

        function draw(e) {
            if (!isDrawing) return;
            ctx.lineWidth = 15;
            ctx.lineCap = 'round';
            ctx.strokeStyle = '#fff';
            ctx.lineTo(e.offsetX,

【免费下载链接】tensorspace Neural network 3D visualization framework, build interactive and intuitive model in browsers, support pre-trained deep learning models from TensorFlow, Keras, TensorFlow.js 【免费下载链接】tensorspace 项目地址: https://gitcode.com/gh_mirrors/te/tensorspace

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐