Skip to content

Commit 17811b1

Browse files
committed
feat: add oauth to shared
1 parent f695cca commit 17811b1

File tree

4 files changed

+181
-11
lines changed

4 files changed

+181
-11
lines changed

Cargo.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,21 @@ metrics-exporter-prometheus = "0.17.0"
3535
# Slot Calc
3636
chrono = "0.4.40"
3737

38+
# OAuth
39+
oauth2 = { version = "5.0.0", optional = true }
40+
tokio = { version = "1.36.0", optional = true }
41+
3842
# Other
3943
thiserror = "2.0.11"
4044
alloy = { version = "0.12.6", optional = true, default-features = false, features = ["std", "signer-aws", "signer-local", "consensus", "network"] }
4145
serde = { version = "1", features = ["derive"] }
4246
async-trait = { version = "0.1.80", optional = true }
4347

48+
4449
# AWS
4550
aws-config = { version = "1.1.7", optional = true }
4651
aws-sdk-kms = { version = "1.15.0", optional = true }
52+
reqwest = { version = "0.12.15", optional = true }
4753

4854
[dev-dependencies]
4955
ajj = "0.3.1"
@@ -55,4 +61,4 @@ tokio = { version = "1.43.0", features = ["macros"] }
5561
[features]
5662
default = ["alloy"]
5763
alloy = ["dep:alloy", "dep:async-trait", "dep:aws-config", "dep:aws-sdk-kms"]
58-
perms = []
64+
perms = ["dep:oauth2", "dep:tokio", "dep:reqwest"]

src/lib.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,28 @@ pub mod perms;
1818

1919
/// Signet utilities.
2020
pub mod utils {
21-
/// Prometheus metrics utilities.
22-
pub mod metrics;
23-
24-
/// OpenTelemetry utilities.
25-
pub mod otlp;
21+
/// Slot calculator for determining the current slot and timepoint within a
22+
/// slot.
23+
pub mod calc;
2624

2725
/// [`FromEnv`], [`FromEnvVar`] traits and related utilities.
2826
///
2927
/// [`FromEnv`]: from_env::FromEnv
3028
/// [`FromEnvVar`]: from_env::FromEnvVar
3129
pub mod from_env;
3230

33-
/// Tracing utilities.
34-
pub mod tracing;
31+
/// Prometheus metrics utilities.
32+
pub mod metrics;
3533

36-
/// Slot calculator for determining the current slot and timepoint within a
37-
/// slot.
38-
pub mod calc;
34+
/// OpenTelemetry utilities.
35+
pub mod otlp;
3936

4037
#[cfg(feature = "alloy")]
4138
/// Signer using a local private key or AWS KMS key.
4239
pub mod signer;
40+
41+
/// Tracing utilities.
42+
pub mod tracing;
4343
}
4444

4545
/// Re-exports of common dependencies.

src/perms/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ pub use builders::{Builder, BuilderPermissionError, Builders, BuildersEnvError};
33

44
pub(crate) mod config;
55
pub use config::{SlotAuthzConfig, SlotAuthzConfigEnvError};
6+
7+
pub(crate) mod oauth;
8+
pub use oauth::{Authenticator, SharedToken};

src/perms/oauth.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//! Service responsible for authenticating with the cache with Oauth tokens.
2+
//! This authenticator periodically fetches a new token every set amount of seconds.
3+
use crate::{
4+
deps::tracing::{error, info},
5+
utils::from_env::FromEnv,
6+
};
7+
use oauth2::{
8+
basic::{BasicClient, BasicTokenType},
9+
AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, StandardTokenResponse, TokenUrl, EndpointNotSet, EndpointSet, RequestTokenError,StandardErrorResponse, HttpClientError
10+
};
11+
use std::sync::{Arc, Mutex};
12+
use tokio::task::JoinHandle;
13+
14+
type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
15+
16+
17+
type MyOAuthClient = BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;
18+
19+
/// Configuration for the OAuth2 client.
20+
#[derive(Debug, Clone, FromEnv)]
21+
#[from_env(crate)]
22+
pub struct OAuthConfig {
23+
/// OAuth client ID for the builder.
24+
#[from_env(var = "OAUTH_CLIENT_ID", desc = "OAuth client ID for the builder")]
25+
pub oauth_client_id: String,
26+
/// OAuth client secret for the builder.
27+
#[from_env(
28+
var = "OAUTH_CLIENT_SECRET",
29+
desc = "OAuth client secret for the builder"
30+
)]
31+
pub oauth_client_secret: String,
32+
/// OAuth authenticate URL for the builder for performing OAuth logins.
33+
#[from_env(
34+
var = "OAUTH_AUTHENTICATE_URL",
35+
desc = "OAuth authenticate URL for the builder for performing OAuth logins"
36+
)]
37+
pub oauth_authenticate_url: url::Url,
38+
/// OAuth token URL for the builder to get an OAuth2 access token
39+
#[from_env(
40+
var = "OAUTH_TOKEN_URL",
41+
desc = "OAuth token URL for the builder to get an OAuth2 access token"
42+
)]
43+
pub oauth_token_url: url::Url,
44+
/// The oauth token refresh interval in seconds.
45+
#[from_env(
46+
var = "AUTH_TOKEN_REFRESH_INTERVAL",
47+
desc = "The oauth token refresh interval in seconds"
48+
)]
49+
pub oauth_token_refresh_interval: u64,
50+
}
51+
52+
/// A shared token that can be read and written to by multiple threads.
53+
#[derive(Debug, Clone, Default)]
54+
pub struct SharedToken(Arc<Mutex<Option<Token>>>);
55+
56+
impl SharedToken {
57+
/// Read the token from the shared token.
58+
pub fn read(&self) -> Option<Token> {
59+
self.0.lock().unwrap().clone()
60+
}
61+
62+
/// Write a new token to the shared token.
63+
pub fn write(&self, token: Token) {
64+
let mut lock = self.0.lock().unwrap();
65+
*lock = Some(token);
66+
}
67+
68+
/// Check if the token is authenticated.
69+
pub fn is_authenticated(&self) -> bool {
70+
self.0.lock().unwrap().is_some()
71+
}
72+
}
73+
74+
/// A self-refreshing, periodically fetching authenticator for the block
75+
/// builder. This task periodically fetches a new token, and stores it in a
76+
/// [`SharedToken`].
77+
#[derive(Debug)]
78+
pub struct Authenticator {
79+
/// Configuration
80+
pub config: OAuthConfig,
81+
client: MyOAuthClient,
82+
token: SharedToken,
83+
reqwest: reqwest::Client,
84+
}
85+
86+
impl Authenticator {
87+
/// Creates a new Authenticator from the provided builder config.
88+
pub fn new(config: &OAuthConfig) -> Self {
89+
let client = BasicClient::new(ClientId::new(config.oauth_client_id.clone()))
90+
.set_client_secret(ClientSecret::new(config.oauth_client_secret.clone()))
91+
.set_auth_uri(AuthUrl::from_url(config.oauth_authenticate_url.clone()))
92+
.set_token_uri(TokenUrl::from_url(config.oauth_token_url.clone()));
93+
94+
let rq_client = reqwest::Client::builder()
95+
.redirect(reqwest::redirect::Policy::none())
96+
.build()
97+
.unwrap();
98+
99+
Self {
100+
config: config.clone(),
101+
client,
102+
token: Default::default(),
103+
reqwest: rq_client,
104+
}
105+
}
106+
107+
/// Requests a new authentication token and, if successful, sets it to as the token
108+
pub async fn authenticate(&self) -> Result<(), RequestTokenError<HttpClientError<reqwest::Error>, StandardErrorResponse<oauth2::basic::BasicErrorResponseType>>> {
109+
let token = self.fetch_oauth_token().await?;
110+
self.set_token(token);
111+
Ok(())
112+
}
113+
114+
/// Returns true if there is Some token set
115+
pub fn is_authenticated(&self) -> bool {
116+
self.token.is_authenticated()
117+
}
118+
119+
/// Sets the Authenticator's token to the provided value
120+
fn set_token(&self, token: StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>) {
121+
self.token.write(token);
122+
}
123+
124+
/// Returns the currently set token
125+
pub fn token(&self) -> SharedToken {
126+
self.token.clone()
127+
}
128+
129+
/// Fetches an oauth token
130+
pub async fn fetch_oauth_token(&self) -> Result<Token, RequestTokenError<HttpClientError<reqwest::Error>, StandardErrorResponse<oauth2::basic::BasicErrorResponseType>>> {
131+
let token_result = self
132+
.client
133+
.exchange_client_credentials()
134+
.request_async(&self.reqwest)
135+
.await?;
136+
137+
Ok(token_result)
138+
}
139+
140+
/// Spawns a task that periodically fetches a new token every 300 seconds.
141+
pub fn spawn(self) -> JoinHandle<()> {
142+
let interval = self.config.oauth_token_refresh_interval;
143+
144+
let handle: JoinHandle<()> = tokio::spawn(async move {
145+
loop {
146+
info!("Refreshing oauth token");
147+
match self.authenticate().await {
148+
Ok(_) => {
149+
info!("Successfully refreshed oauth token");
150+
}
151+
Err(e) => {
152+
error!(%e, "Failed to refresh oauth token");
153+
}
154+
};
155+
let _sleep = tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await;
156+
}
157+
});
158+
159+
handle
160+
}
161+
}

0 commit comments

Comments
 (0)