From b366185a63f8182b1c9d8e32b16f738a42f7a912 Mon Sep 17 00:00:00 2001 From: Lucas Schwiderski Date: Tue, 19 Sep 2023 15:29:40 +0200 Subject: [PATCH] sdk: Implement worker pool for word generation Massive speed improvement. The index generation is really fast, and it appears that even worker numbers way higher than the core/thread count still increase the throughput slightly. The only missing part is the info output. That's broken, currently. --- Cargo.lock | 49 ++ crates/dtmt/Cargo.toml | 1 + .../src/cmd/experiment/brute_force_words.rs | 544 +++++++++++------- crates/dtmt/src/cmd/experiment/mod.rs | 4 +- 4 files changed, 401 insertions(+), 197 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dac07e9..3a02b55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -658,6 +658,20 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + [[package]] name = "crossbeam-channel" version = "0.5.12" @@ -667,6 +681,40 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset 0.9.1", + "scopeguard", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -943,6 +991,7 @@ dependencies = [ "cli-table", "color-eyre", "confy", + "crossbeam", "csv-async", "dtmt-shared", "futures", diff --git a/crates/dtmt/Cargo.toml b/crates/dtmt/Cargo.toml index 8018a8b..e80feaa 100644 --- a/crates/dtmt/Cargo.toml +++ b/crates/dtmt/Cargo.toml @@ -35,6 +35,7 @@ luajit2-sys = { path = "../../lib/luajit2-sys", version = "*" } shlex = { version = "1.2.0", optional = true } atty = "0.2.14" itertools = "0.11.0" +crossbeam = { version = "0.8.2", features = ["crossbeam-deque"] } [dev-dependencies] tempfile = "3.3.0" diff --git a/crates/dtmt/src/cmd/experiment/brute_force_words.rs b/crates/dtmt/src/cmd/experiment/brute_force_words.rs index 7e93dcc..aa15003 100644 --- a/crates/dtmt/src/cmd/experiment/brute_force_words.rs +++ b/crates/dtmt/src/cmd/experiment/brute_force_words.rs @@ -1,13 +1,16 @@ use std::collections::HashSet; +use std::fs; +use std::io::Write; use std::path::PathBuf; +use std::sync::Arc; +use std::thread::JoinHandle; use clap::{value_parser, Arg, ArgAction, ArgMatches, Command}; use color_eyre::eyre::{self, Context}; use color_eyre::Result; +use crossbeam::channel::{bounded, unbounded, Receiver, Sender}; use itertools::Itertools; use sdk::murmur::Murmur64; -use tokio::fs; -use tokio::io::AsyncWriteExt; use tokio::time::Instant; pub(crate) fn command_definition() -> Command { @@ -57,6 +60,14 @@ pub(crate) fn command_definition() -> Command { .short('c') .long("continue") ) + .arg( + Arg::new("threads") + .help("The number of workers to run in parallel.") + .long("threads") + .short('n') + .default_value("6") + .value_parser(value_parser!(usize)) + ) .arg( Arg::new("words") .help("Path to a file containing words line by line.") @@ -75,36 +86,307 @@ pub(crate) fn command_definition() -> Command { ) } +const LINE_FEED: u8 = 0x0A; +const UNDERSCORE: u8 = 0x5F; +const ZERO: u8 = 0x30; + +const PREFIXES: [&str; 29] = [ + "", + "content/characters/", + "content/debug/", + "content/decals/", + "content/environment/", + "content/fx/", + "content/gizmos/", + "content/items/", + "content/levels/", + "content/liquid_area/", + "content/localization/", + "content/materials/", + "content/minion_impact_assets/", + "content/pickups/", + "content/shading_environments/", + "content/textures/", + "content/ui/", + "content/videos/", + "content/vo/", + "content/volume_types/", + "content/weapons/", + "packages/boot_assets/", + "packages/content/", + "packages/game_scripts/", + "packages/strings/", + "packages/ui/", + "wwise/events/", + "wwise/packages/", + "wwise/world_sound_fx/", +]; + +fn make_info_printer(rx: Receiver<(usize, usize)>, hash_count: usize) -> JoinHandle<()> { + std::thread::spawn(move || { + let mut writer = std::io::stderr(); + let mut total_count = 0; + let mut total_found = 0; + + let start = Instant::now(); + + while let Ok((count, found)) = rx.recv() { + total_count += count; + total_found += found; + + let dur = Instant::now() - start; + if dur.as_secs() > 1 { + let s = format!("\r{total_count} per second | {total_found:6}/{hash_count} found",); + + // let s = String::from_utf8_lossy(&buf); + // // The last prefix in the set is the one that will stay in the buffer + // // when we're about to print here. + // // So we strip that, to show just the generated part. + // // We also restrict the length to stay on a single line. + // let prefix_len = prefixes[28].len(); + // let s = s[prefix_len..std::cmp::min(s.len(), prefix_len + 60)] + // .trim_end() + // .to_string(); + + writer.write_all(s.as_bytes()).unwrap(); + + total_count = 0; + } + } + }) +} + +fn make_stdout_printer(rx: Receiver>) -> JoinHandle<()> { + std::thread::spawn(move || { + let mut writer = std::io::stdout(); + + while let Ok(buf) = rx.recv() { + writer.write_all(&buf).unwrap(); + } + }) +} + +struct State { + delimiter_lists: Arc>>, + hashes: Arc>, + words: Arc>, + delimiters_len: usize, + stdout_tx: Sender>, + info_tx: Sender<(usize, usize)>, +} + +fn make_worker(rx: Receiver>, state: State) -> JoinHandle<()> { + std::thread::spawn(move || { + let delimiter_lists = &state.delimiter_lists; + let hashes = &state.hashes; + let words = &state.words; + let delimiters_len = state.delimiters_len; + + let mut count = 0; + let mut found = 0; + let mut buf = Vec::with_capacity(1024); + + // while let Some(indices) = find_task(local, global, &[]) { + while let Ok(indices) = rx.recv() { + let sequence = indices.iter().map(|i| words[*i].as_str()); + + // We only want delimiters between words, so we keep that iterator shorter by + // one. + let delimiter_count = sequence.len() as u32 - 1; + + for prefix in PREFIXES.iter().map(|p| p.as_bytes()) { + buf.clear(); + + // We can keep the prefix at the front of the buffer and only + // replace the parts after that. + let prefix_len = prefix.len(); + buf.extend_from_slice(prefix); + + for delims in delimiter_lists + .iter() + .take(delimiters_len.pow(delimiter_count)) + { + buf.truncate(prefix_len); + + let delims = delims + .iter() + .map(|s| s.as_str()) + .take(delimiter_count as usize); + sequence + .clone() + .interleave(delims.clone()) + .for_each(|word| buf.extend_from_slice(word.as_bytes())); + + count += 1; + + let hash = Murmur64::hash(&buf); + if hashes.contains(&hash) { + found += 1; + + buf.push(LINE_FEED); + if let Err(_) = state.stdout_tx.send(buf.clone()) { + return; + } + } else { + let word_len = buf.len(); + + // If the regular word itself didn't match, we check + // for numbered suffixes. + // For now, we only check up to `09` to avoid more complex logic + // writing into the buffer. + // Packages that contain files with higher numbers than this + // should hopefully become easier to spot once a good number of + // hashes is found. + for i in 1..=9 { + buf.truncate(word_len); + buf.push(UNDERSCORE); + buf.push(ZERO); + buf.push(ZERO + i); + + count += 1; + + let hash = Murmur64::hash(&buf); + if hashes.contains(&hash) { + found += 1; + + buf.push(LINE_FEED); + if let Err(_) = state.stdout_tx.send(buf.clone()) { + return; + } + } else { + break; + } + } + } + } + } + + if count >= 1024 * 1024 { + let _ = state.info_tx.send((count, found)); + } + + // let dur = Instant::now() - start; + // if dur.as_secs() >= 1 { + // let hashes_len = hashes.len(); + // let s = String::from_utf8_lossy(&buf); + // // The last prefix in the set is the one that will stay in the buffer + // // when we're about to print here. + // // So we strip that, to show just the generated part. + // // We also restrict the length to stay on a single line. + // let prefix_len = prefixes[28].len(); + // let s = s[prefix_len..std::cmp::min(s.len(), prefix_len + 60)] + // .trim_end() + // .to_string(); + // info_tx.send(format!( + // "\r{:8} hashes per second | {:6}/{} found | {:<60}", + // count, found, hashes_len, s + // )); + + // start = Instant::now(); + // count = 0; + // } + } + }) +} + +fn build_delimiter_lists(delimiters: impl AsRef<[String]>, max_length: usize) -> Vec> { + let delimiters = delimiters.as_ref(); + let mut indices = vec![0; max_length]; + let mut list = Vec::new(); + + for _ in 0..delimiters.len().pow(max_length as u32) { + list.push( + indices + .iter() + .map(|i| delimiters[*i].clone()) + .collect::>(), + ); + + for v in indices.iter_mut() { + if *v >= delimiters.len() - 1 { + *v = 0; + break; + } else { + *v += 1; + } + } + } + + list +} + +fn build_initial_indices( + cont: Option<&String>, + delimiters: impl AsRef<[String]>, + words: impl AsRef<[String]>, +) -> Result> { + if let Some(cont) = cont { + let mut splits = vec![cont.clone()]; + + for delim in delimiters.as_ref().iter() { + splits = splits + .iter() + .flat_map(|s| s.split(delim)) + .map(|s| s.to_string()) + .collect(); + } + + let indices = splits + .into_iter() + .map(|s| { + words + .as_ref() + .iter() + .enumerate() + .find(|(_, v)| s == **v) + .map(|(i, _)| i) + .ok_or_else(|| eyre::eyre!("'{}' is not in the word list", s)) + }) + .collect::>()?; + + tracing::info!("Continuing from '{}' -> '{:?}'", cont, &indices); + + Ok(indices) + } else { + Ok(vec![0]) + } +} + #[tracing::instrument(skip_all)] #[allow(clippy::mut_range_bound)] -pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> { +pub(crate) fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> { let max_length: usize = matches .get_one::("max-length") .copied() .expect("parameter has default"); - let words: Vec = { + let num_threads: usize = matches + .get_one::("threads") + .copied() + .expect("parameter has default"); + + let words = { let path = matches .get_one::("words") .expect("missing required parameter"); - let file = fs::read_to_string(&path) - .await + let file = fs::read_to_string(path) .wrap_err_with(|| format!("Failed to read file '{}'", path.display()))?; - file.lines().map(str::to_string).collect() - }; + let words: Vec<_> = file.lines().map(str::to_string).collect(); - if words.is_empty() { - eyre::bail!("Word list must not be empty"); - } + if words.is_empty() { + eyre::bail!("Word list must not be empty"); + } + + Arc::new(words) + }; let hashes = { let path = matches .get_one::("hashes") .expect("missing required argument"); - let content = fs::read_to_string(&path) - .await + let content = fs::read_to_string(path) .wrap_err_with(|| format!("Failed to read file '{}'", path.display()))?; let hashes: Result, _> = content @@ -116,7 +398,7 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> tracing::trace!("{:?}", hashes); - hashes + Arc::new(hashes) }; let mut delimiters: Vec = matches @@ -132,38 +414,6 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> let delimiters_len = delimiters.len(); - let prefixes = [ - "", - "content/characters/", - "content/debug/", - "content/decals/", - "content/environment/", - "content/fx/", - "content/gizmos/", - "content/items/", - "content/levels/", - "content/liquid_area/", - "content/localization/", - "content/materials/", - "content/minion_impact_assets/", - "content/pickups/", - "content/shading_environments/", - "content/textures/", - "content/ui/", - "content/videos/", - "content/vo/", - "content/volume_types/", - "content/weapons/", - "packages/boot_assets/", - "packages/content/", - "packages/game_scripts/", - "packages/strings/", - "packages/ui/", - "wwise/events/", - "wwise/packages/", - "wwise/world_sound_fx/", - ]; - let word_count = words.len(); tracing::info!("{} words to try", word_count); @@ -175,56 +425,43 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> // So we basically have to implement a smaller version of the iterative algorithm we use later on // to build permutations of the actual words. let delimiter_lists = { - let mut indices = vec![0; max_length - 1]; - let mut list = Vec::new(); - - for _ in 0..delimiters_len.pow(max_length as u32 - 1) { - list.push(indices.iter().map(|i| &delimiters[*i]).collect::>()); - - for v in indices.iter_mut() { - if *v >= delimiters_len - 1 { - *v = 0; - break; - } else { - *v += 1; - } - } - } - - list + let lists = build_delimiter_lists(&delimiters, max_length - 1); + Arc::new(lists) }; - tracing::debug!("{:?}", delimiter_lists); - let mut indices = if let Some(cont) = matches.get_one::("continue").cloned() { - let mut splits = vec![cont.clone()]; + let (info_tx, info_rx) = unbounded(); + let (stdout_tx, stdout_rx) = unbounded::>(); + let (task_tx, task_rx) = bounded::>(100); + let mut handles = Vec::new(); - for delim in delimiters.iter() { - splits = splits - .iter() - .flat_map(|s| s.split(delim)) - .map(|s| s.to_string()) - .collect(); - } + for _ in 0..num_threads { + let handle = make_worker( + task_rx.clone(), + State { + delimiter_lists: Arc::clone(&delimiter_lists), + hashes: Arc::clone(&hashes), + words: Arc::clone(&words), + delimiters_len, + stdout_tx: stdout_tx.clone(), + info_tx: info_tx.clone(), + }, + ); + handles.push(handle); + } + // These are only used inside the worker threads, but due to the loops above, we had to + // clone them one too many times. + // So we drop that extra reference immediately, to ensure that the channels can + // disconnect properly when the threads finish. + drop(stdout_tx); + drop(info_tx); - let indices = splits - .into_iter() - .map(|s| { - words - .iter() - .enumerate() - .find(|(_, v)| s == **v) - .map(|(i, _)| i) - .ok_or_else(|| eyre::eyre!("'{}' is not in the word list", s)) - }) - .collect::>()?; + // handles.push(make_info_printer(info_rx, hashes.len())); + handles.push(make_stdout_printer(stdout_rx)); - tracing::info!("Continuing from '{}' -> '{:?}'", cont, &indices); - - indices - } else { - vec![0] - }; + let mut indices = + build_initial_indices(matches.get_one::("continue"), &delimiters, &*words) + .wrap_err("Failed to build initial indices")?; let mut indices_len = indices.len(); let mut sequence = indices .iter() @@ -235,113 +472,8 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> indices.reserve(max_length); sequence.reserve(max_length); - let mut count: usize = 0; - let mut found: usize = 0; - let mut start = Instant::now(); - - // let mut writer = BufWriter::new(tokio::io::stdout()); - let mut writer = tokio::io::stdout(); - let mut buf = Vec::with_capacity(1024); - - const LINE_FEED: u8 = 0x0A; - const UNDERSCORE: u8 = 0x5F; - const ZERO: u8 = 0x30; - 'outer: loop { - // We only want delimiters between words, so we keep that iterator shorter by - // one. - let delimiter_count = sequence.len() as u32 - 1; - - for prefix in prefixes.iter().map(|p| p.as_bytes()) { - buf.clear(); - - // We can keep the prefix at the front of the buffer and only - // replace the parts after that. - let prefix_len = prefix.len(); - buf.extend_from_slice(prefix); - - for delims in delimiter_lists - .iter() - .take(delimiters_len.pow(delimiter_count)) - { - buf.truncate(prefix_len); - - let delims = delims - .iter() - .map(|s| s.as_str()) - .take(delimiter_count as usize); - sequence - .iter() - .copied() - .interleave(delims.clone()) - .for_each(|word| buf.extend_from_slice(word.as_bytes())); - - count += 1; - - let hash = Murmur64::hash(&buf); - if hashes.contains(&hash) { - found += 1; - buf.push(LINE_FEED); - writer.write_all(&buf).await?; - } else { - let word_len = buf.len(); - - // If the regular word itself didn't match, we check - // for numbered suffixes. - // For now, we only check up to `09` to avoid more complex logic - // writing into the buffer. - // Packages that contain files with higher numbers than this - // should hopefully become easier to spot once a good number of - // hashes is found. - for i in 1..=9 { - buf.truncate(word_len); - buf.push(UNDERSCORE); - buf.push(ZERO); - buf.push(ZERO + i); - - count += 1; - - let hash = Murmur64::hash(&buf); - if hashes.contains(&hash) { - found += 1; - buf.push(LINE_FEED); - writer.write_all(&buf).await?; - } else { - break; - } - } - } - } - } - - let dur = Instant::now() - start; - if dur.as_secs() >= 1 { - let hashes_len = hashes.len(); - let s = String::from_utf8_lossy(&buf); - // The last prefix in the set is the one that will stay in the buffer - // when we're about to print here. - // So we strip that, to show just the generated part. - // We also restrict the length to stay on a single line. - let prefix_len = prefixes[28].len(); - let s = s[prefix_len..std::cmp::min(s.len(), prefix_len + 60)] - .trim_end() - .to_string(); - // Don't care when it finishes, don't care if it fails. - tokio::spawn(async move { - let _ = tokio::io::stderr() - .write_all( - format!( - "\r{:8} hashes per second | {:6}/{} found | {:<60}", - count, found, hashes_len, s - ) - .as_bytes(), - ) - .await; - }); - - start = Instant::now(); - count = 0; - } + task_tx.send(indices.clone())?; for i in 0..indices_len { let index = indices.get_mut(i).unwrap(); @@ -371,5 +503,25 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> } } + // Dropping the senders will disconnect the channel, + // so that the threads holding the other end will eventually + // complete as well. + drop(task_tx); + + tracing::debug!("Wainting for workers to finish."); + + for handle in handles { + match handle.join() { + Ok(_) => {} + Err(value) => { + if let Some(err) = value.downcast_ref::() { + eyre::bail!("Thread failed: {}", err); + } else { + eyre::bail!("Thread failed with unknown error: {:?}", value); + } + } + } + } + Ok(()) } diff --git a/crates/dtmt/src/cmd/experiment/mod.rs b/crates/dtmt/src/cmd/experiment/mod.rs index 9ceb3b9..c53d9b5 100644 --- a/crates/dtmt/src/cmd/experiment/mod.rs +++ b/crates/dtmt/src/cmd/experiment/mod.rs @@ -15,7 +15,9 @@ pub(crate) fn command_definition() -> Command { #[tracing::instrument(skip_all)] pub(crate) async fn run(ctx: sdk::Context, matches: &ArgMatches) -> Result<()> { match matches.subcommand() { - Some(("brute-force-words", sub_matches)) => brute_force_words::run(ctx, sub_matches).await, + // It's fine to block here, as this is the only thing that's executing on the runtime. + // The other option with `spawn_blocking` would require setting up values to be Send+Sync. + Some(("brute-force-words", sub_matches)) => brute_force_words::run(ctx, sub_matches), Some(("extract-words", sub_matches)) => extract_words::run(ctx, sub_matches).await, _ => unreachable!( "clap is configured to require a subcommand, and they're all handled above"