English | 简体中文
A micro deep learning inference engine built with WebGPU from scratch. This project demonstrates AI infrastructure optimization techniques including kernel fusion, memory layout optimization, and Im2Col algorithm implementation.
- WebGPU-based Compute: Hand-written WGSL shaders for all neural network operators
- Core Operators: Conv2d, MaxPool, ReLU, Softmax, Dense, Flatten
- Kernel Fusion: Fused Conv2d+Bias+ReLU operator for reduced memory traffic
- Memory Layout Metadata: Tensor layout metadata and conversion helpers
- Im2Col Utilities: Experimental convolution-to-GEMM helpers for exploration
- Property-Based Testing: Partial fast-check coverage for core tensor/operator behavior
- MNIST Support: End-to-end digit classification demo
src/
├── core/ # Core infrastructure (GPUContext, Tensor, errors)
├── operators/ # Neural network operators
├── engine/ # Inference engine and model loader
└── utils/ # Utilities (benchmark, im2col, CPU reference)
tests/
├── core/ # Core component tests
└── operators/ # Operator tests with property-based testing
npm installnpm run build# Run all tests
npm test
# Run tests in watch mode
npm run test:watch
# Generate coverage report
npm run test:coverageimport { InferenceEngine, ModelLoader } from 'tiny-dl-inference';
// Initialize engine
const engine = new InferenceEngine();
await engine.initialize();
// Load model
const loader = new ModelLoader();
const model = await loader.loadFromJSON('model.json');
await engine.loadModel(model);
// Run inference
const input = engine.tensorFromArray(imageData, [1, 1, 28, 28]);
const output = await engine.infer(input);
const result = await output.download();
console.log('Predictions:', result);Combines Conv2d + Bias + ReLU into a single GPU kernel, reducing memory traffic by 3x:
- Non-fused: 6 memory operations (3 reads + 3 writes)
- Fused: 2 memory operations (1 read + 1 write)
- NCHW: PyTorch-style layout [Batch, Channel, Height, Width]
- NHWC: TensorFlow-style layout [Batch, Height, Width, Channel]
- Tensor layout metadata and conversion helpers are available, but the current Conv2d / MaxPool execution path supports NCHW only
The repository includes Im2Col utilities for experimentation and reference work. They are not yet wired into the default Conv2d execution path:
Input [N, C, H, W] → Im2Col → [N*outH*outW, C*kH*kW]
Weight [K, C, kH, kW] → Reshape → [K, C*kH*kW]
Output = GEMM(Weight, Im2Col(Input))
The project uses dual testing approach:
- Unit Tests: Specific examples and edge cases
- Property Tests: Universal properties across all inputs (100+ iterations each)
- Tensor Data Round-Trip: Upload → Download preserves data
- Layout Conversion Round-Trip: NCHW → NHWC → NCHW preserves data
- ReLU Element-wise Correctness: output[i] = max(0, input[i])
- Softmax Output Validity: All values in [0,1], sum to 1.0
- Softmax Numerical Stability: No NaN/Infinity for large inputs
- MaxPool Output Shape: Correct shape calculation
- MaxPool Correctness: Selects maximum in pooling window
- Conv2d Output Shape: Correct shape with stride/padding
- Conv2d Correctness: Matches CPU reference implementation
- Operator Contract Checks: Invalid layout / shape / axis usage fails fast
- Model Loading Checks: Weight shape metadata and binary parsing are validated
- Engine Context Safety: Input tensors must come from the engine's GPU context
The repository includes a Benchmark helper for exploratory measurements, but its higher-level comparison helpers are still best treated as experimental. In particular, fusion and layout comparisons should be interpreted cautiously unless the compared operators use matching synchronization and input conventions.
import { Benchmark } from 'tiny-dl-inference';
const benchmark = new Benchmark();
const result = await benchmark.measureOperator(operator, inputs, params, 100);
console.log(`Execution time: ${result.executionTimeMs}ms`);For now, prefer measureOperator(...) for local experiments, and validate any fused/layout benchmark with a correctness test alongside it.
- Zero Dependencies: No TensorFlow.js or ONNX Runtime
- Type-Safe: Full TypeScript implementation
- Modular Architecture: Easy to extend with new operators
- Focused Testing: Property-based, contract, and reference-backed tests cover key correctness paths
- Educational: Clear code demonstrating AI infrastructure concepts
Requires WebGPU support:
- Chrome 113+
- Edge 113+
- Safari 18+ (macOS Sonoma+)
MIT
Built as a demonstration of AI infrastructure and GPU computing expertise.