Skip to content

Commit 6a047b0

Browse files
authored
Support deepseek (#41)
1 parent 608e446 commit 6a047b0

File tree

5 files changed

+130
-34
lines changed

5 files changed

+130
-34
lines changed

README.md

+22
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,28 @@ AIScript excels in these scenarios:
113113

114114
Check out the [examples](./examples) directory for more sample code.
115115

116+
## Supported AI Models
117+
118+
AIScript supports the following AI models:
119+
120+
- [x] OpenAI ((uses `OPENAI_API_KEY` environment variable by default))
121+
- [x] DeepSeek
122+
- [ ] Anthropic
123+
124+
Configuration by `project.toml`:
125+
126+
```toml
127+
# use OpenAI
128+
[ai.openai]
129+
api_key = "YOUR_API_KEY"
130+
model = "gpt-3.5-turbo"
131+
132+
# or use DeepSeek
133+
[ai.deepseek]
134+
api_key = "YOUR_API_KEY"
135+
model = "deepseek-chat"
136+
```
137+
116138
## Roadmap
117139

118140
See our [roadmap](https://aiscript.dev/guide/contribution/roadmap) for upcoming features and improvements.

aiscript-vm/src/ai/agent.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use openai_api_rs::v1::{
88
ChatCompletionMessage, ChatCompletionMessageForResponse, ChatCompletionRequest, Content,
99
MessageRole, Tool, ToolCall, ToolChoiceType, ToolType,
1010
},
11-
common::GPT3_5_TURBO,
1211
types::{self, FunctionParameters, JSONSchemaDefine},
1312
};
1413
use tokio::runtime::Handle;
@@ -278,6 +277,8 @@ pub async fn _run_agent<'gc>(
278277
mut agent: Gc<'gc, Agent<'gc>>,
279278
args: Vec<Value<'gc>>,
280279
) -> Value<'gc> {
280+
use super::default_model;
281+
281282
let message = args[0];
282283
let debug = args[1].as_boolean();
283284
let mut history = Vec::new();
@@ -288,11 +289,11 @@ pub async fn _run_agent<'gc>(
288289
tool_calls: None,
289290
tool_call_id: None,
290291
});
291-
let mut client = super::openai_client();
292+
let mut client = super::openai_client(state.ai_config.as_ref());
292293
loop {
293294
let mut messages = vec![agent.get_instruction_message()];
294295
messages.extend(history.clone());
295-
let mut req = ChatCompletionRequest::new(GPT3_5_TURBO.to_string(), messages);
296+
let mut req = ChatCompletionRequest::new(default_model(state.ai_config.as_ref()), messages);
296297
let tools = agent.get_tools();
297298
if !tools.is_empty() {
298299
req = req

aiscript-vm/src/ai/mod.rs

+67-11
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@ mod prompt;
44
use std::env;
55

66
pub use agent::{Agent, run_agent};
7-
use openai_api_rs::v1::api::OpenAIClient;
7+
use openai_api_rs::v1::{api::OpenAIClient, common::GPT3_5_TURBO};
88
pub use prompt::{PromptConfig, prompt_with_config};
99

1010
use serde::Deserialize;
1111

12-
#[derive(Debug, Clone, Deserialize, Default)]
13-
pub struct AiConfig {
14-
pub openai: Option<ModelConfig>,
15-
pub anthropic: Option<ModelConfig>,
16-
pub deepseek: Option<ModelConfig>,
12+
const DEEPSEEK_API_ENDPOINT: &str = "https://api.deepseek.com/v1";
13+
const DEEPSEEK_V3: &str = "deepseek-chat";
14+
15+
#[derive(Debug, Clone, Deserialize)]
16+
pub enum AiConfig {
17+
#[serde(rename = "openai")]
18+
OpenAI(ModelConfig),
19+
#[serde(rename = "anthropic")]
20+
Anthropic(ModelConfig),
21+
#[serde(rename = "deepseek")]
22+
DeepSeek(ModelConfig),
1723
}
1824

1925
#[derive(Debug, Clone, Deserialize)]
@@ -22,10 +28,60 @@ pub struct ModelConfig {
2228
pub model: Option<String>,
2329
}
2430

31+
impl AiConfig {
32+
pub(crate) fn take_model(&mut self) -> Option<String> {
33+
match self {
34+
Self::OpenAI(ModelConfig { model, .. }) => model.take(),
35+
Self::Anthropic(ModelConfig { model, .. }) => model.take(),
36+
Self::DeepSeek(ModelConfig { model, .. }) => model.take(),
37+
}
38+
}
39+
40+
pub(crate) fn set_model(&mut self, m: String) {
41+
match self {
42+
Self::OpenAI(ModelConfig { model, .. }) => model.replace(m),
43+
Self::Anthropic(ModelConfig { model, .. }) => model.replace(m),
44+
Self::DeepSeek(ModelConfig { model, .. }) => model.replace(m),
45+
};
46+
}
47+
}
48+
2549
#[allow(unused)]
26-
pub(crate) fn openai_client() -> OpenAIClient {
27-
OpenAIClient::builder()
28-
.with_api_key(env::var("OPENAI_API_KEY").unwrap().to_string())
29-
.build()
30-
.unwrap()
50+
pub(crate) fn openai_client(config: Option<&AiConfig>) -> OpenAIClient {
51+
match config {
52+
None => OpenAIClient::builder()
53+
.with_api_key(env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
54+
.build()
55+
.unwrap(),
56+
Some(AiConfig::OpenAI(model_config)) => {
57+
let api_key = if model_config.api_key.is_empty() {
58+
env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")
59+
} else {
60+
model_config.api_key.clone()
61+
};
62+
OpenAIClient::builder()
63+
.with_api_key(api_key)
64+
.build()
65+
.unwrap()
66+
}
67+
Some(AiConfig::DeepSeek(ModelConfig { api_key, .. })) => OpenAIClient::builder()
68+
.with_endpoint(DEEPSEEK_API_ENDPOINT)
69+
.with_api_key(api_key)
70+
.build()
71+
.unwrap(),
72+
Some(AiConfig::Anthropic(_)) => unimplemented!("Anthropic API not yet supported"),
73+
}
74+
}
75+
76+
pub(crate) fn default_model(config: Option<&AiConfig>) -> String {
77+
match config {
78+
None => GPT3_5_TURBO.to_string(),
79+
Some(AiConfig::OpenAI(ModelConfig { model, .. })) => {
80+
model.clone().unwrap_or(GPT3_5_TURBO.to_string())
81+
}
82+
Some(AiConfig::DeepSeek(ModelConfig { model, .. })) => {
83+
model.clone().unwrap_or(DEEPSEEK_V3.to_string())
84+
}
85+
Some(AiConfig::Anthropic(_)) => unimplemented!("Anthropic API not yet supported"),
86+
}
3187
}

aiscript-vm/src/ai/prompt.rs

+26-15
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use openai_api_rs::v1::common::GPT3_5_TURBO;
22
use tokio::runtime::Handle;
33

4+
use super::{AiConfig, ModelConfig, default_model};
5+
46
pub struct PromptConfig {
57
pub input: String,
6-
pub model: Option<String>,
8+
pub ai_config: Option<AiConfig>,
79
pub max_tokens: Option<i64>,
810
pub temperature: Option<f64>,
911
pub system_prompt: Option<String>,
@@ -13,27 +15,42 @@ impl Default for PromptConfig {
1315
fn default() -> Self {
1416
Self {
1517
input: String::new(),
16-
model: Some(GPT3_5_TURBO.to_string()),
18+
ai_config: Some(AiConfig::OpenAI(ModelConfig {
19+
api_key: Default::default(),
20+
model: Some(GPT3_5_TURBO.to_string()),
21+
})),
1722
max_tokens: Default::default(),
1823
temperature: Default::default(),
1924
system_prompt: Default::default(),
2025
}
2126
}
2227
}
2328

29+
impl PromptConfig {
30+
fn take_model(&mut self) -> String {
31+
self.ai_config
32+
.as_mut()
33+
.and_then(|config| config.take_model())
34+
.unwrap_or_else(|| default_model(self.ai_config.as_ref()))
35+
}
36+
37+
pub(crate) fn set_model(&mut self, model: String) {
38+
if let Some(config) = self.ai_config.as_mut() {
39+
config.set_model(model);
40+
}
41+
}
42+
}
43+
2444
#[cfg(feature = "ai_test")]
2545
async fn _prompt_with_config(config: PromptConfig) -> String {
2646
return format!("AI: {}", config.input);
2747
}
2848

2949
#[cfg(not(feature = "ai_test"))]
3050
async fn _prompt_with_config(mut config: PromptConfig) -> String {
31-
use openai_api_rs::v1::{
32-
chat_completion::{self, ChatCompletionRequest},
33-
common::GPT3_5_TURBO,
34-
};
35-
36-
let mut client = super::openai_client();
51+
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
52+
let mut client = super::openai_client(config.ai_config.as_ref());
53+
let model = config.take_model();
3754

3855
// Create system message if provided
3956
let mut messages = Vec::new();
@@ -57,13 +74,7 @@ async fn _prompt_with_config(mut config: PromptConfig) -> String {
5774
});
5875

5976
// Build the request
60-
let mut req = ChatCompletionRequest::new(
61-
config
62-
.model
63-
.take()
64-
.unwrap_or_else(|| GPT3_5_TURBO.to_string()),
65-
messages,
66-
);
77+
let mut req = ChatCompletionRequest::new(model, messages);
6778

6879
if let Some(max_tokens) = config.max_tokens {
6980
req.max_tokens = Some(max_tokens);

aiscript-vm/src/vm/state.rs

+11-5
Original file line numberDiff line numberDiff line change
@@ -1013,15 +1013,21 @@ impl<'gc> State<'gc> {
10131013
let result = match value {
10141014
// Simple string case
10151015
Value::String(s) => {
1016-
let input = s.to_str().unwrap().to_string();
1017-
ai::prompt_with_config(PromptConfig {
1018-
input,
1016+
let mut config = PromptConfig {
1017+
input: s.to_str().unwrap().to_string(),
10191018
..Default::default()
1020-
})
1019+
};
1020+
if let Some(ai_cfg) = &self.ai_config {
1021+
config.ai_config = Some(ai_cfg.clone());
1022+
}
1023+
ai::prompt_with_config(config)
10211024
}
10221025
// Object config case
10231026
Value::Object(obj) => {
10241027
let mut config = PromptConfig::default();
1028+
if let Some(ai_cfg) = &self.ai_config {
1029+
config.ai_config = Some(ai_cfg.clone());
1030+
}
10251031
let obj_ref = obj.borrow();
10261032

10271033
// Extract input (required)
@@ -1039,7 +1045,7 @@ impl<'gc> State<'gc> {
10391045
if let Some(Value::String(model)) =
10401046
obj_ref.fields.get(&self.intern(b"model"))
10411047
{
1042-
config.model = Some(model.to_str().unwrap().to_string());
1048+
config.set_model(model.to_str().unwrap().to_string());
10431049
}
10441050

10451051
// Extract max_tokens (optional)

0 commit comments

Comments
 (0)