Skip to content

Commit 3164cfd

Browse files
authored
Merge pull request #249 from weaviate/#222/support-updating-generative-and-reranker
Improve generative/reranker config UX
2 parents ef950c4 + 892ec12 commit 3164cfd

File tree

6 files changed

+160
-14
lines changed

6 files changed

+160
-14
lines changed

ci/docker-compose.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ services:
2121
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true'
2222
PERSISTENCE_DATA_PATH: '/var/lib/weaviate'
2323
DEFAULT_VECTORIZER_MODULE: 'text2vec-contextionary'
24-
ENABLE_MODULES: text2vec-contextionary,backup-filesystem,img2vec-neural
24+
ENABLE_MODULES: text2vec-contextionary,backup-filesystem,img2vec-neural,generative-cohere,reranker-cohere
2525
BACKUP_FILESYSTEM_PATH: "/tmp/backups"
2626
CLUSTER_GOSSIP_BIND_PORT: "7100"
2727
CLUSTER_DATA_BIND_PORT: "7101"

src/collections/config/classes.ts

+43-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { WeaviateInvalidInputError } from '../../errors.js';
33
import {
44
WeaviateClass,
55
WeaviateInvertedIndexConfig,
6+
WeaviateModuleConfig,
67
WeaviateMultiTenancyConfig,
78
WeaviateReplicationConfig,
89
WeaviateVectorIndexConfig,
@@ -17,7 +18,15 @@ import {
1718
VectorIndexConfigFlatUpdate,
1819
VectorIndexConfigHNSWUpdate,
1920
} from '../configure/types/index.js';
20-
import { CollectionConfigUpdate, VectorIndexType } from './types/index.js';
21+
import {
22+
CollectionConfigUpdate,
23+
GenerativeConfig,
24+
GenerativeSearch,
25+
ModuleConfig,
26+
Reranker,
27+
RerankerConfig,
28+
VectorIndexType,
29+
} from './types/index.js';
2130

2231
export class MergeWithExisting {
2332
static schema(
@@ -27,6 +36,8 @@ export class MergeWithExisting {
2736
): WeaviateClass {
2837
if (update === undefined) return current;
2938
if (update.description !== undefined) current.description = update.description;
39+
if (update.generative !== undefined)
40+
current.moduleConfig = MergeWithExisting.generative(current.moduleConfig, update.generative);
3041
if (update.invertedIndex !== undefined)
3142
current.invertedIndexConfig = MergeWithExisting.invertedIndex(
3243
current.invertedIndexConfig,
@@ -42,6 +53,8 @@ export class MergeWithExisting {
4253
current.replicationConfig!,
4354
update.replication
4455
);
56+
if (update.reranker !== undefined)
57+
current.moduleConfig = MergeWithExisting.reranker(current.moduleConfig, update.reranker);
4558
if (update.vectorizers !== undefined) {
4659
if (Array.isArray(update.vectorizers)) {
4760
current.vectorConfig = MergeWithExisting.vectors(current.vectorConfig, update.vectorizers);
@@ -61,6 +74,35 @@ export class MergeWithExisting {
6174
return current;
6275
}
6376

77+
static generative(
78+
current: WeaviateModuleConfig,
79+
update: ModuleConfig<GenerativeSearch, GenerativeConfig>
80+
): WeaviateModuleConfig {
81+
if (current === undefined) throw Error('Module config is missing from the class schema.');
82+
if (update === undefined) return current;
83+
const generative = update.name === 'generative-azure-openai' ? 'generative-openai' : update.name;
84+
const currentGenerative = current[generative] as Record<string, any>;
85+
current[generative] = {
86+
...currentGenerative,
87+
...update.config,
88+
};
89+
return current;
90+
}
91+
92+
static reranker(
93+
current: WeaviateModuleConfig,
94+
update: ModuleConfig<Reranker, RerankerConfig>
95+
): WeaviateModuleConfig {
96+
if (current === undefined) throw Error('Module config is missing from the class schema.');
97+
if (update === undefined) return current;
98+
const reranker = current[update.name] as Record<string, any>;
99+
current[update.name] = {
100+
...reranker,
101+
...update.config,
102+
};
103+
return current;
104+
}
105+
64106
static invertedIndex(
65107
current: WeaviateInvertedIndexConfig,
66108
update: InvertedIndexConfigUpdate

src/collections/config/integration.test.ts

+52
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import { WeaviateUnsupportedFeatureError } from '../../errors.js';
33
import weaviate, { WeaviateClient, weaviateV2 } from '../../index.js';
44
import {
5+
GenerativeCohereConfig,
6+
ModuleConfig,
57
MultiTenancyConfig,
68
PropertyConfig,
9+
RerankerCohereConfig,
710
VectorIndexConfigDynamic,
811
VectorIndexConfigHNSW,
912
} from './types/index.js';
@@ -621,4 +624,53 @@ describe('Testing of the collection.config namespace', () => {
621624
expect(config.vectorizers.default.indexType).toEqual('hnsw');
622625
expect(config.vectorizers.default.vectorizer.name).toEqual('none');
623626
});
627+
628+
it('should be able to update the generative & reranker configs of a collection', async () => {
629+
if ((await client.getWeaviateVersion()).isLowerThan(1, 25, 0)) {
630+
console.warn('Skipping test because Weaviate version is lower than 1.25.0');
631+
return;
632+
}
633+
const collectionName = 'TestCollectionConfigUpdateGenerative';
634+
const collection = client.collections.get(collectionName);
635+
await client.collections.create({
636+
name: collectionName,
637+
vectorizers: weaviate.configure.vectorizer.none(),
638+
});
639+
let config = await collection.config.get();
640+
expect(config.generative).toBeUndefined();
641+
642+
await collection.config.update({
643+
generative: weaviate.reconfigure.generative.cohere({
644+
model: 'model',
645+
}),
646+
});
647+
648+
config = await collection.config.get();
649+
expect(config.generative).toEqual<ModuleConfig<'generative-cohere', GenerativeCohereConfig>>({
650+
name: 'generative-cohere',
651+
config: {
652+
model: 'model',
653+
},
654+
});
655+
656+
await collection.config.update({
657+
reranker: weaviate.reconfigure.reranker.cohere({
658+
model: 'model',
659+
}),
660+
});
661+
662+
config = await collection.config.get();
663+
expect(config.generative).toEqual<ModuleConfig<'generative-cohere', GenerativeCohereConfig>>({
664+
name: 'generative-cohere',
665+
config: {
666+
model: 'model',
667+
},
668+
});
669+
expect(config.reranker).toEqual<ModuleConfig<'reranker-cohere', RerankerCohereConfig>>({
670+
name: 'reranker-cohere',
671+
config: {
672+
model: 'model',
673+
},
674+
});
675+
});
624676
});

src/collections/config/types/index.ts

+6-4
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import {
99
ReplicationConfigUpdate,
1010
VectorConfigUpdate,
1111
} from '../../configure/types/index.js';
12-
import { GenerativeConfig } from './generative.js';
13-
import { RerankerConfig } from './reranker.js';
12+
import { GenerativeConfig, GenerativeSearch } from './generative.js';
13+
import { Reranker, RerankerConfig } from './reranker.js';
1414
import { VectorIndexType } from './vectorIndex.js';
1515
import { VectorConfig } from './vectorizer.js';
1616

@@ -93,22 +93,24 @@ export type ShardingConfig = {
9393
export type CollectionConfig = {
9494
name: string;
9595
description?: string;
96-
generative?: GenerativeConfig;
96+
generative?: ModuleConfig<GenerativeSearch, GenerativeConfig>;
9797
invertedIndex: InvertedIndexConfig;
9898
multiTenancy: MultiTenancyConfig;
9999
properties: PropertyConfig[];
100100
references: ReferenceConfig[];
101101
replication: ReplicationConfig;
102-
reranker?: RerankerConfig;
102+
reranker?: ModuleConfig<Reranker, RerankerConfig>;
103103
sharding: ShardingConfig;
104104
vectorizers: VectorConfig;
105105
};
106106

107107
export type CollectionConfigUpdate = {
108108
description?: string;
109+
generative?: ModuleConfig<GenerativeSearch, GenerativeConfig>;
109110
invertedIndex?: InvertedIndexConfigUpdate;
110111
multiTenancy?: MultiTenancyConfigUpdate;
111112
replication?: ReplicationConfigUpdate;
113+
reranker?: ModuleConfig<Reranker, RerankerConfig>;
112114
vectorizers?:
113115
| VectorConfigUpdate<undefined, VectorIndexType>
114116
| VectorConfigUpdate<string, VectorIndexType>[];

src/collections/config/unit.test.ts

+56-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import {
22
WeaviateInvertedIndexConfig,
3+
WeaviateModuleConfig,
34
WeaviateMultiTenancyConfig,
45
WeaviateVectorsConfig,
56
} from '../../openapi/types';
67
import { MergeWithExisting } from './classes';
8+
import { GenerativeCohereConfig, RerankerCohereConfig } from './types';
79

810
describe('Unit testing of the MergeWithExisting class', () => {
11+
const deepCopy = (config: any) => JSON.parse(JSON.stringify(config));
12+
913
const invertedIndex: WeaviateInvertedIndexConfig = {
1014
bm25: {
1115
b: 0.8,
@@ -62,7 +66,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
6266
};
6367

6468
it('should merge a full invertedIndexUpdate with existing schema', () => {
65-
const merged = MergeWithExisting.invertedIndex(JSON.parse(JSON.stringify(invertedIndex)), {
69+
const merged = MergeWithExisting.invertedIndex(deepCopy(invertedIndex), {
6670
bm25: {
6771
b: 0.9,
6872
k1: 1.4,
@@ -122,8 +126,20 @@ describe('Unit testing of the MergeWithExisting class', () => {
122126
autoTenantCreation: false,
123127
};
124128

129+
const moduleConfig: WeaviateModuleConfig = {
130+
'generative-cohere': {
131+
kProperty: 0.1,
132+
model: 'model',
133+
maxTokensProperty: '5',
134+
returnLikelihoodsProperty: 'likelihoods',
135+
stopSequencesProperty: ['and'],
136+
temperatureProperty: 5.2,
137+
},
138+
'reranker-cohere': {},
139+
};
140+
125141
it('should merge a partial invertedIndexUpdate with existing schema', () => {
126-
const merged = MergeWithExisting.invertedIndex(JSON.parse(JSON.stringify(invertedIndex)), {
142+
const merged = MergeWithExisting.invertedIndex(deepCopy(invertedIndex), {
127143
bm25: {
128144
b: 0.9,
129145
},
@@ -147,7 +163,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
147163
});
148164

149165
it('should merge a no quantizer HNSW vectorIndexConfig with existing schema', () => {
150-
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [
166+
const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [
151167
{
152168
name: 'name',
153169
vectorIndex: {
@@ -196,7 +212,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
196212
});
197213

198214
it('should merge a PQ quantizer HNSW vectorIndexConfig with existing schema', () => {
199-
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [
215+
const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [
200216
{
201217
name: 'name',
202218
vectorIndex: {
@@ -245,7 +261,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
245261
});
246262

247263
it('should merge a BQ quantizer HNSW vectorIndexConfig with existing schema', () => {
248-
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [
264+
const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [
249265
{
250266
name: 'name',
251267
vectorIndex: {
@@ -280,7 +296,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
280296
});
281297

282298
it('should merge a SQ quantizer HNSW vectorIndexConfig with existing schema', () => {
283-
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(hnswVectorConfig)), [
299+
const merged = MergeWithExisting.vectors(deepCopy(hnswVectorConfig), [
284300
{
285301
name: 'name',
286302
vectorIndex: {
@@ -317,7 +333,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
317333
});
318334

319335
it('should merge a BQ quantizer Flat vectorIndexConfig with existing schema', () => {
320-
const merged = MergeWithExisting.vectors(JSON.parse(JSON.stringify(flatVectorConfig)), [
336+
const merged = MergeWithExisting.vectors(deepCopy(flatVectorConfig), [
321337
{
322338
name: 'name',
323339
vectorIndex: {
@@ -353,7 +369,7 @@ describe('Unit testing of the MergeWithExisting class', () => {
353369
});
354370

355371
it('should merge full multi tenancy config with existing schema', () => {
356-
const merged = MergeWithExisting.multiTenancy(JSON.parse(JSON.stringify(multiTenancyConfig)), {
372+
const merged = MergeWithExisting.multiTenancy(deepCopy(multiTenancyConfig), {
357373
autoTenantActivation: true,
358374
autoTenantCreation: true,
359375
});
@@ -363,4 +379,36 @@ describe('Unit testing of the MergeWithExisting class', () => {
363379
autoTenantCreation: true,
364380
});
365381
});
382+
383+
it('should merge a generative config with existing schema', () => {
384+
const merged = MergeWithExisting.generative(deepCopy(moduleConfig), {
385+
name: 'generative-cohere',
386+
config: {
387+
kProperty: 0.2,
388+
} as GenerativeCohereConfig,
389+
});
390+
expect(merged).toEqual({
391+
...moduleConfig,
392+
'generative-cohere': {
393+
...(moduleConfig['generative-cohere'] as any),
394+
kProperty: 0.2,
395+
} as GenerativeCohereConfig,
396+
});
397+
});
398+
399+
it('should merge a reranker config with existing schema', () => {
400+
const merged = MergeWithExisting.reranker(deepCopy(moduleConfig), {
401+
name: 'reranker-cohere',
402+
config: {
403+
model: 'other',
404+
} as RerankerCohereConfig,
405+
});
406+
expect(merged).toEqual({
407+
...moduleConfig,
408+
'reranker-cohere': {
409+
...(moduleConfig['reranker-cohere'] as any),
410+
model: 'other',
411+
} as RerankerCohereConfig,
412+
});
413+
});
366414
});

src/collections/configure/index.ts

+2
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ const reconfigure = {
261261
autoTenantCreation: options.autoTenantCreation,
262262
};
263263
},
264+
generative: configure.generative,
265+
reranker: configure.reranker,
264266
};
265267

266268
export {

0 commit comments

Comments
 (0)