From 09f357d72dc81fbcd4dc4b213a28b4e5771e8ddc Mon Sep 17 00:00:00 2001 From: Lucas Schwiderski Date: Sat, 25 Feb 2023 11:28:37 +0100 Subject: [PATCH] fix(sdk): Prevent duplicates in bundle database bundles Ref #28. --- crates/dtmm/src/engine.rs | 4 +- crates/dtmt/src/cmd/bundle/list.rs | 4 +- crates/dtmt/src/mods/archive.rs | 3 +- lib/sdk/src/bundle/database.rs | 26 +++++++++---- lib/sdk/src/bundle/mod.rs | 12 +++--- lib/sdk/src/murmur/mod.rs | 61 ++++++++++++++++++++++++++++-- 6 files changed, 87 insertions(+), 23 deletions(-) diff --git a/crates/dtmm/src/engine.rs b/crates/dtmm/src/engine.rs index a0e07a2..1981ef4 100644 --- a/crates/dtmm/src/engine.rs +++ b/crates/dtmm/src/engine.rs @@ -202,7 +202,7 @@ fn build_mod_data_lua(state: Arc) -> String { #[tracing::instrument(skip_all)] async fn build_bundles(state: Arc) -> Result<()> { - let mut bundle = Bundle::new(MOD_BUNDLE_NAME.into()); + let mut bundle = Bundle::new(MOD_BUNDLE_NAME); let mut tasks = Vec::new(); let bundle_dir = Arc::new(state.get_game_dir().join("bundle")); @@ -312,7 +312,7 @@ async fn build_bundles(state: Arc) -> Result<()> { db.add_bundle(&bundle); { - let path = bundle_dir.join(format!("{:x}", Murmur64::hash(bundle.name()))); + let path = bundle_dir.join(format!("{:x}", bundle.name().to_murmur64())); tracing::trace!("Writing mod bundle to '{}'", path.display()); fs::write(&path, bundle.to_binary()?) .await diff --git a/crates/dtmt/src/cmd/bundle/list.rs b/crates/dtmt/src/cmd/bundle/list.rs index a206af3..b985ad2 100644 --- a/crates/dtmt/src/cmd/bundle/list.rs +++ b/crates/dtmt/src/cmd/bundle/list.rs @@ -50,13 +50,13 @@ where match fmt { OutputFormat::Text => { - println!("Bundle: {}", bundle.name()); + println!("Bundle: {}", bundle.name().display()); for f in bundle.files().iter() { if f.variants().len() != 1 { let err = eyre::eyre!("Expected exactly one version for this file.") .with_section(|| f.variants().len().to_string().header("Bundle:")) - .with_section(|| bundle.name().clone().header("Bundle:")); + .with_section(|| bundle.name().display().header("Bundle:")); tracing::error!("{:#}", err); } diff --git a/crates/dtmt/src/mods/archive.rs b/crates/dtmt/src/mods/archive.rs index 9f9eaa1..2947640 100644 --- a/crates/dtmt/src/mods/archive.rs +++ b/crates/dtmt/src/mods/archive.rs @@ -5,7 +5,6 @@ use std::path::{Path, PathBuf}; use color_eyre::eyre::{self, Context}; use color_eyre::Result; -use sdk::murmur::Murmur64; use sdk::Bundle; use zip::ZipWriter; @@ -70,7 +69,7 @@ impl Archive { map_entry.insert(file.name(false, None)); } - let name = Murmur64::hash(bundle.name().as_bytes()); + let name = bundle.name().to_murmur64(); let path = base_path.join(name.to_string().to_ascii_lowercase()); zip.start_file(path.to_string_lossy(), Default::default())?; diff --git a/lib/sdk/src/bundle/database.rs b/lib/sdk/src/bundle/database.rs index 8438b40..20d8501 100644 --- a/lib/sdk/src/bundle/database.rs +++ b/lib/sdk/src/bundle/database.rs @@ -38,17 +38,27 @@ pub struct BundleDatabase { impl BundleDatabase { pub fn add_bundle(&mut self, bundle: &Bundle) { - let hash = Murmur64::hash(bundle.name().as_bytes()); + let hash = bundle.name().to_murmur64(); let name = hash.to_string(); let stream = format!("{}.stream", &name); - let file = BundleFile { - name, - stream, - file_time: 0, - platform_specific: false, - }; - self.stored_files.entry(hash).or_default().push(file); + { + let entry = self.stored_files.entry(hash).or_default(); + let existing = entry.iter().position(|f| f.name == name); + + let file = BundleFile { + name, + stream, + file_time: 0, + platform_specific: false, + }; + + entry.push(file); + + if let Some(pos) = existing { + entry.swap_remove(pos); + } + } for f in bundle.files() { let file_name = FileName { diff --git a/lib/sdk/src/bundle/mod.rs b/lib/sdk/src/bundle/mod.rs index 18a4f52..cfc0d06 100644 --- a/lib/sdk/src/bundle/mod.rs +++ b/lib/sdk/src/bundle/mod.rs @@ -8,7 +8,7 @@ use oodle_sys::{OodleLZ_CheckCRC, OodleLZ_FuzzSafe, CHUNK_SIZE}; use crate::binary::sync::*; use crate::bundle::file::Properties; -use crate::murmur::{HashGroup, Murmur64}; +use crate::murmur::{HashGroup, IdString64, Murmur64}; pub(crate) mod database; pub(crate) mod file; @@ -46,13 +46,13 @@ pub struct Bundle { format: BundleFormat, properties: [Murmur64; 32], files: Vec, - name: String, + name: IdString64, } impl Bundle { - pub fn new(name: String) -> Self { + pub fn new>(name: S) -> Self { Self { - name, + name: name.into(), format: BundleFormat::F8, properties: [0.into(); 32], files: Vec::new(), @@ -201,7 +201,7 @@ impl Bundle { } Ok(Self { - name: bundle_name, + name: bundle_name.into(), format, files, properties, @@ -281,7 +281,7 @@ impl Bundle { Ok(w.into_inner()) } - pub fn name(&self) -> &String { + pub fn name(&self) -> &IdString64 { &self.name } diff --git a/lib/sdk/src/murmur/mod.rs b/lib/sdk/src/murmur/mod.rs index 9ea432d..7ede170 100644 --- a/lib/sdk/src/murmur/mod.rs +++ b/lib/sdk/src/murmur/mod.rs @@ -289,9 +289,9 @@ impl IdString64 { } } -impl From for IdString64 { - fn from(value: String) -> Self { - Self::String(value) +impl> From for IdString64 { + fn from(value: S) -> Self { + Self::String(value.into()) } } @@ -313,6 +313,61 @@ impl PartialEq for IdString64 { } } +impl std::hash::Hash for IdString64 { + fn hash(&self, state: &mut H) { + state.write_u64(self.to_murmur64().into()); + } +} + +impl serde::Serialize for IdString64 { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_u64(self.to_murmur64().into()) + } +} + +struct IdString64Visitor; + +impl<'de> serde::de::Visitor<'de> for IdString64Visitor { + type Value = IdString64; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an u64 or a string") + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(IdString64::Hash(value.into())) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(IdString64::String(v.to_string())) + } + + fn visit_string(self, v: String) -> Result + where + E: serde::de::Error, + { + Ok(IdString64::String(v)) + } +} + +impl<'de> serde::Deserialize<'de> for IdString64 { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_u64(IdString64Visitor) + } +} + pub struct IdString64Display(String); impl std::fmt::Display for IdString64Display {