Skip to content

Commit c8a62f1

Browse files
authored
Add a general-purpose autoencoder class AE to serve as a building block for other autoencoder models (#932)
* Port the AE from `@voidvoxel/auto-encoder` * Rewrite class `AutoEncoder` from scratch * Remove unused property * Use shallow clone * Add unit tests * Replace `shallowClone` with `deepClone` * Rename private properties * Remove the space in the word "autoencoder" [Wikipedia said "Autoencoder" and not "Auto encoder" or "Auto-encoder"](https://en.wikipedia.org/wiki/Autoencoder) * Rename class `Autoencoder` to `AE` The other classes in this library use their respective acronyms with the exceptions of `FeedForward`, `NeuralNetwork`, and `NeuralNetworkGPU`. Furthermore, `AE` variants of existing classes will likely be made, so an acronym equivalent would be desirable. I'm considering the naming conventions for classes such as `LSTMTimeStep`. For example maybe a future `VAE` class could be made to represent variational autoencoders. * Add `AE` usage to README.md * Update references to `AE` * Use Partial<T> instead of nullable properties * Update autoencoder.ts * Minor improvements * Fix "@rollup/plugin-typescript TS2807" * Choose a more accurate name for `includesAnomalies` (`likelyIncludesAnomalies`, as it makes no guarantees that anomalies are truly present and only provides an intuitive guess)
1 parent d7b4b03 commit c8a62f1

File tree

5 files changed

+328
-1
lines changed

5 files changed

+328
-1
lines changed

README.md

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ GPU accelerated Neural networks in JavaScript for Browsers and Node.js
1717
![CI](https://github.com/BrainJS/brain.js/workflows/CI/badge.svg)
1818
[![codecov](https://codecov.io/gh/BrainJS/brain.js/branch/master/graph/badge.svg?token=3SJIBJ1679)](https://codecov.io/gh/BrainJS/brain.js)
1919
<a href="https://twitter.com/brainjsfnd"><img src="https://img.shields.io/twitter/follow/brainjsfnd?label=Twitter&style=social" alt="Twitter"></a>
20-
20+
2121
[![NPM](https://nodei.co/npm/brain.js.png?compact=true)](https://nodei.co/npm/brain.js/)
2222

2323
</p>
@@ -43,6 +43,7 @@ GPU accelerated Neural networks in JavaScript for Browsers and Node.js
4343
- [For training with NeuralNetwork](#for-training-with-neuralnetwork)
4444
- [For training with `RNNTimeStep`, `LSTMTimeStep` and `GRUTimeStep`](#for-training-with-rnntimestep-lstmtimestep-and-grutimestep)
4545
- [For training with `RNN`, `LSTM` and `GRU`](#for-training-with-rnn-lstm-and-gru)
46+
- [For training with `AE`](#for-training-with-ae)
4647
- [Training Options](#training-options)
4748
- [Async Training](#async-training)
4849
- [Cross Validation](#cross-validation)
@@ -317,6 +318,54 @@ net.train([
317318
const output = net.run('I feel great about the world!'); // 'happy'
318319
```
319320

321+
#### For training with `AE`
322+
323+
Each training pattern can either:
324+
325+
- Be an array of numbers
326+
- Be an array of arrays of numbers
327+
328+
Training an autoencoder to compress the values of a XOR calculation:
329+
330+
```javascript
331+
const net = new brain.AE(
332+
{
333+
hiddenLayers: [ 5, 2, 5 ]
334+
}
335+
);
336+
337+
net.train([
338+
[ 0, 0, 0 ],
339+
[ 0, 1, 1 ],
340+
[ 1, 0, 1 ],
341+
[ 1, 1, 0 ]
342+
]);
343+
```
344+
345+
Encoding/decoding:
346+
347+
```javascript
348+
const input = [ 0, 1, 1 ];
349+
350+
const encoded = net.encode(input);
351+
const decoded = net.decode(encoded);
352+
```
353+
354+
Denoise noisy data:
355+
356+
```javascript
357+
const noisyData = [ 0, 1, 0 ];
358+
359+
const data = net.denoise(noisyData);
360+
```
361+
362+
Test for anomalies in data samples:
363+
364+
```javascript
365+
const shouldBeFalse = net.includesAnomalies([0, 1, 1]);
366+
const shouldBeTrue = net.includesAnomalies([0, 1, 0]);
367+
```
368+
320369
### Training Options
321370

322371
`train()` takes a hash of options as its second argument:
@@ -595,6 +644,7 @@ The user interface used:
595644

596645
- [`brain.NeuralNetwork`](src/neural-network.ts) - [Feedforward Neural Network](https://en.wikipedia.org/wiki/Feedforward_neural_network) with backpropagation
597646
- [`brain.NeuralNetworkGPU`](src/neural-network-gpu.ts) - [Feedforward Neural Network](https://en.wikipedia.org/wiki/Feedforward_neural_network) with backpropagation, GPU version
647+
- [`brain.AE`](src/autoencoder.ts) - [Autoencoder or "AE"](https://en.wikipedia.org/wiki/Autoencoder) with backpropogation and GPU support
598648
- [`brain.recurrent.RNNTimeStep`](src/recurrent/rnn-time-step.ts) - [Time Step Recurrent Neural Network or "RNN"](https://en.wikipedia.org/wiki/Recurrent_neural_network)
599649
- [`brain.recurrent.LSTMTimeStep`](src/recurrent/lstm-time-step.ts) - [Time Step Long Short Term Memory Neural Network or "LSTM"](https://en.wikipedia.org/wiki/Long_short-term_memory)
600650
- [`brain.recurrent.GRUTimeStep`](src/recurrent/gru-time-step.ts) - [Time Step Gated Recurrent Unit or "GRU"](https://en.wikipedia.org/wiki/Gated_recurrent_unit)

src/autoencoder.test.ts

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import AE from "./autoencoder";
2+
3+
const trainingData = [
4+
[0, 0, 0],
5+
[0, 1, 1],
6+
[1, 0, 1],
7+
[1, 1, 0]
8+
];
9+
10+
const xornet = new AE<number[], number[]>(
11+
{
12+
decodedSize: 3,
13+
hiddenLayers: [ 5, 2, 5 ]
14+
}
15+
);
16+
17+
const errorThresh = 0.011;
18+
19+
const result = xornet.train(
20+
trainingData, {
21+
iterations: 100000,
22+
errorThresh
23+
}
24+
);
25+
26+
test(
27+
"denoise a data sample",
28+
async () => {
29+
expect(result.error).toBeLessThanOrEqual(errorThresh);
30+
31+
function xor(...args: number[]) {
32+
return Math.round(xornet.denoise(args)[2]);
33+
}
34+
35+
const run1 = xor(0, 0, 0);
36+
const run2 = xor(0, 1, 1);
37+
const run3 = xor(1, 0, 1);
38+
const run4 = xor(1, 1, 0);
39+
40+
expect(run1).toBe(0);
41+
expect(run2).toBe(1);
42+
expect(run3).toBe(1);
43+
expect(run4).toBe(0);
44+
}
45+
);
46+
47+
test(
48+
"encode and decode a data sample",
49+
async () => {
50+
expect(result.error).toBeLessThanOrEqual(errorThresh);
51+
52+
const run1$input = [0, 0, 0];
53+
const run1$encoded = xornet.encode(run1$input);
54+
const run1$decoded = xornet.decode(run1$encoded);
55+
56+
const run2$input = [0, 1, 1];
57+
const run2$encoded = xornet.encode(run2$input);
58+
const run2$decoded = xornet.decode(run2$encoded);
59+
60+
for (let i = 0; i < 3; i++) expect(Math.round(run1$decoded[i])).toBe(run1$input[i]);
61+
for (let i = 0; i < 3; i++) expect(Math.round(run2$decoded[i])).toBe(run2$input[i]);
62+
}
63+
);
64+
65+
test(
66+
"test a data sample for anomalies",
67+
async () => {
68+
expect(result.error).toBeLessThanOrEqual(errorThresh);
69+
70+
function includesAnomalies(...args: number[]) {
71+
expect(xornet.likelyIncludesAnomalies(args)).toBe(false);
72+
}
73+
74+
includesAnomalies(0, 0, 0);
75+
includesAnomalies(0, 1, 1);
76+
includesAnomalies(1, 0, 1);
77+
includesAnomalies(1, 1, 0);
78+
}
79+
);

src/autoencoder.ts

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import { KernelOutput, Texture, TextureArrayOutput } from "gpu.js";
2+
import { IJSONLayer, INeuralNetworkData, INeuralNetworkDatum, INeuralNetworkTrainOptions } from "./neural-network";
3+
import { INeuralNetworkGPUOptions, NeuralNetworkGPU } from "./neural-network-gpu";
4+
import { INeuralNetworkState } from "./neural-network-types";
5+
import { UntrainedNeuralNetworkError } from "./errors/untrained-neural-network-error";
6+
7+
export interface IAEOptions {
8+
binaryThresh: number;
9+
decodedSize: number;
10+
hiddenLayers: number[];
11+
}
12+
13+
/**
14+
* An autoencoder learns to compress input data down to relevant features and reconstruct input data from its compressed representation.
15+
*/
16+
export class AE<DecodedData extends INeuralNetworkData, EncodedData extends INeuralNetworkData> {
17+
private decoder?: NeuralNetworkGPU<EncodedData, DecodedData>;
18+
private denoiser: NeuralNetworkGPU<DecodedData, DecodedData>;
19+
20+
constructor (
21+
options?: Partial<IAEOptions>
22+
) {
23+
// Create default options for the autoencoder.
24+
options ??= {};
25+
26+
// Create default options for the autoencoder's denoiser subnet.
27+
const denoiserOptions: Partial<INeuralNetworkGPUOptions> = {};
28+
29+
// Inherit the binary threshold of the parent autoencoder.
30+
denoiserOptions.binaryThresh = options.binaryThresh;
31+
// Inherit the hidden layers of the parent autoencoder.
32+
denoiserOptions.hiddenLayers = options.hiddenLayers;
33+
34+
// Define the denoiser subnet's input and output sizes.
35+
if (options.decodedSize) denoiserOptions.inputSize = denoiserOptions.outputSize = options.decodedSize;
36+
37+
// Create the denoiser subnet of the autoencoder.
38+
this.denoiser = new NeuralNetworkGPU<DecodedData, DecodedData>(options);
39+
}
40+
41+
/**
42+
* Denoise input data, removing any anomalies from the data.
43+
* @param {DecodedData} input
44+
* @returns {DecodedData}
45+
*/
46+
denoise(input: DecodedData): DecodedData {
47+
// Run the input through the generic denoiser.
48+
// This isn't the best denoiser implementation, but it's efficient.
49+
// Efficiency is important here because training should focus on
50+
// optimizing for feature extraction as quickly as possible rather than
51+
// denoising and anomaly detection; there are other specialized topologies
52+
// better suited for these tasks anyways, many of which can be implemented
53+
// by using an autoencoder.
54+
return this.denoiser.run(input);
55+
}
56+
57+
/**
58+
* Decode `EncodedData` into an approximation of its original form.
59+
*
60+
* @param {EncodedData} input
61+
* @returns {DecodedData}
62+
*/
63+
decode(input: EncodedData): DecodedData {
64+
// If the decoder has not been trained yet, throw an error.
65+
if (!this.decoder) throw new UntrainedNeuralNetworkError(this);
66+
67+
// Decode the encoded input.
68+
return this.decoder.run(input);
69+
}
70+
71+
/**
72+
* Encode data to extract features, reduce dimensionality, etc.
73+
*
74+
* @param {DecodedData} input
75+
* @returns {EncodedData}
76+
*/
77+
encode(input: DecodedData): EncodedData {
78+
// If the decoder has not been trained yet, throw an error.
79+
if (!this.denoiser) throw new UntrainedNeuralNetworkError(this);
80+
81+
// Process the input.
82+
this.denoiser.run(input);
83+
84+
// Get the auto-encoded input.
85+
let encodedInput: TextureArrayOutput = this.encodedLayer as TextureArrayOutput;
86+
87+
// If the encoded input is a `Texture`, convert it into an `Array`.
88+
if (encodedInput instanceof Texture) encodedInput = encodedInput.toArray();
89+
else encodedInput = encodedInput.slice(0);
90+
91+
// Return the encoded input.
92+
return encodedInput as EncodedData;
93+
}
94+
95+
/**
96+
* Test whether or not a data sample likely contains anomalies.
97+
* If anomalies are likely present in the sample, returns `true`.
98+
* Otherwise, returns `false`.
99+
*
100+
* @param {DecodedData} input
101+
* @returns {boolean}
102+
*/
103+
likelyIncludesAnomalies(input: DecodedData, anomalyThreshold: number = 0.2): boolean {
104+
// Create the anomaly vector.
105+
const anomalies: number[] = [];
106+
107+
// Attempt to denoise the input.
108+
const denoised = this.denoise(input);
109+
110+
// Calculate the anomaly vector.
111+
for (let i = 0; i < (input.length ?? 0); i++) {
112+
anomalies[i] = Math.abs((input as number[])[i] - (denoised as number[])[i]);
113+
}
114+
115+
// Calculate the sum of all anomalies within the vector.
116+
const sum = anomalies.reduce(
117+
(previousValue, value) => previousValue + value
118+
);
119+
120+
// Calculate the mean anomaly.
121+
const mean = sum / (input as number[]).length;
122+
123+
// Return whether or not the mean anomaly rate is greater than the anomaly threshold.
124+
return mean > anomalyThreshold;
125+
}
126+
127+
/**
128+
* Train the auto encoder.
129+
*
130+
* @param {DecodedData[]} data
131+
* @param {Partial<INeuralNetworkTrainOptions>} options
132+
* @returns {INeuralNetworkState}
133+
*/
134+
train(data: DecodedData[], options?: Partial<INeuralNetworkTrainOptions>): INeuralNetworkState {
135+
const preprocessedData: INeuralNetworkDatum<Partial<DecodedData>, Partial<DecodedData>>[] = [];
136+
137+
for (let datum of data) {
138+
preprocessedData.push( { input: datum, output: datum } );
139+
}
140+
141+
const results = this.denoiser.train(preprocessedData, options);
142+
143+
this.decoder = this.createDecoder();
144+
145+
return results;
146+
}
147+
148+
/**
149+
* Create a new decoder from the trained denoiser.
150+
*
151+
* @returns {NeuralNetworkGPU<EncodedData, DecodedData>}
152+
*/
153+
private createDecoder() {
154+
const json = this.denoiser.toJSON();
155+
156+
const layers: IJSONLayer[] = [];
157+
const sizes: number[] = [];
158+
159+
for (let i = this.encodedLayerIndex; i < this.denoiser.sizes.length; i++) {
160+
layers.push(json.layers[i]);
161+
sizes.push(json.sizes[i]);
162+
}
163+
164+
json.layers = layers;
165+
json.sizes = sizes;
166+
167+
json.options.inputSize = json.sizes[0];
168+
169+
const decoder = new NeuralNetworkGPU().fromJSON(json);
170+
171+
return decoder as unknown as NeuralNetworkGPU<EncodedData, DecodedData>;
172+
}
173+
174+
/**
175+
* Get the layer containing the encoded representation.
176+
*/
177+
private get encodedLayer(): KernelOutput {
178+
return this.denoiser.outputs[this.encodedLayerIndex];
179+
}
180+
181+
/**
182+
* Get the offset of the encoded layer.
183+
*/
184+
private get encodedLayerIndex(): number {
185+
return Math.round(this.denoiser.outputs.length * 0.5) - 1;
186+
}
187+
}
188+
189+
export default AE;
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
export class UntrainedNeuralNetworkError extends Error {
2+
constructor (
3+
neuralNetwork: any
4+
) {
5+
super(`Cannot run a ${neuralNetwork.constructor.name} before it is trained.`);
6+
}
7+
}

src/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import * as activation from './activation';
2+
import { AE } from './autoencoder';
23
import CrossValidate from './cross-validate';
34
import { FeedForward } from './feed-forward';
45
import * as layer from './layer';
@@ -53,6 +54,7 @@ const utilities = {
5354

5455
export {
5556
activation,
57+
AE,
5658
CrossValidate,
5759
likely,
5860
layer,

0 commit comments

Comments
 (0)