Skip to content

Simplify Rust code #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions rust/gen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::fs::File;
use std::io::Write;
use std::path::Path;

use anyhow::{Context, Result};
use anyhow::{ensure, Context, Result};
use serde::Deserialize;

fn main() -> Result<()> {
Expand Down Expand Up @@ -132,9 +132,9 @@ fn generate_model_config(content_types: &[String], model_config: ModelConfig) ->
writeln!(output, "use crate::ContentType;\n")?;
writeln!(output, "pub(crate) const CONFIG: ModelConfig = ModelConfig {{")?;
writeln!(output, " beg_size: {beg_size},")?;
writeln!(output, " mid_size: {mid_size},")?;
ensure!(mid_size == 0, "unsupported mid_size");
writeln!(output, " end_size: {end_size},")?;
writeln!(output, " use_inputs_at_offsets: {use_inputs_at_offsets},")?;
ensure!(!use_inputs_at_offsets, "unsupported use_inputs_at_offsets");
writeln!(output, " min_file_size_for_dl: {min_file_size_for_dl},")?;
writeln!(output, " padding_token: {padding_token},")?;
writeln!(output, " block_size: {block_size},")?;
Expand Down
1 change: 1 addition & 0 deletions rust/lib/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

### Minor

- Remove features extraction logic of older models
- Use the `standard_v3_3` model instead of `standard_v3_2` (see [model changelog])
- Add `OverwriteReason` to document why the inferred content type is overwritten

Expand Down
20 changes: 3 additions & 17 deletions rust/lib/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ use crate::ContentType;
#[derive(Debug)]
pub(crate) struct ModelConfig {
pub(crate) beg_size: usize,
pub(crate) mid_size: usize,
pub(crate) end_size: usize,
pub(crate) use_inputs_at_offsets: bool,
pub(crate) min_file_size_for_dl: usize,
pub(crate) padding_token: i32,
pub(crate) block_size: usize,
Expand All @@ -31,30 +29,18 @@ pub(crate) struct ModelConfig {

pub(crate) struct SplitFeatures<'a> {
pub(crate) beg: &'a mut [i32],
pub(crate) mid: &'a mut [i32],
pub(crate) end: &'a mut [i32],
pub(crate) off: Vec<(usize, &'a mut [i32])>,
}

impl ModelConfig {
pub(crate) fn features_size(&self) -> usize {
let offsets_size = if self.use_inputs_at_offsets { 4 * 8 } else { 0 };
self.beg_size + self.mid_size + self.end_size + offsets_size
self.beg_size + self.end_size
}

pub(crate) fn split_features<'a>(&self, features: &'a mut [i32]) -> SplitFeatures<'a> {
let (beg, features) = features.split_at_mut(self.beg_size);
let (mid, features) = features.split_at_mut(self.mid_size);
let (end, mut features) = features.split_at_mut(self.end_size);
let mut off = Vec::new();
if self.use_inputs_at_offsets {
for offset in [0x8000, 0x8800, 0x9000, 0x9800] {
let (head, tail) = features.split_at_mut(8);
features = tail;
off.push((offset, head));
}
}
let (end, features) = features.split_at_mut(self.end_size);
debug_assert!(features.is_empty());
SplitFeatures { beg, mid, end, off }
SplitFeatures { beg, end }
}
}
34 changes: 10 additions & 24 deletions rust/lib/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ async fn extract_features_async(
config: &ModelConfig, mut file: impl AsyncInputApi, file_len: usize,
) -> Result<(Vec<u8>, Vec<i32>)> {
debug_assert!(config.beg_size < config.block_size);
debug_assert!(config.mid_size < config.block_size);
debug_assert!(config.end_size < config.block_size);
let buffer_size = std::cmp::min(config.block_size, file_len);
let mut content_beg = vec![0; buffer_size];
Expand All @@ -160,31 +159,18 @@ async fn extract_features_async(
let mut end = vec![0; buffer_size];
file.read_at(&mut end, file_len - buffer_size).await?;
let end = strip_suffix(&end);
let mid_len = std::cmp::min(config.mid_size, file_len);
let mid_off = (file_len - mid_len) / 2;
let mut mid = vec![0; mid_len];
file.read_at(&mut mid, mid_off).await?;
let mut features = vec![config.padding_token; config.features_size()];
let split_features = config.split_features(&mut features);
copy_features(split_features.beg, beg, 0);
copy_features(split_features.mid, &mid, 1);
copy_features(split_features.end, end, 2);
for (offset, features) in split_features.off {
let mut buffer = Vec::new();
if offset + features.len() <= file_len {
buffer = vec![0; features.len()];
file.read_at(&mut buffer, offset).await?;
}
copy_features(features, &buffer, 0);
}
copy_features(split_features.end, end, 1);
Ok((content_beg, features))
}

fn copy_features(dst: &mut [i32], src: &[u8], align: usize) {
let len = std::cmp::min(dst.len(), src.len());
let dst_len = dst.len(); // borrowing issue: cannot inline below
let dst = &mut dst[(dst_len - len) * align / 2..][..len];
let src = &src[(src.len() - len) * align / 2..][..len];
let dst = &mut dst[(dst_len - len) * align..][..len];
let src = &src[(src.len() - len) * align..][..len];
for (dst, src) in dst.iter_mut().zip(src.iter()) {
*dst = *src as i32;
}
Expand Down Expand Up @@ -272,23 +258,23 @@ mod tests {
GzDecoder::new(File::open(PATH).unwrap()).read_to_string(&mut tests).unwrap();
let tests: Vec<Test> = serde_json::from_str(&tests).unwrap();
for test in tests {
assert_eq!(test.args.mid_size, 0, "unsupported mid_size");
assert!(!test.args.use_inputs_at_offsets, "unsupported use_inputs_at_offsets");
assert!(test.features.mid.is_empty(), "unsupported mid");
assert!(test.features.offset_0x8000_0x8007.is_empty(), "unsupported offset");
assert!(test.features.offset_0x8800_0x8807.is_empty(), "unsupported offset");
assert!(test.features.offset_0x9000_0x9007.is_empty(), "unsupported offset");
assert!(test.features.offset_0x9800_0x9807.is_empty(), "unsupported offset");
let config = ModelConfig {
beg_size: test.args.beg_size,
mid_size: test.args.mid_size,
end_size: test.args.end_size,
use_inputs_at_offsets: test.args.use_inputs_at_offsets,
padding_token: test.args.padding_token,
block_size: test.args.block_size,
..crate::model::CONFIG
};
let mut expected = Vec::new();
expected.extend_from_slice(&test.features.beg);
expected.extend_from_slice(&test.features.mid);
expected.extend_from_slice(&test.features.end);
expected.extend_from_slice(&test.features.offset_0x8000_0x8007);
expected.extend_from_slice(&test.features.offset_0x8800_0x8807);
expected.extend_from_slice(&test.features.offset_0x9000_0x9007);
expected.extend_from_slice(&test.features.offset_0x9800_0x9807);
let content = BASE64.decode(test.content_base64.as_bytes()).unwrap();
let actual = extract_features_async(&config, content.as_slice(), content.len());
let actual = exec(actual).unwrap().1;
Expand Down
2 changes: 0 additions & 2 deletions rust/lib/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ use crate::ContentType;

pub(crate) const CONFIG: ModelConfig = ModelConfig {
beg_size: 1024,
mid_size: 0,
end_size: 1024,
use_inputs_at_offsets: false,
min_file_size_for_dl: 8,
padding_token: 256,
block_size: 4096,
Expand Down
Loading