Skip to content

Commit 27d6b26

Browse files
authored
Merge pull request #287 from weaviate/modules/nvidia-reranker
Adds `baseURL` parameter to module configurations and an `reranker-nvidia` reranker.
2 parents 51b1737 + 626b5a9 commit 27d6b26

File tree

5 files changed

+136
-0
lines changed

5 files changed

+136
-0
lines changed

src/collections/config/types/generative.ts

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export type GenerativeAnthropicConfig = {
2525
};
2626

2727
export type GenerativeAnyscaleConfig = {
28+
baseURL?: string;
2829
model?: string;
2930
temperature?: number;
3031
};
@@ -54,6 +55,7 @@ export type GenerativeFriendliAIConfig = {
5455
};
5556

5657
export type GenerativeMistralConfig = {
58+
baseURL?: string;
5759
maxTokens?: number;
5860
model?: string;
5961
temperature?: number;

src/collections/config/types/reranker.ts

+10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export type RerankerCohereConfig = {
55
};
66

77
export type RerankerVoyageAIConfig = {
8+
baseURL?: string;
89
model?: 'rerank-lite-1' | string;
910
};
1011

@@ -18,9 +19,15 @@ export type RerankerJinaAIConfig = {
1819
| string;
1920
};
2021

22+
export type RerankerNvidiaConfig = {
23+
baseURL?: string;
24+
model?: 'nvidia/rerank-qa-mistral-4b' | string;
25+
};
26+
2127
export type RerankerConfig =
2228
| RerankerCohereConfig
2329
| RerankerJinaAIConfig
30+
| RerankerNvidiaConfig
2431
| RerankerTransformersConfig
2532
| RerankerVoyageAIConfig
2633
| Record<string, any>
@@ -29,6 +36,7 @@ export type RerankerConfig =
2936
export type Reranker =
3037
| 'reranker-cohere'
3138
| 'reranker-jinaai'
39+
| 'reranker-nvidia'
3240
| 'reranker-transformers'
3341
| 'reranker-voyageai'
3442
| 'none'
@@ -38,6 +46,8 @@ export type RerankerConfigType<R> = R extends 'reranker-cohere'
3846
? RerankerCohereConfig
3947
: R extends 'reranker-jinaai'
4048
? RerankerJinaAIConfig
49+
: R extends 'reranker-nvidia'
50+
? RerankerNvidiaConfig
4151
: R extends 'reranker-transformers'
4252
? RerankerTransformersConfig
4353
: R extends 'reranker-voyageai'

src/collections/config/types/vectorizer.ts

+2
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,8 @@ export type Text2VecNvidiaConfig = {
390390
* See the [documentation](https://weaviate.io/developers/weaviate/model-providers/mistral/embeddings) for detailed usage.
391391
*/
392392
export type Text2VecMistralConfig = {
393+
/** The base URL to use where API requests should go. */
394+
baseURL?: string;
393395
/** The model to use. */
394396
model?: 'mistral-embed' | string;
395397
/** Whether to vectorize the collection name. */

src/collections/configure/reranker.ts

+17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import {
22
ModuleConfig,
33
RerankerCohereConfig,
44
RerankerJinaAIConfig,
5+
RerankerNvidiaConfig,
56
RerankerVoyageAIConfig,
67
} from '../config/types/index.js';
78

@@ -38,6 +39,22 @@ export default {
3839
config: config,
3940
};
4041
},
42+
/**
43+
* Create a `ModuleConfig<'reranker-nvidia', RerankerNvidiaConfig>` object for use when reranking using the `reranker-nvidia` module.
44+
*
45+
* See the [documentation](https://weaviate.io/developers/weaviate/model-providers/nvidia/reranker) for detailed usage.
46+
*
47+
* @param {RerankerNvidiaConfig} [config] The configuration for the `reranker-nvidia` module.
48+
* @returns {ModuleConfig<'reranker-nvidia', RerankerNvidiaConfig | undefined>} The configuration object.
49+
*/
50+
nvidia: (
51+
config?: RerankerNvidiaConfig
52+
): ModuleConfig<'reranker-nvidia', RerankerNvidiaConfig | undefined> => {
53+
return {
54+
name: 'reranker-nvidia',
55+
config: config,
56+
};
57+
},
4158
/**
4259
* Create a `ModuleConfig<'reranker-transformers', Record<string, never>>` object for use when reranking using the `reranker-transformers` module.
4360
*

src/collections/configure/unit.test.ts

+105
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ import {
1212
GenerativeOpenAIConfig,
1313
GenerativeXAIConfig,
1414
ModuleConfig,
15+
RerankerCohereConfig,
16+
RerankerJinaAIConfig,
17+
RerankerNvidiaConfig,
18+
RerankerTransformersConfig,
19+
RerankerVoyageAIConfig,
1520
VectorConfigCreate,
1621
} from '../types/index.js';
1722
import { configure } from './index.js';
@@ -1220,6 +1225,7 @@ describe('Unit testing of the vectorizer factory class', () => {
12201225

12211226
it('should create the correct Text2VecMistralConfig type with all values', () => {
12221227
const config = configure.vectorizer.text2VecMistral({
1228+
baseURL: 'base-url',
12231229
name: 'test',
12241230
model: 'model',
12251231
vectorizeCollectionName: true,
@@ -1233,6 +1239,7 @@ describe('Unit testing of the vectorizer factory class', () => {
12331239
vectorizer: {
12341240
name: 'text2vec-mistral',
12351241
config: {
1242+
baseURL: 'base-url',
12361243
model: 'model',
12371244
vectorizeCollectionName: true,
12381245
},
@@ -1567,12 +1574,14 @@ describe('Unit testing of the generative factory class', () => {
15671574

15681575
it('should create the correct GenerativeAnyscaleConfig type with all values', () => {
15691576
const config = configure.generative.anyscale({
1577+
baseURL: 'base-url',
15701578
model: 'model',
15711579
temperature: 0.5,
15721580
});
15731581
expect(config).toEqual<ModuleConfig<'generative-anyscale', GenerativeAnyscaleConfig | undefined>>({
15741582
name: 'generative-anyscale',
15751583
config: {
1584+
baseURL: 'base-url',
15761585
model: 'model',
15771586
temperature: 0.5,
15781587
},
@@ -1749,13 +1758,15 @@ describe('Unit testing of the generative factory class', () => {
17491758

17501759
it('should create the correct GenerativeMistralConfig type with all values', () => {
17511760
const config = configure.generative.mistral({
1761+
baseURL: 'base-url',
17521762
maxTokens: 100,
17531763
model: 'model',
17541764
temperature: 0.5,
17551765
});
17561766
expect(config).toEqual<ModuleConfig<'generative-mistral', GenerativeMistralConfig | undefined>>({
17571767
name: 'generative-mistral',
17581768
config: {
1769+
baseURL: 'base-url',
17591770
maxTokens: 100,
17601771
model: 'model',
17611772
temperature: 0.5,
@@ -1909,3 +1920,97 @@ describe('Unit testing of the generative factory class', () => {
19091920
});
19101921
});
19111922
});
1923+
1924+
describe('Unit testing of the reranker factory class', () => {
1925+
it('should create the correct RerankerCohereConfig type using required & default values', () => {
1926+
const config = configure.reranker.cohere();
1927+
expect(config).toEqual<ModuleConfig<'reranker-cohere', RerankerCohereConfig | undefined>>({
1928+
name: 'reranker-cohere',
1929+
config: undefined,
1930+
});
1931+
});
1932+
1933+
it('should create the correct RerankerCohereConfig type with all values', () => {
1934+
const config = configure.reranker.cohere({
1935+
model: 'model',
1936+
});
1937+
expect(config).toEqual<ModuleConfig<'reranker-cohere', RerankerCohereConfig | undefined>>({
1938+
name: 'reranker-cohere',
1939+
config: {
1940+
model: 'model',
1941+
},
1942+
});
1943+
});
1944+
1945+
it('should create the correct RerankerJinaAIConfig type using required & default values', () => {
1946+
const config = configure.reranker.jinaai();
1947+
expect(config).toEqual<ModuleConfig<'reranker-jinaai', RerankerJinaAIConfig | undefined>>({
1948+
name: 'reranker-jinaai',
1949+
config: undefined,
1950+
});
1951+
});
1952+
1953+
it('should create the correct RerankerJinaAIConfig type with all values', () => {
1954+
const config = configure.reranker.jinaai({
1955+
model: 'model',
1956+
});
1957+
expect(config).toEqual<ModuleConfig<'reranker-jinaai', RerankerJinaAIConfig | undefined>>({
1958+
name: 'reranker-jinaai',
1959+
config: {
1960+
model: 'model',
1961+
},
1962+
});
1963+
});
1964+
1965+
it('should create the correct RerankerNvidiaConfig type with required & default values', () => {
1966+
const config = configure.reranker.nvidia();
1967+
expect(config).toEqual<ModuleConfig<'reranker-nvidia', RerankerNvidiaConfig | undefined>>({
1968+
name: 'reranker-nvidia',
1969+
config: undefined,
1970+
});
1971+
});
1972+
1973+
it('should create the correct RerankerNvidiaConfig type with all values', () => {
1974+
const config = configure.reranker.nvidia({
1975+
baseURL: 'base-url',
1976+
model: 'model',
1977+
});
1978+
expect(config).toEqual<ModuleConfig<'reranker-nvidia', RerankerNvidiaConfig | undefined>>({
1979+
name: 'reranker-nvidia',
1980+
config: {
1981+
baseURL: 'base-url',
1982+
model: 'model',
1983+
},
1984+
});
1985+
});
1986+
1987+
it('should create the correct RerankerTransformersConfig type using required & default values', () => {
1988+
const config = configure.reranker.transformers();
1989+
expect(config).toEqual<ModuleConfig<'reranker-transformers', RerankerTransformersConfig>>({
1990+
name: 'reranker-transformers',
1991+
config: {},
1992+
});
1993+
});
1994+
1995+
it('should create the correct RerankerVoyageAIConfig with required & default values', () => {
1996+
const config = configure.reranker.voyageAI();
1997+
expect(config).toEqual<ModuleConfig<'reranker-voyageai', RerankerVoyageAIConfig | undefined>>({
1998+
name: 'reranker-voyageai',
1999+
config: undefined,
2000+
});
2001+
});
2002+
2003+
it('should create the correct RerankerVoyageAIConfig type with all values', () => {
2004+
const config = configure.reranker.voyageAI({
2005+
baseURL: 'base-url',
2006+
model: 'model',
2007+
});
2008+
expect(config).toEqual<ModuleConfig<'reranker-voyageai', RerankerVoyageAIConfig | undefined>>({
2009+
name: 'reranker-voyageai',
2010+
config: {
2011+
baseURL: 'base-url',
2012+
model: 'model',
2013+
},
2014+
});
2015+
});
2016+
});

0 commit comments

Comments
 (0)