diff --git a/benches/compare.rs b/benches/compare.rs index 4e1cf73c..762b8b64 100644 --- a/benches/compare.rs +++ b/benches/compare.rs @@ -25,11 +25,25 @@ pub struct InternedInput<'db> { pub text: String, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Supertype)] +enum SupertypeInput<'db> { + InternedInput(InternedInput<'db>), + Input(Input), +} + #[salsa::tracked] pub fn interned_length<'db>(db: &'db dyn salsa::Database, input: InternedInput<'db>) -> usize { input.text(db).len() } +#[salsa::tracked] +pub fn either_length<'db>(db: &'db dyn salsa::Database, input: SupertypeInput<'db>) -> usize { + match input { + SupertypeInput::InternedInput(input) => interned_length(db, input), + SupertypeInput::Input(input) => length(db, input), + } +} + fn mutating_inputs(c: &mut Criterion) { let mut group: codspeed_criterion_compat::BenchmarkGroup< codspeed_criterion_compat::measurement::WallTime, @@ -150,6 +164,77 @@ fn inputs(c: &mut Criterion) { ) }); + group.bench_function(BenchmarkId::new("new", "SupertypeInput"), |b| { + b.iter_batched_ref( + || { + let db = salsa::DatabaseImpl::default(); + + // Prepopulate ingredients. + let input = SupertypeInput::Input(Input::new( + black_box(&db), + black_box("hello, world!".to_owned()), + )); + let interned_input = SupertypeInput::InternedInput(InternedInput::new( + black_box(&db), + black_box("hello, world!".to_owned()), + )); + let len = either_length(black_box(&db), black_box(input)); + assert_eq!(black_box(len), 13); + let len = either_length(black_box(&db), black_box(interned_input)); + assert_eq!(black_box(len), 13); + + db + }, + |db| { + let input = SupertypeInput::Input(Input::new( + black_box(db), + black_box("hello, world!".to_owned()), + )); + let interned_input = SupertypeInput::InternedInput(InternedInput::new( + black_box(db), + black_box("hello, world!".to_owned()), + )); + let len = either_length(black_box(db), black_box(input)); + assert_eq!(black_box(len), 13); + let len = either_length(black_box(db), black_box(interned_input)); + assert_eq!(black_box(len), 13); + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function(BenchmarkId::new("amortized", "SupertypeInput"), |b| { + b.iter_batched_ref( + || { + let db = salsa::DatabaseImpl::default(); + + let input = SupertypeInput::Input(Input::new( + black_box(&db), + black_box("hello, world!".to_owned()), + )); + let interned_input = SupertypeInput::InternedInput(InternedInput::new( + black_box(&db), + black_box("hello, world!".to_owned()), + )); + // we can't pass this along otherwise, and the lifetime is generally informational + let interned_input: SupertypeInput<'static> = unsafe { transmute(interned_input) }; + let len = either_length(black_box(&db), black_box(input)); + assert_eq!(black_box(len), 13); + let len = either_length(black_box(&db), black_box(interned_input)); + assert_eq!(black_box(len), 13); + + (db, input, interned_input) + }, + |&mut (ref db, input, interned_input)| { + let len = either_length(black_box(db), black_box(input)); + assert_eq!(black_box(len), 13); + let len = either_length(black_box(db), black_box(interned_input)); + assert_eq!(black_box(len), 13); + }, + BatchSize::SmallInput, + ) + }); + group.finish(); } diff --git a/components/salsa-macro-rules/src/setup_accumulator_impl.rs b/components/salsa-macro-rules/src/setup_accumulator_impl.rs index f10318a2..e8d4da49 100644 --- a/components/salsa-macro-rules/src/setup_accumulator_impl.rs +++ b/components/salsa-macro-rules/src/setup_accumulator_impl.rs @@ -24,7 +24,8 @@ macro_rules! setup_accumulator_impl { fn $ingredient(db: &dyn $zalsa::Database) -> &$zalsa_struct::IngredientImpl<$Struct> { $CACHE.get_or_create(db, || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Struct>>::default()) + db.zalsa() + .add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Struct>>() }) } diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index f89122ac..7274697f 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -89,14 +89,14 @@ macro_rules! setup_input_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db, || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } pub fn ingredient_mut(db: &mut dyn $zalsa::Database) -> (&mut $zalsa_struct::IngredientImpl, &mut $zalsa::Runtime) { let zalsa_mut = db.zalsa_mut(); let current_revision = zalsa_mut.new_revision(); - let index = zalsa_mut.add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()); + let index = zalsa_mut.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>(); let (ingredient, runtime) = zalsa_mut.lookup_ingredient_mut(index); let ingredient = ingredient.assert_type_mut::<$zalsa_struct::IngredientImpl>(); (ingredient, runtime) @@ -135,8 +135,19 @@ macro_rules! setup_input_struct { } impl $zalsa::SalsaStructInDb for $Struct { - fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex; + + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { + if type_id == $zalsa::TypeId::of::<$Struct>() { + $zalsa::Some($Struct(id)) + } else { + $zalsa::None + } } } @@ -198,7 +209,7 @@ macro_rules! setup_input_struct { // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database, { - $Configuration::ingredient(db.as_dyn_database()).get_singleton_input() + $Configuration::ingredient(db.as_dyn_database()).get_singleton_input(db) } #[track_caller] diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index 9af02dd8..668f444e 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -141,7 +141,7 @@ macro_rules! setup_interned_struct { static CACHE: $zalsa::IngredientCache<$zalsa_struct::IngredientImpl<$Configuration>> = $zalsa::IngredientCache::new(); CACHE.get_or_create(db.as_dyn_database(), || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } } @@ -171,8 +171,19 @@ macro_rules! setup_interned_struct { } impl< $($db_lt_arg)? > $zalsa::SalsaStructInDb for $Struct< $($db_lt_arg)? > { - fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex; + + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { + if type_id == $zalsa::TypeId::of::<$Struct>() { + $zalsa::Some(<$Struct as $zalsa::FromId>::from_id(id)) + } else { + $zalsa::None + } } } diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index c5ffc350..5d8a4c2a 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -101,8 +101,19 @@ macro_rules! setup_tracked_fn { $zalsa::IngredientCache::new(); impl $zalsa::SalsaStructInDb for $InternedData<'_> { - fn lookup_ingredient_index(_aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - None + type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex; + + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + $zalsa::IngredientIndices::empty() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: ::core::any::TypeId) -> Option { + if type_id == ::core::any::TypeId::of::<$InternedData>() { + Some($InternedData(id, ::core::marker::PhantomData)) + } else { + None + } } } @@ -132,13 +143,13 @@ macro_rules! setup_tracked_fn { fn fn_ingredient(db: &dyn $Db) -> &$zalsa::function::IngredientImpl<$Configuration> { $FN_CACHE.get_or_create(db.as_dyn_database(), || { ::zalsa_db(db); - db.zalsa().add_or_lookup_jar_by_type(&$Configuration) + db.zalsa().add_or_lookup_jar_by_type::<$Configuration>() }) } pub fn fn_ingredient_mut(db: &mut dyn $Db) -> &mut $zalsa::function::IngredientImpl { let zalsa_mut = db.zalsa_mut(); - let index = zalsa_mut.add_or_lookup_jar_by_type(&$Configuration); + let index = zalsa_mut.add_or_lookup_jar_by_type::<$Configuration>(); let (ingredient, _) = zalsa_mut.lookup_ingredient_mut(index); ingredient.assert_type_mut::<$zalsa::function::IngredientImpl>() } @@ -148,7 +159,7 @@ macro_rules! setup_tracked_fn { db: &dyn $Db, ) -> &$zalsa::interned::IngredientImpl<$Configuration> { $INTERN_CACHE.get_or_create(db.as_dyn_database(), || { - db.zalsa().add_or_lookup_jar_by_type(&$Configuration).successor(0) + db.zalsa().add_or_lookup_jar_by_type::<$Configuration>().successor(0) }) } } @@ -201,33 +212,43 @@ macro_rules! setup_tracked_fn { if $needs_interner { $Configuration::intern_ingredient(db).data(db.as_dyn_database(), key).clone() } else { - $zalsa::FromId::from_id(key) + $zalsa::FromIdWithDb::from_id(key, db) } } } } impl $zalsa::Jar for $Configuration { + fn create_dependencies(zalsa: &$zalsa::Zalsa) -> $zalsa::IngredientIndices + where + Self: Sized + { + $zalsa::macro_if! { + if $needs_interner { + $zalsa::IngredientIndices::empty() + } else { + <$InternedData as $zalsa::SalsaStructInDb>::lookup_or_create_ingredient_index(zalsa) + } + } + } + fn create_ingredients( - &self, - aux: &dyn $zalsa::JarAux, + zalsa: &$zalsa::Zalsa, first_index: $zalsa::IngredientIndex, + struct_index: $zalsa::IngredientIndices, ) -> Vec> { - let struct_index = $zalsa::macro_if! { + let struct_index: $zalsa::IngredientIndices = $zalsa::macro_if! { if $needs_interner { - first_index.successor(0) + first_index.successor(0).into() } else { - <$InternedData as $zalsa::SalsaStructInDb>::lookup_ingredient_index(aux) - .expect( - "Salsa struct is passed as an argument of a tracked function, but its ingredient hasn't been added!" - ) + struct_index } }; + let memo_ingredient_indices = From::from((zalsa, struct_index, first_index)); let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new( - struct_index, first_index, - aux, + memo_ingredient_indices, $lru ); $zalsa::macro_if! { @@ -246,8 +267,8 @@ macro_rules! setup_tracked_fn { } } - fn salsa_struct_type_id(&self) -> Option { - None + fn id_struct_type_id() -> $zalsa::TypeId { + $zalsa::TypeId::of::<$InternedData<'static>>() } } diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index 7e06ac68..8a448380 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -180,7 +180,7 @@ macro_rules! setup_tracked_struct { $zalsa::IngredientCache::new(); CACHE.get_or_create(db, || { - db.zalsa().add_or_lookup_jar_by_type(&<$zalsa_struct::JarImpl::<$Configuration>>::default()) + db.zalsa().add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>() }) } } @@ -198,8 +198,19 @@ macro_rules! setup_tracked_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { - fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { - aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + type MemoIngredientMap = $zalsa::MemoIngredientSingletonIndex; + + fn lookup_or_create_ingredient_index(aux: &$zalsa::Zalsa) -> $zalsa::IngredientIndices { + aux.add_or_lookup_jar_by_type::<$zalsa_struct::JarImpl<$Configuration>>().into() + } + + #[inline] + fn cast(id: $zalsa::Id, type_id: $zalsa::TypeId) -> $zalsa::Option { + if type_id == $zalsa::TypeId::of::<$Struct>() { + $zalsa::Some(<$Struct as $zalsa::FromId>::from_id(id)) + } else { + $zalsa::None + } } } diff --git a/components/salsa-macros/src/lib.rs b/components/salsa-macros/src/lib.rs index 2b2de522..e86c1bfc 100644 --- a/components/salsa-macros/src/lib.rs +++ b/components/salsa-macros/src/lib.rs @@ -44,6 +44,7 @@ mod input; mod interned; mod options; mod salsa_struct; +mod supertype; mod tracked; mod tracked_fn; mod tracked_impl; @@ -66,6 +67,11 @@ pub fn interned(args: TokenStream, input: TokenStream) -> TokenStream { interned::interned(args, input) } +#[proc_macro_derive(Supertype)] +pub fn supertype(input: TokenStream) -> TokenStream { + supertype::supertype(input) +} + #[proc_macro_attribute] pub fn input(args: TokenStream, input: TokenStream) -> TokenStream { input::input(args, input) diff --git a/components/salsa-macros/src/supertype.rs b/components/salsa-macros/src/supertype.rs new file mode 100644 index 00000000..a67aa92c --- /dev/null +++ b/components/salsa-macros/src/supertype.rs @@ -0,0 +1,105 @@ +use crate::token_stream_with_error; +use proc_macro2::TokenStream; + +/// The implementation of the `supertype` macro. +/// +/// For an entity enum `Foo` with variants `Variant1, ..., VariantN`, we generate +/// mappings between the variants and their corresponding supertypes. +pub(crate) fn supertype(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let enum_item = parse_macro_input!(input as syn::ItemEnum); + match enum_impl(enum_item) { + Ok(v) => v.into(), + Err(e) => token_stream_with_error(input, e), + } +} + +fn enum_impl(enum_item: syn::ItemEnum) -> syn::Result { + let enum_name = enum_item.ident.clone(); + let mut variant_names = Vec::new(); + let mut variant_types = Vec::new(); + if enum_item.variants.is_empty() { + return Err(syn::Error::new( + enum_item.enum_token.span, + "empty enums are not permitted", + )); + } + for variant in &enum_item.variants { + let valid = match &variant.fields { + syn::Fields::Unnamed(fields) => { + variant_names.push(variant.ident.clone()); + variant_types.push(fields.unnamed[0].ty.clone()); + fields.unnamed.len() == 1 + } + syn::Fields::Unit | syn::Fields::Named(_) => false, + }; + if !valid { + return Err(syn::Error::new( + variant.ident.span(), + "the only form allowed is `Variant(SalsaStruct)`", + )); + } + } + + let (impl_generics, type_generics, where_clause) = enum_item.generics.split_for_impl(); + + let as_id = quote! { + impl #impl_generics zalsa::AsId for #enum_name #type_generics + #where_clause { + #[inline] + fn as_id(&self) -> zalsa::Id { + match self { + #( Self::#variant_names(__v) => zalsa::AsId::as_id(__v), )* + } + } + } + }; + + let from_id = quote! { + impl #impl_generics zalsa::FromIdWithDb for #enum_name #type_generics + #where_clause { + #[inline] + fn from_id(__id: zalsa::Id, __db: &(impl ?Sized + zalsa::Database)) -> Self { + let __zalsa = __db.zalsa(); + let __type_id = __zalsa.lookup_page_type_id(__id); + ::cast(__id, __type_id).expect("invalid enum variant") + } + } + }; + + let salsa_struct_in_db = quote! { + impl #impl_generics zalsa::SalsaStructInDb for #enum_name #type_generics + #where_clause { + type MemoIngredientMap = zalsa::MemoIngredientIndices; + + #[inline] + fn lookup_or_create_ingredient_index(__zalsa: &zalsa::Zalsa) -> zalsa::IngredientIndices { + zalsa::IngredientIndices::merge([ #( <#variant_types as zalsa::SalsaStructInDb>::lookup_or_create_ingredient_index(__zalsa) ),* ]) + } + + #[inline] + fn cast(id: zalsa::Id, type_id: ::core::any::TypeId) -> Option { + #( + // Subtle: the ingredient can be missing, but in this case the id cannot come + // from it - because it wasn't initialized yet. + if let Some(result) = <#variant_types as zalsa::SalsaStructInDb>::cast(id, type_id) { + Some(Self::#variant_names(result)) + } else + )* + { + None + } + } + } + }; + + let all_impls = quote! { + const _: () = { + use salsa::plumbing as zalsa; + + #as_id + #from_id + #salsa_struct_in_db + }; + }; + Ok(all_impls) +} diff --git a/src/accumulator.rs b/src/accumulator.rs index d96c6b3b..725f6086 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -1,7 +1,7 @@ //! Basic test of accumulator functionality. use std::{ - any::Any, + any::{Any, TypeId}, fmt::{self, Debug}, marker::PhantomData, panic::UnwindSafe, @@ -13,8 +13,8 @@ use accumulated::AnyAccumulated; use crate::{ cycle::CycleRecoveryStrategy, ingredient::{fmt_index, Ingredient, Jar, MaybeChangedAfter}, - plumbing::JarAux, - zalsa::IngredientIndex, + plumbing::IngredientIndices, + zalsa::{IngredientIndex, Zalsa}, zalsa_local::QueryOrigin, Database, DatabaseKeyIndex, Id, Revision, }; @@ -47,15 +47,15 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - &self, - _aux: &dyn JarAux, + _zalsa: &Zalsa, first_index: IngredientIndex, + _dependencies: IngredientIndices, ) -> Vec> { vec![Box::new(>::new(first_index))] } - fn salsa_struct_type_id(&self) -> Option { - None + fn id_struct_type_id() -> TypeId { + TypeId::of::() } } @@ -70,9 +70,8 @@ impl IngredientImpl { where Db: ?Sized + Database, { - let jar: JarImpl = Default::default(); let zalsa = db.zalsa(); - let index = zalsa.add_or_lookup_jar_by_type(&jar); + let index = zalsa.add_or_lookup_jar_by_type::>(); let ingredient = zalsa.lookup_ingredient(index).assert_type::(); Some(ingredient) } diff --git a/src/function.rs b/src/function.rs index f2bb267a..3c1f31ce 100644 --- a/src/function.rs +++ b/src/function.rs @@ -5,7 +5,7 @@ use crate::{ cycle::CycleRecoveryStrategy, ingredient::{fmt_index, MaybeChangedAfter}, key::DatabaseKeyIndex, - plumbing::JarAux, + plumbing::MemoIngredientMap, salsa_struct::SalsaStructInDb, table::Table, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, @@ -96,7 +96,11 @@ pub struct IngredientImpl { index: IngredientIndex, /// The index for the memo/sync tables - memo_ingredient_index: MemoIngredientIndex, + /// + /// This may be a [`crate::memo_ingredient_indices::MemoIngredientSingletonIndex`] or a + /// [`crate::memo_ingredient_indices::MemoIngredientIndices`], depending on whether the + /// tracked function's struct is a plain salsa struct or an enum `#[derive(Supertype)]`. + memo_ingredient_indices: as SalsaStructInDb>::MemoIngredientMap, /// Used to find memos to throw out when we have too many memoized values. lru: lru::Lru, @@ -128,14 +132,13 @@ where C: Configuration, { pub fn new( - struct_index: IngredientIndex, index: IngredientIndex, - aux: &dyn JarAux, + memo_ingredient_indices: as SalsaStructInDb>::MemoIngredientMap, lru: usize, ) -> Self { Self { index, - memo_ingredient_index: aux.next_memo_ingredient_index(struct_index, index), + memo_ingredient_indices, lru: lru::Lru::new(lru), deleted_entries: Default::default(), } @@ -171,6 +174,7 @@ where zalsa: &'db Zalsa, id: Id, memo: memo::Memo>, + memo_ingredient_index: MemoIngredientIndex, ) -> &'db memo::Memo> { let memo = Arc::new(memo); // Unsafety conditions: memo must be in the map (it's not yet, but it will be by the time this @@ -178,7 +182,9 @@ where let db_memo = unsafe { self.extend_memo_lifetime(&memo) }; // Safety: We delay the drop of `old_value` until a new revision starts which ensures no // references will exist for the memo contents. - if let Some(old_value) = unsafe { self.insert_memo_into_table_for(zalsa, id, memo) } { + if let Some(old_value) = + unsafe { self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index) } + { // In case there is a reference to the old memo out there, we have to store it // in the deleted entries. This will get cleared when a new revision starts. self.deleted_entries @@ -186,6 +192,11 @@ where } db_memo } + + #[inline] + fn memo_ingredient_index(&self, zalsa: &Zalsa, id: Id) -> MemoIngredientIndex { + self.memo_ingredient_indices.get_zalsa_id(zalsa, id) + } } impl Ingredient for IngredientImpl @@ -240,10 +251,11 @@ where fn reset_for_new_revision(&mut self, table: &mut Table) { self.lru.for_each_evicted(|evict| { + let ingredient_index = table.ingredient_index(evict); Self::evict_value_from_memo_for( table.memos_mut(evict), &self.deleted_entries, - self.memo_ingredient_index, + self.memo_ingredient_indices.get(ingredient_index), ) }); std::mem::take(&mut self.deleted_entries); diff --git a/src/function/execute.rs b/src/function/execute.rs index 185bea33..843047fa 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -84,6 +84,12 @@ where tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}"); - self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)) + let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); + self.insert_memo( + zalsa, + id, + Memo::new(Some(value), revision_now, revisions), + memo_ingredient_index, + ) } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 569cb970..6164a5b3 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,4 +1,5 @@ use super::{memo::Memo, Configuration, IngredientImpl}; +use crate::zalsa::MemoIngredientIndex; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, runtime::StampedValue, @@ -43,10 +44,11 @@ where id: Id, ) -> &'db Memo> { let zalsa = db.zalsa(); + let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); loop { if let Some(memo) = self - .fetch_hot(zalsa, db, id) - .or_else(|| self.fetch_cold(zalsa, db, id)) + .fetch_hot(zalsa, db, id, memo_ingredient_index) + .or_else(|| self.fetch_cold(zalsa, db, id, memo_ingredient_index)) { return memo; } @@ -59,8 +61,9 @@ where zalsa: &'db Zalsa, db: &'db C::DbView, id: Id, + memo_ingredient_index: MemoIngredientIndex, ) -> Option<&'db Memo>> { - let memo_guard = self.get_memo_from_table_for(zalsa, id); + let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = &memo_guard { if memo.value.is_some() && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo) @@ -78,6 +81,7 @@ where zalsa: &'db Zalsa, db: &'db C::DbView, id: Id, + memo_ingredient_index: MemoIngredientIndex, ) -> Option<&'db Memo>> { let zalsa_local = db.zalsa_local(); let database_key_index = self.database_key_index(id); @@ -88,14 +92,14 @@ where zalsa, zalsa_local, database_key_index, - self.memo_ingredient_index, + memo_ingredient_index, )?; // Push the query on the stack. let active_query = zalsa_local.push_query(database_key_index); // Now that we've claimed the item, check again to see if there's a "hot" value. - let opt_old_memo = self.get_memo_from_table_for(zalsa, id); + let opt_old_memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(old_memo) = &opt_old_memo { if old_memo.value.is_some() && self.deep_verify_memo(db, zalsa, old_memo, &active_query) { diff --git a/src/function/inputs.rs b/src/function/inputs.rs index 8dce73da..40060ddd 100644 --- a/src/function/inputs.rs +++ b/src/function/inputs.rs @@ -7,7 +7,8 @@ where C: Configuration, { pub(super) fn origin(&self, zalsa: &Zalsa, key: Id) -> Option { - self.get_memo_from_table_for(zalsa, key) + let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); + self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) .map(|m| m.revisions.origin.clone()) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 2c964656..6c35da58 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -2,8 +2,7 @@ use crate::{ accumulator::accumulated_map::InputAccumulatedValues, ingredient::MaybeChangedAfter, key::DatabaseKeyIndex, - plumbing::ZalsaLocal, - zalsa::{Zalsa, ZalsaDatabase}, + zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, QueryEdge, QueryOrigin}, AsDynDatabase as _, Id, Revision, }; @@ -21,6 +20,7 @@ where revision: Revision, ) -> MaybeChangedAfter { let zalsa = db.zalsa(); + let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); zalsa.unwind_if_revision_cancelled(db); loop { @@ -29,7 +29,7 @@ where tracing::debug!("{database_key_index:?}: maybe_changed_after(revision = {revision:?})"); // Check if we have a verified version: this is the hot path. - let memo_guard = self.get_memo_from_table_for(zalsa, id); + let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = &memo_guard { if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { return if memo.revisions.changed_at > revision { @@ -40,7 +40,7 @@ where } drop(memo_guard); // release the arc-swap guard before cold path if let Some(mcs) = - self.maybe_changed_after_cold(zalsa, db.zalsa_local(), db, id, revision) + self.maybe_changed_after_cold(zalsa, db, id, revision, memo_ingredient_index) { return mcs; } else { @@ -56,24 +56,26 @@ where fn maybe_changed_after_cold<'db>( &'db self, zalsa: &Zalsa, - zalsa_local: &ZalsaLocal, db: &'db C::DbView, key_index: Id, revision: Revision, + memo_ingredient_index: MemoIngredientIndex, ) -> Option { let database_key_index = self.database_key_index(key_index); + let zalsa_local = db.zalsa_local(); let _claim_guard = zalsa.sync_table_for(key_index).claim( db.as_dyn_database(), zalsa, zalsa_local, database_key_index, - self.memo_ingredient_index, + memo_ingredient_index, )?; let active_query = zalsa_local.push_query(database_key_index); // Load the current memo, if any. - let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { + let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) + else { return Some(MaybeChangedAfter::Yes); }; diff --git a/src/function/memo.rs b/src/function/memo.rs index ae2a2099..e1dd8830 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -46,12 +46,13 @@ impl IngredientImpl { zalsa: &'db Zalsa, id: Id, memo: ArcMemo<'db, C>, + memo_ingredient_index: MemoIngredientIndex, ) -> Option>> { let static_memo = unsafe { self.to_static(memo) }; let old_static_memo = unsafe { zalsa .memo_table_for(id) - .insert(self.memo_ingredient_index, static_memo) + .insert(memo_ingredient_index, static_memo) }?; let old_static_memo = ManuallyDrop::into_inner(old_static_memo); Some(ManuallyDrop::new(unsafe { self.to_self(old_static_memo) })) @@ -64,8 +65,9 @@ impl IngredientImpl { &'db self, zalsa: &'db Zalsa, id: Id, + memo_ingredient_index: MemoIngredientIndex, ) -> Option> { - let static_memo = zalsa.memo_table_for(id).get(self.memo_ingredient_index)?; + let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?; unsafe { Some(self.to_self(static_memo)) } } diff --git a/src/function/specify.rs b/src/function/specify.rs index 85e721f5..33284c05 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -73,7 +73,8 @@ where accumulated_inputs: Default::default(), }; - if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { + let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); + if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { self.backdate_if_appropriate(&old_memo, &mut revisions, &value); self.diff_outputs(zalsa, db, database_key_index, &old_memo, &mut revisions); } @@ -89,7 +90,7 @@ where memo.tracing_debug(), key ); - self.insert_memo(zalsa, key, memo); + self.insert_memo(zalsa, key, memo, memo_ingredient_index); // Record that the current query *specified* a value for this cell. let database_key_index = self.database_key_index(key); @@ -107,8 +108,9 @@ where key: Id, ) { let zalsa = db.zalsa(); + let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); - let memo = match self.get_memo_from_table_for(zalsa, key) { + let memo = match self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { Some(m) => m, None => return, }; diff --git a/src/id.rs b/src/id.rs index 06b54b88..c1c57b1b 100644 --- a/src/id.rs +++ b/src/id.rs @@ -2,6 +2,8 @@ use std::fmt::Debug; use std::hash::Hash; use std::num::NonZeroU32; +use crate::Database; + /// The `Id` of a salsa struct in the database [`Table`](`crate::table::Table`). /// /// The higher-order bits of an `Id` identify a [`Page`](`crate::table::Page`) @@ -72,3 +74,16 @@ impl FromId for Id { id } } + +/// Enums cannot use [`FromId`] because they need access to the DB to tell the `TypeId` of the variant, +/// so they use this trait instead, that has a blanket implementation for `FromId`. +pub trait FromIdWithDb: AsId + Copy + Eq + Hash + Debug { + fn from_id(id: Id, db: &(impl ?Sized + Database)) -> Self; +} + +impl FromIdWithDb for T { + #[inline] + fn from_id(id: Id, _db: &(impl ?Sized + Database)) -> Self { + FromId::from_id(id) + } +} diff --git a/src/ingredient.rs b/src/ingredient.rs index 5ccf2eae..037e004d 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,8 +6,9 @@ use std::{ use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, cycle::CycleRecoveryStrategy, + plumbing::IngredientIndices, table::Table, - zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, MemoIngredientIndex}, + zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, Zalsa}, zalsa_local::QueryOrigin, Database, DatabaseKeyIndex, Id, }; @@ -17,40 +18,33 @@ use super::Revision; /// A "jar" is a group of ingredients that are added atomically. /// Each type implementing jar can be added to the database at most once. pub trait Jar: Any { + /// This creates the ingredient dependencies of this jar. We need to split this from `create_ingredients()` + /// because while `create_ingredients()` is called, a lock on the ingredient map is held (to guarantee + /// atomicity), so other ingredients could not be created. + /// + /// Only tracked fns use this. + fn create_dependencies(_zalsa: &Zalsa) -> IngredientIndices + where + Self: Sized, + { + IngredientIndices::empty() + } + /// Create the ingredients given the index of the first one. /// All subsequent ingredients will be assigned contiguous indices. fn create_ingredients( - &self, - aux: &dyn JarAux, + zalsa: &Zalsa, first_index: IngredientIndex, - ) -> Vec>; - - /// If this jar's first ingredient is a salsa struct, return its `TypeId` - fn salsa_struct_type_id(&self) -> Option; -} - -/// Methods on the Salsa database available to jars while they are creating their ingredients. -pub trait JarAux { - /// Return index of first ingredient from `jar` (based on the dynamic type of `jar`). - /// Returns `None` if the jar has not yet been added. - /// Used by tracked functions to lookup the ingredient index for the salsa struct they take as argument. - fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option; - - /// Returns the memo ingredient index that should be used to attach data from the given tracked function - /// to the given salsa struct (which the fn accepts as argument). - /// - /// The memo ingredient indices for a given function must be distinct from the memo indices - /// of all other functions that take the same salsa struct. - /// - /// # Parameters - /// - /// * `struct_ingredient_index`, the index of the salsa struct the memo will be attached to - /// * `ingredient_index`, the index of the tracked function whose data is stored in the memo - fn next_memo_ingredient_index( - &self, - struct_ingredient_index: IngredientIndex, - ingredient_index: IngredientIndex, - ) -> MemoIngredientIndex; + dependencies: IngredientIndices, + ) -> Vec> + where + Self: Sized; + + /// This returns the [`TypeId`] of the ID struct, that is, the struct that wraps `salsa::Id` + /// and carry the name of the jar. + fn id_struct_type_id() -> TypeId + where + Self: Sized; } pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { diff --git a/src/input.rs b/src/input.rs index e064bf62..b8b7473b 100644 --- a/src/input.rs +++ b/src/input.rs @@ -13,11 +13,11 @@ use input_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, - id::{AsId, FromId}, + id::{AsId, FromIdWithDb}, ingredient::{fmt_index, Ingredient, MaybeChangedAfter}, input::singleton::{Singleton, SingletonChoice}, key::{DatabaseKeyIndex, InputDependencyIndex}, - plumbing::{Jar, JarAux, Stamp}, + plumbing::{Jar, Stamp}, table::{memo::MemoTable, sync::SyncTable, Slot, Table}, zalsa::{IngredientIndex, Zalsa}, zalsa_local::QueryOrigin, @@ -32,7 +32,7 @@ pub trait Configuration: Any { type Singleton: SingletonChoice + Send + Sync; /// The input struct (which wraps an `Id`) - type Struct: FromId + 'static + Send + Sync; + type Struct: FromIdWithDb + 'static + Send + Sync; /// A (possibly empty) tuple of the fields for this struct. type Fields: Send + Sync; @@ -55,9 +55,9 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - &self, - _aux: &dyn JarAux, + _zalsa: &Zalsa, struct_index: crate::zalsa::IngredientIndex, + _dependencies: crate::memo_ingredient_indices::IngredientIndices, ) -> Vec> { let struct_ingredient: IngredientImpl = IngredientImpl::new(struct_index); @@ -68,8 +68,8 @@ impl Jar for JarImpl { .collect() } - fn salsa_struct_type_id(&self) -> Option { - Some(TypeId::of::<::Struct>()) + fn id_struct_type_id() -> TypeId { + TypeId::of::() } } @@ -115,7 +115,7 @@ impl IngredientImpl { }) }); - FromId::from_id(id) + FromIdWithDb::from_id(id, db) } /// Change the value of the field `field_index` to a new value. @@ -153,12 +153,14 @@ impl IngredientImpl { setter(&mut r.fields) } - /// Get the singleton input previously created. - pub fn get_singleton_input(&self) -> Option + /// Get the singleton input previously created (if any). + pub fn get_singleton_input(&self, db: &(impl ?Sized + Database)) -> Option where C: Configuration, { - self.singleton.index().map(FromId::from_id) + self.singleton + .index() + .map(|id| FromIdWithDb::from_id(id, db)) } /// Access field of an input. diff --git a/src/interned.rs b/src/interned.rs index 41e1f0ab..9340323b 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -4,11 +4,11 @@ use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::durability::Durability; use crate::ingredient::{fmt_index, MaybeChangedAfter}; use crate::key::InputDependencyIndex; -use crate::plumbing::{Jar, JarAux}; +use crate::plumbing::{IngredientIndices, Jar}; use crate::table::memo::MemoTable; use crate::table::sync::SyncTable; use crate::table::Slot; -use crate::zalsa::IngredientIndex; +use crate::zalsa::{IngredientIndex, Zalsa}; use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id}; use std::any::TypeId; @@ -100,15 +100,15 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - &self, - _aux: &dyn JarAux, + _zalsa: &Zalsa, first_index: IngredientIndex, + _dependencies: IngredientIndices, ) -> Vec> { vec![Box::new(IngredientImpl::::new(first_index)) as _] } - fn salsa_struct_type_id(&self) -> Option { - Some(TypeId::of::<::Struct<'static>>()) + fn id_struct_type_id() -> TypeId { + TypeId::of::>() } } diff --git a/src/lib.rs b/src/lib.rs index 9b985d4c..bf6ff4e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ mod ingredient; mod input; mod interned; mod key; +mod memo_ingredient_indices; mod nonce; #[cfg(feature = "rayon")] mod par_map; @@ -56,6 +57,7 @@ pub use salsa_macros::db; pub use salsa_macros::input; pub use salsa_macros::interned; pub use salsa_macros::tracked; +pub use salsa_macros::Supertype; pub use salsa_macros::Update; pub mod prelude { @@ -70,6 +72,9 @@ pub mod prelude { /// /// The contents of this module are NOT subject to semver. pub mod plumbing { + pub use std::any::TypeId; + pub use std::option::Option::{self, None, Some}; + pub use crate::accumulator::Accumulator; pub use crate::array::Array; pub use crate::attach::attach; @@ -81,11 +86,14 @@ pub mod plumbing { pub use crate::function::should_backdate_value; pub use crate::id::AsId; pub use crate::id::FromId; + pub use crate::id::FromIdWithDb; pub use crate::id::Id; pub use crate::ingredient::Ingredient; pub use crate::ingredient::Jar; - pub use crate::ingredient::JarAux; pub use crate::key::DatabaseKeyIndex; + pub use crate::memo_ingredient_indices::{ + IngredientIndices, MemoIngredientIndices, MemoIngredientMap, MemoIngredientSingletonIndex, + }; pub use crate::revision::Revision; pub use crate::runtime::stamp; pub use crate::runtime::Runtime; diff --git a/src/memo_ingredient_indices.rs b/src/memo_ingredient_indices.rs new file mode 100644 index 00000000..a784b4ea --- /dev/null +++ b/src/memo_ingredient_indices.rs @@ -0,0 +1,129 @@ +use crate::zalsa::{MemoIngredientIndex, Zalsa}; +use crate::{Id, IngredientIndex}; + +/// An ingredient has an [ingredient index][IngredientIndex]. However, Salsa also supports +/// enums of salsa structs (and other salsa enums), and those don't have a constant ingredient index, +/// because they are not ingredients by themselves but rather composed of them. However, an enum can +/// be viewed as a *set* of [`IngredientIndex`], where each instance of the enum can belong +/// to one, potentially different, index. This is what this type represents: a set of +/// `IngredientIndex`. +#[derive(Clone)] +pub struct IngredientIndices { + indices: Box<[IngredientIndex]>, +} + +impl From for IngredientIndices { + #[inline] + fn from(value: IngredientIndex) -> Self { + Self { + indices: Box::new([value]), + } + } +} + +impl IngredientIndices { + #[inline] + pub fn empty() -> Self { + Self { + indices: Box::default(), + } + } + + pub fn merge(iter: impl IntoIterator) -> Self { + let mut indices = Vec::new(); + for index in iter { + indices.extend(index.indices); + } + indices.sort_unstable(); + indices.dedup(); + Self { + indices: indices.into_boxed_slice(), + } + } +} + +impl From<(&Zalsa, IngredientIndices, IngredientIndex)> for MemoIngredientIndices { + #[inline] + fn from( + (zalsa, struct_indices, ingredient): (&Zalsa, IngredientIndices, IngredientIndex), + ) -> Self { + let Some(&last) = struct_indices.indices.last() else { + unreachable!("Attempting to construct struct memo mapping for non tracked function?") + }; + let mut indices = Vec::new(); + indices.resize( + last.as_usize() + 1, + MemoIngredientIndex::from_usize((u32::MAX - 1) as usize), + ); + for &struct_ingredient in &struct_indices.indices { + indices[struct_ingredient.as_usize()] = + zalsa.next_memo_ingredient_index(struct_ingredient, ingredient); + } + MemoIngredientIndices { + indices: indices.into_boxed_slice(), + } + } +} + +/// This type is to [`MemoIngredientIndex`] what [`IngredientIndices`] is to [`IngredientIndex`]: +/// since enums can contain different ingredient indices, they can also have different memo indices, +/// so we need to keep track of them. +/// +/// This acts a map from [`IngredientIndex`] to [`MemoIngredientIndex`] but implemented +/// via a slice for fast lookups, trading memory for speed. With these changes, lookups are `O(1)` +/// instead of `O(n)`. +/// +/// A database tends to have few ingredients (i), less function ingredients and even less +/// function ingredients targeting `#[derive(Supertype)]` enums (e). +/// While this is bounded as `O(i * e)` memory usage, the average case is significantly smaller: a +/// function ingredient targeting enums only stores a slice whose length corresponds to the largest +/// ingredient index's _value_. For example, if we have the ingredient indices `[2, 6, 17]`, then we +/// will allocate a slice whose length is `17 + 1`. +/// +/// Assuming a heavy example scenario of 1000 ingredients (500 of which are function ingredients, 100 +/// of which are enum targeting functions) this would come out to a maximum possibly memory usage of +/// 4bytes * 1000 * 100 ~= 0.38MB which is negligible. +pub struct MemoIngredientIndices { + indices: Box<[MemoIngredientIndex]>, +} + +impl MemoIngredientMap for MemoIngredientIndices { + #[inline(always)] + fn get_zalsa_id(&self, zalsa: &Zalsa, id: Id) -> MemoIngredientIndex { + self.get(zalsa.ingredient_index(id)) + } + #[inline(always)] + fn get(&self, index: IngredientIndex) -> MemoIngredientIndex { + self.indices[index.as_usize()] + } +} + +#[derive(Debug)] +pub struct MemoIngredientSingletonIndex(MemoIngredientIndex); + +impl MemoIngredientMap for MemoIngredientSingletonIndex { + #[inline(always)] + fn get_zalsa_id(&self, _: &Zalsa, _: Id) -> MemoIngredientIndex { + self.0 + } + #[inline(always)] + fn get(&self, _: IngredientIndex) -> MemoIngredientIndex { + self.0 + } +} + +impl From<(&Zalsa, IngredientIndices, IngredientIndex)> for MemoIngredientSingletonIndex { + #[inline] + fn from((zalsa, indices, ingredient): (&Zalsa, IngredientIndices, IngredientIndex)) -> Self { + let &[struct_ingredient] = &*indices.indices else { + unreachable!("Attempting to construct struct memo mapping from enum?") + }; + + Self(zalsa.next_memo_ingredient_index(struct_ingredient, ingredient)) + } +} + +pub trait MemoIngredientMap: Send + Sync { + fn get_zalsa_id(&self, zalsa: &Zalsa, id: Id) -> MemoIngredientIndex; + fn get(&self, index: IngredientIndex) -> MemoIngredientIndex; +} diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index 8674dc12..642dbe75 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -1,5 +1,65 @@ -use crate::{plumbing::JarAux, IngredientIndex}; +use std::any::TypeId; -pub trait SalsaStructInDb { - fn lookup_ingredient_index(aux: &dyn JarAux) -> Option; +use crate::memo_ingredient_indices::{IngredientIndices, MemoIngredientMap}; +use crate::zalsa::Zalsa; +use crate::Id; + +pub trait SalsaStructInDb: Sized { + type MemoIngredientMap: MemoIngredientMap; + + /// Lookup or create ingredient indices. + /// + /// Note that this method does *not* create the ingredients themselves, this is handled by + /// [`Zalsa::add_or_lookup_jar_by_type()`]. This method only creates + /// or looks up the indices corresponding to the ingredients. + /// + /// While implementors of this trait may call [`Zalsa::add_or_lookup_jar_by_type()`] + /// to create the ingredient, they aren't required to. For example, supertypes recursively + /// call [`Zalsa::add_or_lookup_jar_by_type()`] for their variants and combine them. + fn lookup_or_create_ingredient_index(zalsa: &Zalsa) -> IngredientIndices; + + /// Plumbing to support nested salsa supertypes. + /// + /// In the example below, there are two supertypes: `InnerEnum` and `OuterEnum`, + /// where the former is a supertype of `Input` and `Interned1` and the latter + /// is a supertype of `InnerEnum` and `Interned2`. + /// + /// ```ignore + /// #[salsa::input] + /// struct Input {} + /// + /// #[salsa::interned] + /// struct Interned1 {} + /// + /// #[salsa::interned] + /// struct Interned2 {} + /// + /// #[derive(Debug, salsa::Enum)] + /// enum InnerEnum { + /// Input(Input), + /// Interned1(Interned1), + /// } + /// + /// #[derive(Debug, salsa::Enum)] + /// enum OuterEnum { + /// InnerEnum(InnerEnum), + /// Interned2(Interned2), + /// } + /// ``` + /// + /// Imagine `OuterEnum` got a [`salsa::Id`][Id] and it wants to know which variant it belongs to. + /// + /// `OuterEnum` cannot ask each variant "what is your ingredient index?" and compare because `InnerEnum` + /// has *multiple*, possible ingredient indices. Alternatively, `OuterEnum` could ask eaach variant + /// "is this value yours?" and then invoke [`FromId`][crate::id::FromId] with the correct variant, + /// but this duplicates work: now, `InnerEnum` will have to repeat this check-and-cast for *its* + /// variants. + /// + /// Instead, the implementor keeps track of the [`std::any::TypeId`] of the ID struct, and ask each + /// variant to "cast" to it. If it succeeds, `cast` returns that value; if not, we + /// go to the next variant. + /// + /// Why `TypeId` and not `IngredientIndex`? Because it's cheaper and easier: the `TypeId` is readily + /// available at compile time, while the `IngredientIndex` requires a runtime lookup. + fn cast(id: Id, type_id: TypeId) -> Option; } diff --git a/src/table.rs b/src/table.rs index 07b84e75..3abfb27b 100644 --- a/src/table.rs +++ b/src/table.rs @@ -32,6 +32,8 @@ pub struct Table { pub(crate) trait TablePage: Any + Send + Sync { fn hidden_type_name(&self) -> &'static str; + fn ingredient_index(&self) -> IngredientIndex; + /// Access the memos attached to `slot`. /// /// # Safety condition @@ -52,7 +54,6 @@ pub(crate) trait TablePage: Any + Send + Sync { pub(crate) struct Page { /// The ingredient for elements on this page. - #[allow(dead_code)] // pretty sure we'll need this ingredient: IngredientIndex, /// Number of elements of `data` that are initialized. @@ -127,6 +128,13 @@ impl Default for Table { } impl Table { + /// Returns the [`IngredientIndex`] for an [`Id`]. + #[inline] + pub fn ingredient_index(&self, id: Id) -> IngredientIndex { + let (page_idx, _) = split_id(id); + self.pages[page_idx.0].ingredient_index() + } + /// Get a reference to the data for `id`, which must have been allocated from this table with type `T`. /// /// # Panics @@ -320,6 +328,10 @@ impl TablePage for Page { std::any::type_name::() } + fn ingredient_index(&self) -> IngredientIndex { + self.ingredient + } + unsafe fn memos(&self, slot: SlotIndex, current_revision: Revision) -> &MemoTable { unsafe { self.get(slot).memos(current_revision) } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 52a6d343..082bcc47 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -6,7 +6,7 @@ use tracked_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, Ingredient, Jar, JarAux, MaybeChangedAfter}, + ingredient::{fmt_index, Ingredient, Jar, MaybeChangedAfter}, key::{DatabaseKeyIndex, InputDependencyIndex}, plumbing::ZalsaLocal, revision::OptionalAtomicRevision, @@ -111,9 +111,9 @@ impl Default for JarImpl { impl Jar for JarImpl { fn create_ingredients( - &self, - _aux: &dyn JarAux, + _zalsa: &Zalsa, struct_index: crate::zalsa::IngredientIndex, + _dependencies: crate::memo_ingredient_indices::IngredientIndices, ) -> Vec> { let struct_ingredient = >::new(struct_index); @@ -133,8 +133,8 @@ impl Jar for JarImpl { .collect() } - fn salsa_struct_type_id(&self) -> Option { - Some(TypeId::of::<::Struct<'static>>()) + fn id_struct_type_id() -> TypeId { + TypeId::of::>() } } diff --git a/src/zalsa.rs b/src/zalsa.rs index 1816ea53..2069ce52 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -1,6 +1,7 @@ use parking_lot::{Mutex, RwLock}; use rustc_hash::FxHashMap; use std::any::{Any, TypeId}; +use std::collections::hash_map; use std::marker::PhantomData; use std::mem; use std::num::NonZeroU32; @@ -8,7 +9,7 @@ use std::panic::RefUnwindSafe; use std::sync::atomic::{AtomicU64, Ordering}; use crate::cycle::CycleRecoveryStrategy; -use crate::ingredient::{Ingredient, Jar, JarAux}; +use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; use crate::table::memo::MemoTable; @@ -139,6 +140,11 @@ pub struct Zalsa { /// adding new kinds of ingredients. jar_map: Mutex>, + /// A map from the `IngredientIndex` to the `TypeId` of its ID struct. + /// + /// Notably this is not the reverse mapping of `jar_map`. + ingredient_to_id_struct_type_id_map: RwLock>, + /// Vector of ingredients. /// /// Immutable unless the mutex on `ingredients_map` is held. @@ -165,6 +171,7 @@ impl Zalsa { views_of: Views::new::(), nonce: NONCE.nonce(), jar_map: Default::default(), + ingredient_to_id_struct_type_id_map: Default::default(), ingredients_vec: boxcar::Vec::new(), ingredients_requiring_reset: boxcar::Vec::new(), runtime: Runtime::default(), @@ -236,51 +243,85 @@ impl Zalsa { db.zalsa_local().unwind_cancelled(self.current_revision()); } } + + pub(crate) fn next_memo_ingredient_index( + &self, + struct_ingredient_index: IngredientIndex, + ingredient_index: IngredientIndex, + ) -> MemoIngredientIndex { + let mut memo_ingredients = self.memo_ingredient_indices.write(); + let idx = struct_ingredient_index.as_usize(); + let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) { + memo_ingredients + } else { + memo_ingredients.resize_with(idx + 1, Vec::new); + memo_ingredients.get_mut(idx).unwrap() + }; + let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); + memo_ingredients.push(ingredient_index); + mi + } } /// Semver unstable APIs used by the macro expansions impl Zalsa { + /// **NOT SEMVER STABLE** + #[inline] + pub fn lookup_page_type_id(&self, id: Id) -> TypeId { + let ingredient_index = self.ingredient_index(id); + *self + .ingredient_to_id_struct_type_id_map + .read() + .get(&ingredient_index) + .expect("should have the ingredient index available") + } + /// **NOT SEMVER STABLE** #[doc(hidden)] - pub fn add_or_lookup_jar_by_type(&self, jar: &dyn Jar) -> IngredientIndex { - { - let jar_type_id = jar.type_id(); - let mut jar_map = self.jar_map.lock(); - let mut should_create = false; - // First record the index we will use into the map and then go and create the ingredients. - // Those ingredients may invoke methods on the `JarAux` trait that read from this map - // to lookup ingredient indices for already created jars. - // - // Note that we still hold the lock above so only one jar is being created at a time and hence - // ingredient indices cannot overlap. - let index = *jar_map.entry(jar_type_id).or_insert_with(|| { - should_create = true; - IngredientIndex::from(self.ingredients_vec.count()) - }); - if should_create { - let aux = JarAuxImpl(self, &jar_map); - let ingredients = jar.create_ingredients(&aux, index); - for ingredient in ingredients { - let expected_index = ingredient.ingredient_index(); - - if ingredient.requires_reset_for_new_revision() { - self.ingredients_requiring_reset.push(expected_index); - } - - let actual_index = self.ingredients_vec.push(ingredient); - assert_eq!( - expected_index.as_usize(), - actual_index, - "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", - self.ingredients_vec[actual_index], - expected_index, - actual_index, - ); - } + pub fn add_or_lookup_jar_by_type(&self) -> IngredientIndex { + let jar_type_id = TypeId::of::(); + let mut jar_map = self.jar_map.lock(); + if let Some(index) = jar_map.get(&jar_type_id) { + return *index; + }; + // Drop the map as `J::create_dependencies` may recurse into this function taking the lock again. + drop(jar_map); + let dependencies = J::create_dependencies(self); + + jar_map = self.jar_map.lock(); + let index = IngredientIndex::from(self.ingredients_vec.count()); + match jar_map.entry(jar_type_id) { + hash_map::Entry::Occupied(entry) => { + // Someone made it earlier than us. + return *entry.get(); + } + hash_map::Entry::Vacant(entry) => entry.insert(index), + }; + let ingredients = J::create_ingredients(self, index, dependencies); + for ingredient in ingredients { + let expected_index = ingredient.ingredient_index(); + + if ingredient.requires_reset_for_new_revision() { + self.ingredients_requiring_reset.push(expected_index); } - index + let actual_index = self.ingredients_vec.push(ingredient); + assert_eq!( + expected_index.as_usize(), + actual_index, + "ingredient `{:?}` was predicted to have index `{:?}` but actually has index `{:?}`", + self.ingredients_vec[actual_index], + expected_index, + actual_index, + ); } + + drop(jar_map); + self.ingredient_to_id_struct_type_id_map + .write() + .insert(index, J::id_struct_type_id()); + + index } /// **NOT SEMVER STABLE** @@ -339,31 +380,10 @@ impl Zalsa { .reset_for_new_revision(self.runtime.table_mut()); } } -} -struct JarAuxImpl<'a>(&'a Zalsa, &'a FxHashMap); - -impl JarAux for JarAuxImpl<'_> { - fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { - self.1.get(&jar.type_id()).map(ToOwned::to_owned) - } - - fn next_memo_ingredient_index( - &self, - struct_ingredient_index: IngredientIndex, - ingredient_index: IngredientIndex, - ) -> MemoIngredientIndex { - let mut memo_ingredients = self.0.memo_ingredient_indices.write(); - let idx = struct_ingredient_index.as_usize(); - let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) { - memo_ingredients - } else { - memo_ingredients.resize_with(idx + 1, Vec::new); - &mut memo_ingredients[idx] - }; - let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); - memo_ingredients.push(ingredient_index); - mi + #[inline] + pub fn ingredient_index(&self, id: Id) -> IngredientIndex { + self.table().ingredient_index(id) } } diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs index cb1e829e..272e2166 100644 --- a/tests/interned-structs_self_ref.rs +++ b/tests/interned-structs_self_ref.rs @@ -1,8 +1,10 @@ //! Test that a `tracked` fn on a `salsa::input` //! compiles and executes successfully. +use std::any::TypeId; use std::convert::identity; +use salsa::plumbing::Zalsa; use test_log::test; #[test] @@ -86,7 +88,7 @@ const _: () = { zalsa_::IngredientCache::new(); CACHE.get_or_create(db.as_dyn_database(), || { db.zalsa() - .add_or_lookup_jar_by_type(&>::default()) + .add_or_lookup_jar_by_type::>() }) } } @@ -110,10 +112,20 @@ const _: () = { } } impl zalsa_::SalsaStructInDb for InternedString<'_> { - fn lookup_ingredient_index( - aux: &dyn zalsa_::JarAux, - ) -> core::option::Option { - aux.lookup_jar_by_type(&>::default()) + type MemoIngredientMap = zalsa_::MemoIngredientSingletonIndex; + + fn lookup_or_create_ingredient_index(aux: &Zalsa) -> salsa::plumbing::IngredientIndices { + aux.add_or_lookup_jar_by_type::>() + .into() + } + + #[inline] + fn cast(id: zalsa_::Id, type_id: TypeId) -> Option { + if type_id == TypeId::of::() { + Some(::from_id(id)) + } else { + None + } } } diff --git a/tests/tracked_fn_on_interned_enum.rs b/tests/tracked_fn_on_interned_enum.rs new file mode 100644 index 00000000..0a4b0a37 --- /dev/null +++ b/tests/tracked_fn_on_interned_enum.rs @@ -0,0 +1,93 @@ +//! Test that a `tracked` fn on a `salsa::interned` +//! compiles and executes successfully. + +#[salsa::interned(no_lifetime)] +struct Name { + name: String, +} + +#[salsa::interned] +struct NameAndAge<'db> { + name_and_age: String, +} + +#[salsa::interned(no_lifetime)] +struct Age { + age: u32, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Supertype)] +enum Enum<'db> { + Name(Name), + NameAndAge(NameAndAge<'db>), + Age(Age), +} + +#[salsa::input] +struct Input { + value: String, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Supertype)] +enum EnumOfEnum<'db> { + Enum(Enum<'db>), + Input(Input), +} + +#[salsa::tracked] +fn tracked_fn<'db>(db: &'db dyn salsa::Database, enum_: Enum<'db>) -> String { + match enum_ { + Enum::Name(name) => name.name(db), + Enum::NameAndAge(name_and_age) => name_and_age.name_and_age(db), + Enum::Age(age) => age.age(db).to_string(), + } +} + +#[salsa::tracked] +fn tracked_fn2<'db>(db: &'db dyn salsa::Database, enum_: EnumOfEnum<'db>) -> String { + match enum_ { + EnumOfEnum::Enum(enum_) => tracked_fn(db, enum_), + EnumOfEnum::Input(input) => input.value(db), + } +} + +#[test] +fn execute() { + let db = salsa::DatabaseImpl::new(); + let name = Name::new(&db, "Salsa".to_string()); + let name_and_age = NameAndAge::new(&db, "Salsa 3".to_string()); + let age = Age::new(&db, 123); + + assert_eq!(tracked_fn(&db, Enum::Name(name)), "Salsa"); + assert_eq!(tracked_fn(&db, Enum::NameAndAge(name_and_age)), "Salsa 3"); + assert_eq!(tracked_fn(&db, Enum::Age(age)), "123"); + assert_eq!(tracked_fn(&db, Enum::Name(name)), "Salsa"); + assert_eq!(tracked_fn(&db, Enum::NameAndAge(name_and_age)), "Salsa 3"); + assert_eq!(tracked_fn(&db, Enum::Age(age)), "123"); + + assert_eq!( + tracked_fn2(&db, EnumOfEnum::Enum(Enum::Name(name))), + "Salsa" + ); + assert_eq!( + tracked_fn2(&db, EnumOfEnum::Enum(Enum::NameAndAge(name_and_age))), + "Salsa 3" + ); + assert_eq!(tracked_fn2(&db, EnumOfEnum::Enum(Enum::Age(age))), "123"); + assert_eq!( + tracked_fn2(&db, EnumOfEnum::Enum(Enum::Name(name))), + "Salsa" + ); + assert_eq!( + tracked_fn2(&db, EnumOfEnum::Enum(Enum::NameAndAge(name_and_age))), + "Salsa 3" + ); + assert_eq!(tracked_fn2(&db, EnumOfEnum::Enum(Enum::Age(age))), "123"); + assert_eq!( + tracked_fn2( + &db, + EnumOfEnum::Input(Input::new(&db, "Hello world!".to_string())) + ), + "Hello world!" + ); +}