Skip to content

Commit 608e446

Browse files
authored
Support ai config (#42)
1 parent 7b74dde commit 608e446

File tree

7 files changed

+42
-8
lines changed

7 files changed

+42
-8
lines changed

aiscript-runtime/src/config/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::{env, fmt::Display, fs, ops::Deref, path::Path, sync::OnceLock};
33
use auth::AuthConfig;
44
use serde::Deserialize;
55

6+
use aiscript_vm::AiConfig;
67
use db::DatabaseConfig;
78
pub use sso::{SsoConfig, get_sso_fields};
89

@@ -64,6 +65,8 @@ impl AsRef<str> for EnvString {
6465

6566
#[derive(Debug, Deserialize, Default)]
6667
pub struct Config {
68+
#[serde(default)]
69+
pub ai: Option<AiConfig>,
6770
#[serde(default)]
6871
pub database: DatabaseConfig,
6972
#[serde(default)]
@@ -116,9 +119,9 @@ impl Config {
116119
}
117120
}
118121

119-
pub fn load(path: &str) -> &Config {
122+
pub fn load() -> &'static Config {
120123
CONFIG.get_or_init(|| {
121-
Config::new(path).unwrap_or_else(|e| {
124+
Config::new("project.toml").unwrap_or_else(|e| {
122125
eprintln!("Error loading config file: {}", e);
123126
Config::default()
124127
})

aiscript-runtime/src/endpoint.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,13 @@ impl Future for RequestProcessor {
499499
let redis_connection = self.endpoint.redis_connection.clone();
500500
let handle: JoinHandle<Result<ReturnValue, VmError>> =
501501
task::spawn_blocking(move || {
502-
let mut vm =
503-
Vm::new(pg_connection, sqlite_connection, redis_connection);
502+
let ai_config = Config::load().ai.clone();
503+
let mut vm = Vm::new(
504+
pg_connection,
505+
sqlite_connection,
506+
redis_connection,
507+
ai_config,
508+
);
504509
if let Some(fields) = sso_fields {
505510
vm.inject_sso_instance(fields);
506511
}

aiscript-vm/src/ai/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,21 @@ pub use agent::{Agent, run_agent};
77
use openai_api_rs::v1::api::OpenAIClient;
88
pub use prompt::{PromptConfig, prompt_with_config};
99

10+
use serde::Deserialize;
11+
12+
#[derive(Debug, Clone, Deserialize, Default)]
13+
pub struct AiConfig {
14+
pub openai: Option<ModelConfig>,
15+
pub anthropic: Option<ModelConfig>,
16+
pub deepseek: Option<ModelConfig>,
17+
}
18+
19+
#[derive(Debug, Clone, Deserialize)]
20+
pub struct ModelConfig {
21+
pub api_key: String,
22+
pub model: Option<String>,
23+
}
24+
1025
#[allow(unused)]
1126
pub(crate) fn openai_client() -> OpenAIClient {
1227
OpenAIClient::builder()

aiscript-vm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use std::collections::HashMap;
1616
use std::fmt::Display;
1717
use std::ops::Deref;
1818

19+
pub use ai::AiConfig;
1920
use aiscript_arena::Collect;
2021
use aiscript_arena::Mutation;
2122
pub(crate) use aiscript_lexer as lexer;

aiscript-vm/src/vm/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub use state::State;
66

77
use crate::{
88
ReturnValue, Value,
9+
ai::AiConfig,
910
ast::ChunkId,
1011
builtins, stdlib,
1112
string::{InternedString, InternedStringSet},
@@ -35,7 +36,7 @@ impl Display for VmError {
3536

3637
impl Default for Vm {
3738
fn default() -> Self {
38-
Self::new(None, None, None)
39+
Self::new(None, None, None, None)
3940
}
4041
}
4142

@@ -48,13 +49,15 @@ impl Vm {
4849
pg_connection: Option<PgPool>,
4950
sqlite_connection: Option<SqlitePool>,
5051
redis_connection: Option<redis::aio::MultiplexedConnection>,
52+
ai_config: Option<AiConfig>,
5153
) -> Self {
5254
let mut vm = Vm {
5355
arena: Arena::<Rootable![State<'_>]>::new(|mc| {
5456
let mut state = State::new(mc);
5557
state.pg_connection = pg_connection;
5658
state.sqlite_connection = sqlite_connection;
5759
state.redis_connection = redis_connection;
60+
state.ai_config = ai_config;
5861
state
5962
}),
6063
};

aiscript-vm/src/vm/state.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use sqlx::{PgPool, SqlitePool};
1515

1616
use crate::{
1717
NativeFn, OpCode, ReturnValue, Value,
18-
ai::{self, PromptConfig},
18+
ai::{self, AiConfig, PromptConfig},
1919
ast::{ChunkId, Visibility},
2020
builtins::BuiltinMethods,
2121
module::{ModuleKind, ModuleManager, ModuleSource},
@@ -110,6 +110,7 @@ pub struct State<'gc> {
110110
pub pg_connection: Option<PgPool>,
111111
pub sqlite_connection: Option<SqlitePool>,
112112
pub redis_connection: Option<redis::aio::MultiplexedConnection>,
113+
pub ai_config: Option<AiConfig>,
113114
}
114115

115116
unsafe impl Collect for State<'_> {
@@ -152,6 +153,7 @@ impl<'gc> State<'gc> {
152153
pg_connection: None,
153154
sqlite_connection: None,
154155
redis_connection: None,
156+
ai_config: None,
155157
}
156158
}
157159

aiscript/src/main.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ enum Commands {
4848
#[tokio::main]
4949
async fn main() {
5050
dotenv::dotenv().ok();
51-
Config::load("project.toml");
51+
let config = Config::load();
5252

5353
let cli = AIScriptCli::parse();
5454
match cli.command {
@@ -69,7 +69,12 @@ async fn main() {
6969
let sqlite_connection = aiscript_runtime::get_sqlite_connection().await;
7070
let redis_connection = aiscript_runtime::get_redis_connection().await;
7171
task::spawn_blocking(move || {
72-
let mut vm = Vm::new(pg_connection, sqlite_connection, redis_connection);
72+
let mut vm = Vm::new(
73+
pg_connection,
74+
sqlite_connection,
75+
redis_connection,
76+
config.ai.clone(),
77+
);
7378
vm.run_file(path);
7479
})
7580
.await // must use await to wait for the thread to finish

0 commit comments

Comments
 (0)