Skip to content

Commit 0c493f2

Browse files
Merge pull request #700 from loredanacirstea/types-fix
Fix remaining lint:typecheck errors
2 parents 7b208b6 + 63d2eac commit 0c493f2

File tree

13 files changed

+85
-44
lines changed

13 files changed

+85
-44
lines changed

__tests__/feed-forward/end-to-end.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ import {
88
Target,
99
Sigmoid,
1010
arthurFeedForward,
11+
ILayer,
12+
ILayerSettings,
1113
} from '../../src/layer';
1214
import { feedForward as feedForwardLayer } from '../../src/layer/feed-forward';
1315

1416
import { momentumRootMeanSquaredPropagation } from '../../src/praxis';
1517
import { zeros2D } from '../../src/utilities/zeros-2d';
1618
import { setup, teardown } from '../../src/utilities/kernel';
1719
import { mockPraxis } from '../test-utils';
18-
import { ILayer, ILayerSettings } from '../../src/layer';
1920
import { IPraxis } from '../../src/praxis/base-praxis';
2021

2122
const xorTrainingData = [

__tests__/feed-forward/unit.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {
2525
layerTypes,
2626
ILayer,
2727
ILayerJSON,
28-
ILayerSettings
28+
ILayerSettings,
2929
} from '../../src/layer';
3030
import { mockLayer, mockPraxis } from '../test-utils';
3131
import SpyInstance = jest.SpyInstance;

__tests__/layer/lstm-cell.ts

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import {
55
MultiplyElement,
66
Random,
77
RecurrentZeros,
8-
Sigmoid, Tanh,
8+
Sigmoid,
9+
Tanh,
910
Zeros,
1011
} from '../../src/layer';
1112
import { mockLayer, onePlusPlus2D, TestLayer } from '../test-utils';
@@ -165,15 +166,19 @@ describe('lstm Cell', () => {
165166
const layer = lstmCell(settings, input, recurrentInput);
166167
const layers = flattenLayers([layer]);
167168

168-
recurrentInput.weights = onePlusPlus2D(recurrentInput.width, recurrentInput.height);
169-
const memoryLayers = layers.filter((layer: ILayer) => layer instanceof Random);
170-
memoryLayers
171-
.forEach((layer: ILayer) => {
172-
layer.weights = onePlusPlus2D(layer.width, layer.height);
173-
});
174-
175-
layers.forEach(layer => layer.setupKernels());
176-
layers.forEach(layer => layer.predict());
169+
recurrentInput.weights = onePlusPlus2D(
170+
recurrentInput.width,
171+
recurrentInput.height
172+
);
173+
const memoryLayers = layers.filter(
174+
(layer: ILayer) => layer instanceof Random
175+
);
176+
memoryLayers.forEach((layer: ILayer) => {
177+
layer.weights = onePlusPlus2D(layer.width, layer.height);
178+
});
179+
180+
layers.forEach((layer) => layer.setupKernels());
181+
layers.forEach((layer) => layer.predict());
177182

178183
expect(layers[layers.length - 1].weights).toEqual([
179184
Float32Array.from([0.9640275835990906]),

__tests__/layer/rnn-cell.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
import { rnnCell, RecurrentZeros, Add, Random, Zeros, Multiply, Relu } from '../../src/layer/';
1+
import {
2+
rnnCell,
3+
RecurrentZeros,
4+
Add,
5+
Random,
6+
Zeros,
7+
Multiply,
8+
Relu,
9+
} from '../../src/layer/';
210
import { mockLayer, TestLayer } from '../test-utils';
311
import { flattenLayers } from '../../src/utilities/flatten-layers';
412

__tests__/praxis/arthur-deviation-biases/end-to-end.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ describe('ArthurDeviationBiases', () => {
3939
const praxis = new ArthurDeviationBiases(layer1, {
4040
learningRate: net.trainOpts.learningRate,
4141
});
42-
expect(praxis.settings.learningRate).toBe(
43-
net.trainOpts.learningRate
44-
);
42+
expect(praxis.settings.learningRate).toBe(net.trainOpts.learningRate);
4543

4644
net.deltas[0][0] = 1;
4745
net.deltas[0][1] = 2;

__tests__/recurrent/unit.ts

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
import { GPU } from 'gpu.js';
2-
import { add, input, multiply, output, random, rnnCell, IRecurrentInput, ILayer } from '../../src/layer';
2+
import {
3+
add,
4+
input,
5+
multiply,
6+
output,
7+
random,
8+
rnnCell,
9+
IRecurrentInput,
10+
ILayer,
11+
} from '../../src/layer';
312
import { Filter } from '../../src/layer/filter';
413
import { Recurrent } from '../../src/recurrent';
514
import { Matrix } from '../../src/recurrent/matrix';
@@ -334,12 +343,16 @@ describe('Recurrent Class: Unit', () => {
334343
layers[10].toJSON(),
335344
{ ...layers[11].toJSON(), inputLayer1Index: 10, inputLayer2Index: 9 },
336345
layers[12].toJSON(),
337-
{ ...layers[13].toJSON(), inputLayer1Index: 11, inputLayer2Index: 12 },
346+
{
347+
...layers[13].toJSON(),
348+
inputLayer1Index: 11,
349+
inputLayer2Index: 12,
350+
},
338351
{ ...layers[14].toJSON(), inputLayerIndex: 13 },
339352
],
340353
outputLayerIndex: 14,
341354
sizes: [1, 3, 1],
342-
type: 'Recurrent'
355+
type: 'Recurrent',
343356
});
344357
});
345358
});

__tests__/recurrent_deprecated/gru.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { GRU } from '../../src/recurrent/gru';
2-
import { IMatrixJSON } from '../../src/recurrent/matrix';
2+
import { IMatrixJSON, Matrix } from '../../src/recurrent/matrix';
33
import { RNN } from '../../src/recurrent/rnn';
44
import { DataFormatter } from '../../src/utilities/data-formatter';
55

@@ -106,7 +106,7 @@ describe('GRU', () => {
106106
});
107107
const json = net.toJSON();
108108

109-
function compare(left: IMatrixJSON, right: IMatrixJSON) {
109+
function compare(left: IMatrixJSON, right: Matrix) {
110110
left.weights.forEach((value, i) => {
111111
expect(value).toBe(right.weights[i]);
112112
});

__tests__/recurrent_deprecated/rnn-time-step.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ import { LSTMTimeStep } from '../../src/recurrent/lstm-time-step';
33
import { Matrix } from '../../src/recurrent/matrix';
44
import { Equation } from '../../src/recurrent/matrix/equation';
55
import { IRNNStatus } from '../../src/recurrent/rnn';
6-
import { RNNTimeStep } from '../../src/recurrent/rnn-time-step';
6+
import {
7+
RNNTimeStep,
8+
IRNNTimeStepJSON,
9+
} from '../../src/recurrent/rnn-time-step';
710

811
// TODO: break out LSTMTimeStep into its own tests
912

@@ -18,7 +21,7 @@ describe('RNNTimeStep', () => {
1821
fromJSONSpy.mockRestore();
1922
});
2023
it('calls this.fromJSON with this value', () => {
21-
const json = {
24+
const json: IRNNTimeStepJSON = {
2225
type: 'RNNTimeStep',
2326
options: {
2427
inputSize: 1,
@@ -33,21 +36,21 @@ describe('RNNTimeStep', () => {
3336
},
3437
hiddenLayers: [
3538
{
36-
weight: { rows: 1, columns: 1, weights: Float32Array.from([1]) },
39+
weight: { rows: 1, columns: 1, weights: [1] },
3740
transition: {
3841
rows: 1,
3942
columns: 1,
40-
weights: Float32Array.from([1]),
43+
weights: [1],
4144
},
42-
bias: { rows: 1, columns: 1, weights: Float32Array.from([1]) },
45+
bias: { rows: 1, columns: 1, weights: [1] },
4346
},
4447
],
4548
outputConnector: {
4649
rows: 1,
4750
columns: 1,
48-
weights: Float32Array.from([1]),
51+
weights: [1],
4952
},
50-
output: { rows: 1, columns: 1, weights: Float32Array.from([1]) },
53+
output: { rows: 1, columns: 1, weights: [1] },
5154
inputLookup: { a: 0 },
5255
inputLookupLength: 1,
5356
outputLookup: { a: 0 },

__tests__/utilities/layer-from-json.ts

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
import { Add, Convolution, ILayerJSON, RecurrentZeros, Sigmoid, Target } from '../../src/layer';
1+
import {
2+
Add,
3+
Convolution,
4+
ILayerJSON,
5+
RecurrentZeros,
6+
Sigmoid,
7+
Target,
8+
} from '../../src/layer';
29
import { layerFromJSON } from '../../src/utilities/layer-from-json';
310
import { mockLayer } from '../test-utils';
411

@@ -37,7 +44,9 @@ describe('layerFromJSON', () => {
3744
height: 1,
3845
depth: 1,
3946
});
40-
expect(layerFromJSON(jsonLayer, inputLayer)).toEqual(new Convolution(jsonLayer, inputLayer));
47+
expect(layerFromJSON(jsonLayer, inputLayer)).toEqual(
48+
new Convolution(jsonLayer, inputLayer)
49+
);
4150
});
4251
});
4352
describe('when used with a Activation layer json', () => {
@@ -53,7 +62,9 @@ describe('layerFromJSON', () => {
5362
});
5463
it('should return that type instantiated', () => {
5564
const inputLayer1 = mockLayer();
56-
expect(layerFromJSON(jsonLayer, inputLayer1)).toEqual(new Sigmoid(inputLayer1, jsonLayer));
65+
expect(layerFromJSON(jsonLayer, inputLayer1)).toEqual(
66+
new Sigmoid(inputLayer1, jsonLayer)
67+
);
5768
});
5869
});
5970
describe('when used with a Operator layer json', () => {
@@ -81,12 +92,16 @@ describe('layerFromJSON', () => {
8192
});
8293
it('fails if inputLayer2 falsey', () => {
8394
const inputLayer1 = mockLayer();
84-
expect(() => layerFromJSON(jsonLayer, inputLayer1)).toThrow('inputLayer2 missing');
95+
expect(() => layerFromJSON(jsonLayer, inputLayer1)).toThrow(
96+
'inputLayer2 missing'
97+
);
8598
});
8699
it('should return that type instantiated', () => {
87100
const inputLayer1 = mockLayer();
88101
const inputLayer2 = mockLayer();
89-
expect(layerFromJSON(jsonLayer, inputLayer1, inputLayer2)).toEqual(new Add(inputLayer1, inputLayer2, jsonLayer));
102+
expect(layerFromJSON(jsonLayer, inputLayer1, inputLayer2)).toEqual(
103+
new Add(inputLayer1, inputLayer2, jsonLayer)
104+
);
90105
});
91106
});
92107
describe('when used with a TargetType layer json', () => {
@@ -102,7 +117,9 @@ describe('layerFromJSON', () => {
102117
});
103118
it('should return that type instantiated', () => {
104119
const inputLayer = mockLayer();
105-
expect(layerFromJSON(jsonLayer, inputLayer)).toEqual(new Target(jsonLayer, inputLayer));
120+
expect(layerFromJSON(jsonLayer, inputLayer)).toEqual(
121+
new Target(jsonLayer, inputLayer)
122+
);
106123
});
107124
});
108125
});

src/layer/dropout.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ export class Dropout extends Filter {
5959
dropouts: KernelOutput | null;
6060
predictKernelMap: IKernelMapRunShortcut<ISubKernelObject> | null = null;
6161
settings: Partial<IDropoutSettings>;
62-
constructor(inputLayer: ILayer, settings?: Partial<IDropoutSettings>) {
63-
super(inputLayer);
62+
constructor(inputLayer: ILayer, settings: Partial<IDropoutSettings> = {}) {
63+
super(settings, inputLayer);
6464
this.settings = { ...dropoutDefaults, ...settings };
6565
this.dropouts = null;
6666
this.validate();

src/layer/fully-connected.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ export class FullyConnected extends Filter {
174174
compareInputDeltasKernel: IKernelRunShortcut | null = null;
175175
compareBiasesKernel: IKernelRunShortcut | null = null;
176176
constructor(settings: IFullyConnectedDefaultSettings, inputLayer: ILayer) {
177-
super(inputLayer);
177+
super(settings, inputLayer);
178178
this.settings = { ...settings };
179179
this.validate();
180180

src/recurrent/matrix/index.ts

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,7 @@ export class Matrix {
8080
};
8181
}
8282

83-
static fromJSON(json: {
84-
rows: number;
85-
columns: number;
86-
weights: number[];
87-
}): Matrix {
83+
static fromJSON(json: IMatrixJSON): Matrix {
8884
const matrix = new Matrix(json.rows, json.columns);
8985

9086
for (let i = 0, max = json.rows * json.columns; i < max; i++) {

src/recurrent/rnn.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ export interface IRNNJSONOptions {
5858
export interface IRNNTrainingOptions {
5959
iterations: number;
6060
errorThresh: number;
61-
log: boolean | ((status: INeuralNetworkState) => void);
61+
log: boolean | ((status: string) => void);
6262
logPeriod: number;
6363
learningRate: number;
6464
callback?: (status: IRNNStatus) => void;
@@ -69,7 +69,7 @@ export interface IRNNTrainingOptions {
6969
export interface IRNNJSONTrainOptions {
7070
iterations: number;
7171
errorThresh: number;
72-
log: boolean | ((status: INeuralNetworkState) => void);
72+
log: boolean | ((status: string) => void);
7373
logPeriod: number;
7474
learningRate: number;
7575
callback?: (status: IRNNStatus) => void;

0 commit comments

Comments
 (0)