sdk: Implement worker pool for word generation

Massive speed improvement. The index generation is really fast,
and it appears that even worker numbers way higher than the core/thread
count still increase the throughput slightly.

The only missing part is the info output. That's broken, currently.
This commit is contained in:
Lucas Schwiderski 2023-09-19 15:29:40 +02:00
parent 951a7f82c0
commit b366185a63
Signed by: lucas
GPG key ID: AA12679AAA6DF4D8
4 changed files with 401 additions and 197 deletions

49
Cargo.lock generated
View file

@ -658,6 +658,20 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "crossbeam"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-channel" name = "crossbeam-channel"
version = "0.5.12" version = "0.5.12"
@ -667,6 +681,40 @@ dependencies = [
"crossbeam-utils", "crossbeam-utils",
] ]
[[package]]
name = "crossbeam-deque"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
dependencies = [
"cfg-if",
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
dependencies = [
"autocfg",
"cfg-if",
"crossbeam-utils",
"memoffset 0.9.1",
"scopeguard",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-utils" name = "crossbeam-utils"
version = "0.8.20" version = "0.8.20"
@ -943,6 +991,7 @@ dependencies = [
"cli-table", "cli-table",
"color-eyre", "color-eyre",
"confy", "confy",
"crossbeam",
"csv-async", "csv-async",
"dtmt-shared", "dtmt-shared",
"futures", "futures",

View file

@ -35,6 +35,7 @@ luajit2-sys = { path = "../../lib/luajit2-sys", version = "*" }
shlex = { version = "1.2.0", optional = true } shlex = { version = "1.2.0", optional = true }
atty = "0.2.14" atty = "0.2.14"
itertools = "0.11.0" itertools = "0.11.0"
crossbeam = { version = "0.8.2", features = ["crossbeam-deque"] }
[dev-dependencies] [dev-dependencies]
tempfile = "3.3.0" tempfile = "3.3.0"

View file

@ -1,13 +1,16 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::fs;
use std::io::Write;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use std::thread::JoinHandle;
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command}; use clap::{value_parser, Arg, ArgAction, ArgMatches, Command};
use color_eyre::eyre::{self, Context}; use color_eyre::eyre::{self, Context};
use color_eyre::Result; use color_eyre::Result;
use crossbeam::channel::{bounded, unbounded, Receiver, Sender};
use itertools::Itertools; use itertools::Itertools;
use sdk::murmur::Murmur64; use sdk::murmur::Murmur64;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tokio::time::Instant; use tokio::time::Instant;
pub(crate) fn command_definition() -> Command { pub(crate) fn command_definition() -> Command {
@ -57,6 +60,14 @@ pub(crate) fn command_definition() -> Command {
.short('c') .short('c')
.long("continue") .long("continue")
) )
.arg(
Arg::new("threads")
.help("The number of workers to run in parallel.")
.long("threads")
.short('n')
.default_value("6")
.value_parser(value_parser!(usize))
)
.arg( .arg(
Arg::new("words") Arg::new("words")
.help("Path to a file containing words line by line.") .help("Path to a file containing words line by line.")
@ -75,64 +86,11 @@ pub(crate) fn command_definition() -> Command {
) )
} }
#[tracing::instrument(skip_all)] const LINE_FEED: u8 = 0x0A;
#[allow(clippy::mut_range_bound)] const UNDERSCORE: u8 = 0x5F;
pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> { const ZERO: u8 = 0x30;
let max_length: usize = matches
.get_one::<usize>("max-length")
.copied()
.expect("parameter has default");
let words: Vec<String> = { const PREFIXES: [&str; 29] = [
let path = matches
.get_one::<PathBuf>("words")
.expect("missing required parameter");
let file = fs::read_to_string(&path)
.await
.wrap_err_with(|| format!("Failed to read file '{}'", path.display()))?;
file.lines().map(str::to_string).collect()
};
if words.is_empty() {
eyre::bail!("Word list must not be empty");
}
let hashes = {
let path = matches
.get_one::<PathBuf>("hashes")
.expect("missing required argument");
let content = fs::read_to_string(&path)
.await
.wrap_err_with(|| format!("Failed to read file '{}'", path.display()))?;
let hashes: Result<HashSet<_>, _> = content
.lines()
.map(|s| u64::from_str_radix(s, 16).map(Murmur64::from))
.collect();
let hashes = hashes?;
tracing::trace!("{:?}", hashes);
hashes
};
let mut delimiters: Vec<String> = matches
.get_many::<String>("delimiter")
.unwrap_or_default()
.cloned()
.collect();
if delimiters.is_empty() {
delimiters.push(String::from("/"));
delimiters.push(String::from("_"));
}
let delimiters_len = delimiters.len();
let prefixes = [
"", "",
"content/characters/", "content/characters/",
"content/debug/", "content/debug/",
@ -162,97 +120,81 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()>
"wwise/events/", "wwise/events/",
"wwise/packages/", "wwise/packages/",
"wwise/world_sound_fx/", "wwise/world_sound_fx/",
]; ];
let word_count = words.len(); fn make_info_printer(rx: Receiver<(usize, usize)>, hash_count: usize) -> JoinHandle<()> {
tracing::info!("{} words to try", word_count); std::thread::spawn(move || {
let mut writer = std::io::stderr();
let mut total_count = 0;
let mut total_found = 0;
// To be able to easily combine the permutations of words and delimiters, let start = Instant::now();
// we turn the latter into a pre-defined list of all permutations of delimiters
// that are possible at the given amount of words.
// Combining `Iterator::cycle` with `Itertools::permutations` works, but
// with a high `max_length`, it runs OOM.
// So we basically have to implement a smaller version of the iterative algorithm we use later on
// to build permutations of the actual words.
let delimiter_lists = {
let mut indices = vec![0; max_length - 1];
let mut list = Vec::new();
for _ in 0..delimiters_len.pow(max_length as u32 - 1) { while let Ok((count, found)) = rx.recv() {
list.push(indices.iter().map(|i| &delimiters[*i]).collect::<Vec<_>>()); total_count += count;
total_found += found;
for v in indices.iter_mut() { let dur = Instant::now() - start;
if *v >= delimiters_len - 1 { if dur.as_secs() > 1 {
*v = 0; let s = format!("\r{total_count} per second | {total_found:6}/{hash_count} found",);
break;
} else { // let s = String::from_utf8_lossy(&buf);
*v += 1; // // The last prefix in the set is the one that will stay in the buffer
// // when we're about to print here.
// // So we strip that, to show just the generated part.
// // We also restrict the length to stay on a single line.
// let prefix_len = prefixes[28].len();
// let s = s[prefix_len..std::cmp::min(s.len(), prefix_len + 60)]
// .trim_end()
// .to_string();
writer.write_all(s.as_bytes()).unwrap();
total_count = 0;
} }
} }
}
list
};
tracing::debug!("{:?}", delimiter_lists);
let mut indices = if let Some(cont) = matches.get_one::<String>("continue").cloned() {
let mut splits = vec![cont.clone()];
for delim in delimiters.iter() {
splits = splits
.iter()
.flat_map(|s| s.split(delim))
.map(|s| s.to_string())
.collect();
}
let indices = splits
.into_iter()
.map(|s| {
words
.iter()
.enumerate()
.find(|(_, v)| s == **v)
.map(|(i, _)| i)
.ok_or_else(|| eyre::eyre!("'{}' is not in the word list", s))
}) })
.collect::<Result<_>>()?; }
tracing::info!("Continuing from '{}' -> '{:?}'", cont, &indices); fn make_stdout_printer(rx: Receiver<Vec<u8>>) -> JoinHandle<()> {
std::thread::spawn(move || {
let mut writer = std::io::stdout();
indices while let Ok(buf) = rx.recv() {
} else { writer.write_all(&buf).unwrap();
vec![0] }
}; })
let mut indices_len = indices.len(); }
let mut sequence = indices
.iter()
.map(|index| words[*index].as_str())
.collect::<Vec<_>>();
// Prevent re-allocation by reserving as much as we need upfront struct State {
indices.reserve(max_length); delimiter_lists: Arc<Vec<Vec<String>>>,
sequence.reserve(max_length); hashes: Arc<HashSet<Murmur64>>,
words: Arc<Vec<String>>,
delimiters_len: usize,
stdout_tx: Sender<Vec<u8>>,
info_tx: Sender<(usize, usize)>,
}
let mut count: usize = 0; fn make_worker(rx: Receiver<Vec<usize>>, state: State) -> JoinHandle<()> {
let mut found: usize = 0; std::thread::spawn(move || {
let mut start = Instant::now(); let delimiter_lists = &state.delimiter_lists;
let hashes = &state.hashes;
let words = &state.words;
let delimiters_len = state.delimiters_len;
// let mut writer = BufWriter::new(tokio::io::stdout()); let mut count = 0;
let mut writer = tokio::io::stdout(); let mut found = 0;
let mut buf = Vec::with_capacity(1024); let mut buf = Vec::with_capacity(1024);
const LINE_FEED: u8 = 0x0A; // while let Some(indices) = find_task(local, global, &[]) {
const UNDERSCORE: u8 = 0x5F; while let Ok(indices) = rx.recv() {
const ZERO: u8 = 0x30; let sequence = indices.iter().map(|i| words[*i].as_str());
'outer: loop {
// We only want delimiters between words, so we keep that iterator shorter by // We only want delimiters between words, so we keep that iterator shorter by
// one. // one.
let delimiter_count = sequence.len() as u32 - 1; let delimiter_count = sequence.len() as u32 - 1;
for prefix in prefixes.iter().map(|p| p.as_bytes()) { for prefix in PREFIXES.iter().map(|p| p.as_bytes()) {
buf.clear(); buf.clear();
// We can keep the prefix at the front of the buffer and only // We can keep the prefix at the front of the buffer and only
@ -271,8 +213,7 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()>
.map(|s| s.as_str()) .map(|s| s.as_str())
.take(delimiter_count as usize); .take(delimiter_count as usize);
sequence sequence
.iter() .clone()
.copied()
.interleave(delims.clone()) .interleave(delims.clone())
.for_each(|word| buf.extend_from_slice(word.as_bytes())); .for_each(|word| buf.extend_from_slice(word.as_bytes()));
@ -281,8 +222,11 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()>
let hash = Murmur64::hash(&buf); let hash = Murmur64::hash(&buf);
if hashes.contains(&hash) { if hashes.contains(&hash) {
found += 1; found += 1;
buf.push(LINE_FEED); buf.push(LINE_FEED);
writer.write_all(&buf).await?; if let Err(_) = state.stdout_tx.send(buf.clone()) {
return;
}
} else { } else {
let word_len = buf.len(); let word_len = buf.len();
@ -304,8 +248,11 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()>
let hash = Murmur64::hash(&buf); let hash = Murmur64::hash(&buf);
if hashes.contains(&hash) { if hashes.contains(&hash) {
found += 1; found += 1;
buf.push(LINE_FEED); buf.push(LINE_FEED);
writer.write_all(&buf).await?; if let Err(_) = state.stdout_tx.send(buf.clone()) {
return;
}
} else { } else {
break; break;
} }
@ -314,35 +261,220 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()>
} }
} }
let dur = Instant::now() - start; if count >= 1024 * 1024 {
if dur.as_secs() >= 1 { let _ = state.info_tx.send((count, found));
let hashes_len = hashes.len();
let s = String::from_utf8_lossy(&buf);
// The last prefix in the set is the one that will stay in the buffer
// when we're about to print here.
// So we strip that, to show just the generated part.
// We also restrict the length to stay on a single line.
let prefix_len = prefixes[28].len();
let s = s[prefix_len..std::cmp::min(s.len(), prefix_len + 60)]
.trim_end()
.to_string();
// Don't care when it finishes, don't care if it fails.
tokio::spawn(async move {
let _ = tokio::io::stderr()
.write_all(
format!(
"\r{:8} hashes per second | {:6}/{} found | {:<60}",
count, found, hashes_len, s
)
.as_bytes(),
)
.await;
});
start = Instant::now();
count = 0;
} }
// let dur = Instant::now() - start;
// if dur.as_secs() >= 1 {
// let hashes_len = hashes.len();
// let s = String::from_utf8_lossy(&buf);
// // The last prefix in the set is the one that will stay in the buffer
// // when we're about to print here.
// // So we strip that, to show just the generated part.
// // We also restrict the length to stay on a single line.
// let prefix_len = prefixes[28].len();
// let s = s[prefix_len..std::cmp::min(s.len(), prefix_len + 60)]
// .trim_end()
// .to_string();
// info_tx.send(format!(
// "\r{:8} hashes per second | {:6}/{} found | {:<60}",
// count, found, hashes_len, s
// ));
// start = Instant::now();
// count = 0;
// }
}
})
}
fn build_delimiter_lists(delimiters: impl AsRef<[String]>, max_length: usize) -> Vec<Vec<String>> {
let delimiters = delimiters.as_ref();
let mut indices = vec![0; max_length];
let mut list = Vec::new();
for _ in 0..delimiters.len().pow(max_length as u32) {
list.push(
indices
.iter()
.map(|i| delimiters[*i].clone())
.collect::<Vec<_>>(),
);
for v in indices.iter_mut() {
if *v >= delimiters.len() - 1 {
*v = 0;
break;
} else {
*v += 1;
}
}
}
list
}
fn build_initial_indices(
cont: Option<&String>,
delimiters: impl AsRef<[String]>,
words: impl AsRef<[String]>,
) -> Result<Vec<usize>> {
if let Some(cont) = cont {
let mut splits = vec![cont.clone()];
for delim in delimiters.as_ref().iter() {
splits = splits
.iter()
.flat_map(|s| s.split(delim))
.map(|s| s.to_string())
.collect();
}
let indices = splits
.into_iter()
.map(|s| {
words
.as_ref()
.iter()
.enumerate()
.find(|(_, v)| s == **v)
.map(|(i, _)| i)
.ok_or_else(|| eyre::eyre!("'{}' is not in the word list", s))
})
.collect::<Result<_>>()?;
tracing::info!("Continuing from '{}' -> '{:?}'", cont, &indices);
Ok(indices)
} else {
Ok(vec![0])
}
}
#[tracing::instrument(skip_all)]
#[allow(clippy::mut_range_bound)]
pub(crate) fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()> {
let max_length: usize = matches
.get_one::<usize>("max-length")
.copied()
.expect("parameter has default");
let num_threads: usize = matches
.get_one::<usize>("threads")
.copied()
.expect("parameter has default");
let words = {
let path = matches
.get_one::<PathBuf>("words")
.expect("missing required parameter");
let file = fs::read_to_string(path)
.wrap_err_with(|| format!("Failed to read file '{}'", path.display()))?;
let words: Vec<_> = file.lines().map(str::to_string).collect();
if words.is_empty() {
eyre::bail!("Word list must not be empty");
}
Arc::new(words)
};
let hashes = {
let path = matches
.get_one::<PathBuf>("hashes")
.expect("missing required argument");
let content = fs::read_to_string(path)
.wrap_err_with(|| format!("Failed to read file '{}'", path.display()))?;
let hashes: Result<HashSet<_>, _> = content
.lines()
.map(|s| u64::from_str_radix(s, 16).map(Murmur64::from))
.collect();
let hashes = hashes?;
tracing::trace!("{:?}", hashes);
Arc::new(hashes)
};
let mut delimiters: Vec<String> = matches
.get_many::<String>("delimiter")
.unwrap_or_default()
.cloned()
.collect();
if delimiters.is_empty() {
delimiters.push(String::from("/"));
delimiters.push(String::from("_"));
}
let delimiters_len = delimiters.len();
let word_count = words.len();
tracing::info!("{} words to try", word_count);
// To be able to easily combine the permutations of words and delimiters,
// we turn the latter into a pre-defined list of all permutations of delimiters
// that are possible at the given amount of words.
// Combining `Iterator::cycle` with `Itertools::permutations` works, but
// with a high `max_length`, it runs OOM.
// So we basically have to implement a smaller version of the iterative algorithm we use later on
// to build permutations of the actual words.
let delimiter_lists = {
let lists = build_delimiter_lists(&delimiters, max_length - 1);
Arc::new(lists)
};
tracing::debug!("{:?}", delimiter_lists);
let (info_tx, info_rx) = unbounded();
let (stdout_tx, stdout_rx) = unbounded::<Vec<u8>>();
let (task_tx, task_rx) = bounded::<Vec<usize>>(100);
let mut handles = Vec::new();
for _ in 0..num_threads {
let handle = make_worker(
task_rx.clone(),
State {
delimiter_lists: Arc::clone(&delimiter_lists),
hashes: Arc::clone(&hashes),
words: Arc::clone(&words),
delimiters_len,
stdout_tx: stdout_tx.clone(),
info_tx: info_tx.clone(),
},
);
handles.push(handle);
}
// These are only used inside the worker threads, but due to the loops above, we had to
// clone them one too many times.
// So we drop that extra reference immediately, to ensure that the channels can
// disconnect properly when the threads finish.
drop(stdout_tx);
drop(info_tx);
// handles.push(make_info_printer(info_rx, hashes.len()));
handles.push(make_stdout_printer(stdout_rx));
let mut indices =
build_initial_indices(matches.get_one::<String>("continue"), &delimiters, &*words)
.wrap_err("Failed to build initial indices")?;
let mut indices_len = indices.len();
let mut sequence = indices
.iter()
.map(|index| words[*index].as_str())
.collect::<Vec<_>>();
// Prevent re-allocation by reserving as much as we need upfront
indices.reserve(max_length);
sequence.reserve(max_length);
'outer: loop {
task_tx.send(indices.clone())?;
for i in 0..indices_len { for i in 0..indices_len {
let index = indices.get_mut(i).unwrap(); let index = indices.get_mut(i).unwrap();
let word = sequence.get_mut(i).unwrap(); let word = sequence.get_mut(i).unwrap();
@ -371,5 +503,25 @@ pub(crate) async fn run(_ctx: sdk::Context, matches: &ArgMatches) -> Result<()>
} }
} }
// Dropping the senders will disconnect the channel,
// so that the threads holding the other end will eventually
// complete as well.
drop(task_tx);
tracing::debug!("Wainting for workers to finish.");
for handle in handles {
match handle.join() {
Ok(_) => {}
Err(value) => {
if let Some(err) = value.downcast_ref::<String>() {
eyre::bail!("Thread failed: {}", err);
} else {
eyre::bail!("Thread failed with unknown error: {:?}", value);
}
}
}
}
Ok(()) Ok(())
} }

View file

@ -15,7 +15,9 @@ pub(crate) fn command_definition() -> Command {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub(crate) async fn run(ctx: sdk::Context, matches: &ArgMatches) -> Result<()> { pub(crate) async fn run(ctx: sdk::Context, matches: &ArgMatches) -> Result<()> {
match matches.subcommand() { match matches.subcommand() {
Some(("brute-force-words", sub_matches)) => brute_force_words::run(ctx, sub_matches).await, // It's fine to block here, as this is the only thing that's executing on the runtime.
// The other option with `spawn_blocking` would require setting up values to be Send+Sync.
Some(("brute-force-words", sub_matches)) => brute_force_words::run(ctx, sub_matches),
Some(("extract-words", sub_matches)) => extract_words::run(ctx, sub_matches).await, Some(("extract-words", sub_matches)) => extract_words::run(ctx, sub_matches).await,
_ => unreachable!( _ => unreachable!(
"clap is configured to require a subcommand, and they're all handled above" "clap is configured to require a subcommand, and they're all handled above"