Skip to content

Commit 76c3ce4

Browse files
chore: Upgrade net to handle types
1 parent f31dced commit 76c3ce4

18 files changed

+270
-283
lines changed

.babelrc

Lines changed: 0 additions & 21 deletions
This file was deleted.

.eslintrc.json renamed to .eslintrc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"no-underscore-dangle": "off",
4141
"prettier/prettier": "error",
4242
"semi": "off",
43-
"standard/no-callback-literal": "off"
43+
"standard/no-callback-literal": "off",
44+
"no-implied-eval": "off"
4445
}
4546
}

src/cross-validate.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import { LSTMTimeStep } from './recurrent/lstm-time-step';
99

1010
describe('CrossValidate', () => {
1111
describe('.train()', () => {
12-
class FakeNN extends NeuralNetwork {
12+
class FakeNN extends NeuralNetwork<number[], number[]> {
1313
constructor(
1414
options: Partial<
1515
INeuralNetworkOptions & INeuralNetworkTrainOptions
@@ -195,7 +195,7 @@ describe('CrossValidate', () => {
195195
});
196196
});
197197
describe('.fromJSON()', () => {
198-
class FakeNN extends NeuralNetwork {}
198+
class FakeNN extends NeuralNetwork<number[], number[]> {}
199199
it("creates a new instance of constructor from argument's sets.error", () => {
200200
const cv = new CrossValidate(FakeNN);
201201
const options = { inputSize: 1, hiddenLayers: [10], outputSize: 1 };
@@ -241,7 +241,7 @@ describe('CrossValidate', () => {
241241
});
242242
});
243243
describe('.toNeuralNetwork()', () => {
244-
class FakeNN extends NeuralNetwork {}
244+
class FakeNN extends NeuralNetwork<number[], number[]> {}
245245
it('creates a new instance of constructor from top .json sets.error', () => {
246246
const cv = new CrossValidate(FakeNN);
247247
const details = {

src/feed-forward.unit.test.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ describe('FeedForward Class: Unit', () => {
7777
{
7878
filterHeight: 3,
7979
filterWidth: 3,
80+
filterCount: 1,
8081
padding: 2,
8182
stride: 2,
8283
},
@@ -100,6 +101,7 @@ describe('FeedForward Class: Unit', () => {
100101
padding: 2,
101102
filterWidth: 3,
102103
filterHeight: 3,
104+
filterCount: 1,
103105
stride: 3,
104106
},
105107
inputLayer
@@ -171,6 +173,7 @@ describe('FeedForward Class: Unit', () => {
171173
{
172174
filterWidth: 3, // TODO: setting height, width should behave same
173175
filterHeight: 3,
176+
filterCount: 3,
174177
padding: 2,
175178
stride: 3,
176179
},
@@ -187,6 +190,7 @@ describe('FeedForward Class: Unit', () => {
187190
{
188191
filterWidth: 3,
189192
filterHeight: 3,
193+
filterCount: 16,
190194
padding: 2,
191195
stride: 2,
192196
},

src/layer/target.test.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ describe('Target Layer', () => {
5757
});
5858

5959
test('uses compare2D when width > 1', () => {
60-
const target = new Target({}, mockLayer({ height: 10, width: 10 }));
60+
const target = new Target(
61+
{ height: 10, width: 10 },
62+
mockLayer({ height: 10, width: 10 })
63+
);
6164
target.setupKernels();
6265
expect(makeKernel).toHaveBeenCalledWith(compare2D, {
6366
output: [10, 10],

src/likely.test.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ import { NeuralNetwork } from './neural-network';
33

44
/**
55
* Return 0 or 1 for '#'
6-
* @param character
7-
* @returns {number}
86
*/
97
function integer(character: string): number {
108
if (character === '#') return 1;
@@ -13,8 +11,6 @@ function integer(character: string): number {
1311

1412
/**
1513
* Turn the # into 1s and . into 0s. for whole string
16-
* @param string
17-
* @returns {Array}
1814
*/
1915
function character(string: string): number[] {
2016
return string.trim().split('').map(integer);

src/likely.ts

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
1-
import { INumberHash } from './lookup';
2-
import { NeuralNetwork } from './neural-network';
1+
import { INeuralNetworkData, NeuralNetwork } from './neural-network';
32

4-
/**
5-
*
6-
* @param {*} input
7-
* @param {brain.NeuralNetwork} net
8-
* @returns {*}
9-
*/
10-
export function likely<T extends number[] | Float32Array | INumberHash>(
11-
input: T,
12-
net: NeuralNetwork
13-
): T | null {
3+
export function likely<
4+
InputType extends INeuralNetworkData,
5+
OutputType extends INeuralNetworkData
6+
>(
7+
input: InputType,
8+
net: NeuralNetwork<InputType, OutputType>
9+
): OutputType | null {
1410
if (!net) {
1511
throw new TypeError(
1612
`Required parameter 'net' is of type ${typeof net}. Must be of type 'brain.NeuralNetwork'`
1713
);
1814
}
1915

20-
const output = net.run<T>(input);
16+
const output = net.run(input);
2117
let maxProp = null;
2218
let maxValue = -1;
2319

2420
Object.entries(output).forEach(([key, value]) => {
25-
if (value > maxValue) {
21+
if (typeof value !== 'undefined' && value > maxValue) {
2622
maxProp = key;
2723
maxValue = value;
2824
}

src/lookup.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export interface INumberArray {
1010
[index: number]: number;
1111
}
1212

13-
export type InputOutputValue = INumberArray | INumberHash;
13+
export type InputOutputValue = INumberArray | Partial<INumberHash>;
1414

1515
export interface ITrainingDatum {
1616
input: InputOutputValue | InputOutputValue[] | KernelOutput;

src/neural-network-gpu.test.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import { Texture } from 'gpu.js';
2+
13
import { NeuralNetwork } from './neural-network';
24
import { NeuralNetworkGPU } from './neural-network-gpu';
3-
import { Texture } from 'gpu.js';
45

56
describe('NeuralNetworkGPU', () => {
67
const xorTrainingData = [
@@ -33,11 +34,11 @@ describe('NeuralNetworkGPU', () => {
3334
});
3435

3536
it('can serialize from NeuralNetworkGPU & deserialize to NeuralNetwork', () => {
36-
const net = new NeuralNetworkGPU();
37+
const net = new NeuralNetworkGPU<number[], number[]>();
3738
net.train(xorTrainingData, { iterations: 1 });
3839
const target = xorTrainingData.map((datum) => net.run(datum.input));
3940
const json = net.toJSON();
40-
const net2 = new NeuralNetwork();
41+
const net2 = new NeuralNetwork<number[], number[]>();
4142
net2.fromJSON(json);
4243
for (let i = 0; i < xorTrainingData.length; i++) {
4344
// there is a wee bit of loss going from GPU to CPU
@@ -49,11 +50,11 @@ describe('NeuralNetworkGPU', () => {
4950
});
5051

5152
it('can serialize from NeuralNetwork & deserialize to NeuralNetworkGPU', () => {
52-
const net = new NeuralNetwork();
53+
const net = new NeuralNetwork<number[], number[]>();
5354
net.train(xorTrainingData, { iterations: 1 });
5455
const target = xorTrainingData.map((datum) => net.run(datum.input));
5556
const json = net.toJSON();
56-
const net2 = new NeuralNetworkGPU();
57+
const net2 = new NeuralNetworkGPU<number[], number[]>();
5758
net2.fromJSON(json);
5859
for (let i = 0; i < xorTrainingData.length; i++) {
5960
// there is a wee bit of loss going from CPU to GPU

0 commit comments

Comments
 (0)