From ef8f4d38fe6bfab7b616c6383edd6dfaf7e8001c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 8 Nov 2023 23:16:29 +0000 Subject: [PATCH] feat: polymorphic function types (inc OpDefs) using dyn trait (#630) * Add type variable as a variant of PrimType and TypeArg * Add types/poly_func.rs containing PolyFuncType - a list of binders ('TypeParam's) and a FunctionType * can be instantiated to a FunctionType, or (perhaps-partially) a PolyFuncType * PrimType contains PolyFuncType, not FunctionType * Add `substitute` method to Type, TypeArg, et al., taking a Substitution trait which provides values for type variables * Implementations SubstValues (when instantiating) as well as InsideBinders and Renumber for dealing with PolyFuncTypes * OpDefs may have either binary compute_signature functions or PolyFuncType "type schemes"; the former (i.e. binary functions) return the latter. * Add LeafOp::TypeApply whose signature represents applying a PolyFuncType to a type-argument - the argument of kind (i.e. TypeParam) depending on the former BREAKING CHANGE: * Rename OpDef::instantiate_concrete to just instantiate; * PrimType::Function and Type::new_function take PolyFuncType not FunctionType --------- Co-authored-by: Seyon Sivarajah --- src/extension.rs | 74 +++- src/extension/op_def.rs | 168 +++++--- src/extension/prelude.rs | 6 +- src/extension/type_def.rs | 18 +- src/hugr/rewrite/replace.rs | 2 +- src/hugr/validate.rs | 51 ++- src/ops/leaf.rs | 134 +++++++ src/std_extensions/collections.rs | 17 +- src/std_extensions/quantum.rs | 13 +- src/types.rs | 111 +++++- src/types/check.rs | 2 +- src/types/custom.rs | 44 ++- src/types/poly_func.rs | 619 ++++++++++++++++++++++++++++++ src/types/primitive.rs | 9 +- src/types/serialize.rs | 11 +- src/types/signature.rs | 26 +- src/types/type_param.rs | 152 +++++++- 17 files changed, 1292 insertions(+), 165 deletions(-) create mode 100644 src/types/poly_func.rs diff --git a/src/extension.rs b/src/extension.rs index dd7401cfe..25efa58f1 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -14,9 +14,9 @@ use thiserror::Error; use crate::hugr::IdentList; use crate::ops; use crate::ops::custom::{ExtensionOp, OpaqueOp}; -use crate::types::type_param::{check_type_arg, TypeArgError}; +use crate::types::type_param::{check_type_args, TypeArgError}; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{CustomType, TypeBound}; +use crate::types::{check_typevar_decl, CustomType, PolyFuncType, Substitution, TypeBound}; mod infer; pub use infer::{infer_extensions, ExtensionSolution, InferExtensionError}; @@ -91,6 +91,24 @@ pub enum SignatureError { actual: TypeBound, expected: TypeBound, }, + /// A Type Variable's cache of its declared kind is incorrect + #[error("Type Variable claims to be {cached:?} but actual declaration {actual:?}")] + TypeVarDoesNotMatchDeclaration { + actual: TypeParam, + cached: TypeParam, + }, + /// A type variable that was used has not been declared + #[error("Type variable {idx} was not declared ({num_decls} in scope)")] + FreeTypeVar { idx: usize, num_decls: usize }, + /// The type stored in a [LeafOp::TypeApply] is not what we compute from the + /// [ExtensionRegistry]. + /// + /// [LeafOp::TypeApply]: crate::ops::LeafOp::TypeApply + #[error("Incorrect result of type application - cached {cached} but expected {expected}")] + TypeApplyIncorrectCache { + cached: PolyFuncType, + expected: PolyFuncType, + }, } /// Concrete instantiations of types and operations defined in extensions. @@ -140,15 +158,7 @@ trait TypeParametrised { fn extension(&self) -> &ExtensionId; /// Check provided type arguments are valid against parameters. fn check_args_impl(&self, args: &[TypeArg]) -> Result<(), SignatureError> { - if args.len() != self.params().len() { - return Err(SignatureError::TypeArgMismatch( - TypeArgError::WrongNumberArgs(args.len(), self.params().len()), - )); - } - for (a, p) in args.iter().zip(self.params().iter()) { - check_type_arg(a, p).map_err(SignatureError::TypeArgMismatch)?; - } - Ok(()) + check_type_args(args, self.params()).map_err(SignatureError::TypeArgMismatch) } } @@ -315,6 +325,14 @@ impl ExtensionSet { self.0.insert(extension.clone()); } + /// Adds a type var (which must have been declared as a [TypeParam::Extensions]) to this set + pub fn insert_type_var(&mut self, idx: usize) { + // Represent type vars as string representation of DeBruijn index. + // This is not a legal IdentList or ExtensionId so should not conflict. + self.0 + .insert(ExtensionId::new_unchecked(idx.to_string().as_str())); + } + /// Returns `true` if the set contains the given extension. pub fn contains(&self, extension: &ExtensionId) -> bool { self.0.contains(extension) @@ -337,6 +355,14 @@ impl ExtensionSet { set } + /// An ExtensionSet containing a single type variable + /// (which must have been declared as a [TypeParam::Extensions]) + pub fn type_var(idx: usize) -> Self { + let mut set = Self::new(); + set.insert_type_var(idx); + set + } + /// Returns the union of two extension sets. pub fn union(mut self, other: &Self) -> Self { self.0.extend(other.0.iter().cloned()); @@ -357,6 +383,32 @@ impl ExtensionSet { pub fn is_empty(&self) -> bool { self.0.is_empty() } + + pub(crate) fn validate(&self, params: &[TypeParam]) -> Result<(), SignatureError> { + self.iter() + .filter_map(as_typevar) + .try_for_each(|var_idx| check_typevar_decl(params, var_idx, &TypeParam::Extensions)) + } + + pub(crate) fn substitute(&self, t: &impl Substitution) -> Self { + Self::from_iter(self.0.iter().flat_map(|e| match as_typevar(e) { + None => vec![e.clone()], + Some(i) => match t.apply_var(i, &TypeParam::Extensions) { + TypeArg::Extensions{es} => es.iter().cloned().collect::>(), + _ => panic!("value for type var was not extension set - type scheme should be validated first"), + }, + })) + } +} + +fn as_typevar(e: &ExtensionId) -> Option { + // Type variables are represented as radix-10 numbers, which are illegal + // as standard ExtensionIds. Hence if an ExtensionId starts with a digit, + // we assume it must be a type variable, and fail fast if it isn't. + match e.chars().next() { + Some(c) if c.is_ascii_digit() => Some(str::parse(e).unwrap()), + _ => None, + } } impl Display for ExtensionSet { diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 8fe92e4fc..6a6ba0f97 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -1,24 +1,20 @@ -use crate::Hugr; +use std::cmp::min; use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::sync::Arc; +use smol_str::SmolStr; + use super::{ Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeParametrised, }; -use crate::types::FunctionType; - -use crate::types::type_param::TypeArg; - use crate::ops::custom::OpaqueOp; - -use std::collections::HashMap; - -use crate::types::type_param::TypeParam; - -use smol_str::SmolStr; +use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; +use crate::types::{FunctionType, PolyFuncType}; +use crate::Hugr; /// Trait for extensions to provide custom binary code for computing signature. pub trait CustomSignatureFunc: Send + Sync { @@ -31,14 +27,14 @@ pub trait CustomSignatureFunc: Send + Sync { arg_values: &[TypeArg], misc: &HashMap, extension_registry: &ExtensionRegistry, - ) -> Result; + ) -> Result; } // Note this is very much a utility, rather than definitive; // one can only do so much without the ExtensionRegistry! -impl CustomSignatureFunc for F +impl> CustomSignatureFunc for F where - F: Fn(&[TypeArg]) -> Result + Send + Sync, + F: Fn(&[TypeArg]) -> Result + Send + Sync, { fn compute_signature( &self, @@ -46,8 +42,8 @@ where arg_values: &[TypeArg], _misc: &HashMap, _extension_registry: &ExtensionRegistry, - ) -> Result { - self(arg_values) + ) -> Result { + Ok(self(arg_values)?.into()) } } @@ -77,25 +73,25 @@ pub trait CustomLowerFunc: Send + Sync { /// The two ways in which an OpDef may compute the Signature of each operation node. #[derive(serde::Deserialize, serde::Serialize)] pub(super) enum SignatureFunc { - // Note: I'd prefer to make the YAML version just implement the same CustomSignatureFunc trait, - // and then just have a Box instead of this enum, but that seems less likely - // to serialize well. - /// TODO: these types need to be whatever representation we want of a type scheme encoded in the YAML + // Note: except for serialization, we could have type schemes just implement the same + // CustomSignatureFunc trait too, and replace this enum with Box. + // However instead we treat all CustomFunc's as non-serializable. #[serde(rename = "signature")] - FromDecl { inputs: String, outputs: String }, + TypeScheme(PolyFuncType), #[serde(skip)] - CustomFunc(Box), + CustomFunc { + /// Type parameters passed to [func]. (The returned [PolyFuncType] + /// may require further type parameters, not declared here.) + static_params: Vec, + func: Box, + }, } impl Debug for SignatureFunc { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::FromDecl { inputs, outputs } => f - .debug_struct("signature") - .field("inputs", inputs) - .field("outputs", outputs) - .finish(), - Self::CustomFunc(_) => f.write_str(""), + Self::TypeScheme(scheme) => scheme.fmt(f), + Self::CustomFunc { .. } => f.write_str(""), } } } @@ -135,8 +131,6 @@ pub struct OpDef { name: SmolStr, /// Human readable description of the operation. description: String, - /// Declared type parameters, values must be provided for each operation node - params: Vec, /// Miscellaneous data associated with the operation. #[serde(default, skip_serializing_if = "HashMap::is_empty")] misc: HashMap, @@ -178,16 +172,24 @@ impl OpDef { args: &[TypeArg], exts: &ExtensionRegistry, ) -> Result { - self.check_args(args)?; - let res = match &self.signature_func { - SignatureFunc::FromDecl { .. } => { - // Sig should be computed solely from inputs + outputs + args. - todo!() - } - SignatureFunc::CustomFunc(bf) => { - bf.compute_signature(&self.name, args, &self.misc, exts)? + // Hugr's are monomorphic, so check the args have no free variables + args.iter().try_for_each(|ta| ta.validate(exts, &[]))?; + + let temp: PolyFuncType; // to keep alive + let (pf, args) = match &self.signature_func { + SignatureFunc::TypeScheme(ts) => (ts, args), + SignatureFunc::CustomFunc { + static_params, + func, + } => { + let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); + check_type_args(static_args, static_params)?; + temp = func.compute_signature(&self.name, static_args, &self.misc, exts)?; + (&temp, other_args) } }; + + let res = pf.instantiate(args, exts)?; // TODO bring this assert back once resource inference is done? // https://github.com/CQCL-DEV/hugr/issues/425 // assert!(res.contains(self.extension())); @@ -196,8 +198,8 @@ impl OpDef { pub(crate) fn should_serialize_signature(&self) -> bool { match self.signature_func { - SignatureFunc::CustomFunc(_) => true, - SignatureFunc::FromDecl { .. } => false, + SignatureFunc::TypeScheme { .. } => true, + SignatureFunc::CustomFunc { .. } => false, } } @@ -238,7 +240,10 @@ impl OpDef { /// Returns a reference to the params of this [`OpDef`]. pub fn params(&self) -> &[TypeParam] { - self.params.as_ref() + match &self.signature_func { + SignatureFunc::TypeScheme(ts) => ts.params(), + SignatureFunc::CustomFunc { static_params, .. } => static_params, + } } } @@ -248,7 +253,6 @@ impl Extension { &mut self, name: SmolStr, description: String, - params: Vec, misc: HashMap, lower_funcs: Vec, signature_func: SignatureFunc, @@ -257,7 +261,6 @@ impl Extension { extension: self.name.clone(), name, description, - params, misc, signature_func, lower_funcs, @@ -274,7 +277,7 @@ impl Extension { &mut self, name: SmolStr, description: String, - params: Vec, + static_params: Vec, misc: HashMap, lower_funcs: Vec, signature_func: impl CustomSignatureFunc + 'static, @@ -282,50 +285,103 @@ impl Extension { self.add_op( name, description, - params, misc, lower_funcs, - SignatureFunc::CustomFunc(Box::new(signature_func)), + SignatureFunc::CustomFunc { + static_params, + func: Box::new(signature_func), + }, ) } - /// Create an OpDef with custom binary code to compute the signature, and no "misc" or "lowering - /// functions" defined. + /// Create an OpDef with custom binary code to compute the type scheme + /// (which may be polymorphic); and no "misc" or "lowering functions" defined. pub fn add_op_custom_sig_simple( &mut self, name: SmolStr, description: String, - params: Vec, + static_params: Vec, signature_func: impl CustomSignatureFunc + 'static, ) -> Result<&OpDef, ExtensionBuildError> { self.add_op_custom_sig( name, description, - params, + static_params, HashMap::default(), Vec::new(), signature_func, ) } - /// Create an OpDef with a signature (inputs+outputs) read from the + /// Create an OpDef with a signature (inputs+outputs) read from e.g. /// declarative YAML - pub fn add_op_decl_sig( + pub fn add_op_type_scheme( &mut self, name: SmolStr, description: String, - params: Vec, misc: HashMap, lower_funcs: Vec, - (inputs, outputs): (String, String), // separating these makes clippy complain about too many args + type_scheme: PolyFuncType, ) -> Result<&OpDef, ExtensionBuildError> { self.add_op( name, description, - params, misc, lower_funcs, - SignatureFunc::FromDecl { inputs, outputs }, + SignatureFunc::TypeScheme(type_scheme), ) } } + +#[cfg(test)] +mod test { + use smol_str::SmolStr; + + use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; + use crate::extension::prelude::USIZE_T; + use crate::extension::PRELUDE; + use crate::ops::custom::ExternalOp; + use crate::ops::LeafOp; + use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; + use crate::types::Type; + use crate::types::{type_param::TypeParam, FunctionType, PolyFuncType, TypeArg, TypeBound}; + use crate::{const_extension_ids, Extension}; + + const_extension_ids! { + const EXT_ID: ExtensionId = "MyExt"; + } + + #[test] + fn op_def_with_type_scheme() -> Result<(), Box> { + let reg1 = [PRELUDE.to_owned(), EXTENSION.to_owned()].into(); + let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); + let mut e = Extension::new(EXT_ID); + const TP: TypeParam = TypeParam::Type(TypeBound::Any); + let list_of_var = + Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); + const OP_NAME: SmolStr = SmolStr::new_inline("Reverse"); + let type_scheme = PolyFuncType::new_validated( + vec![TP], + FunctionType::new_linear(vec![list_of_var]), + ®1, + )?; + e.add_op_type_scheme(OP_NAME, "".into(), Default::default(), vec![], type_scheme)?; + + let list_usize = + Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: USIZE_T }])?); + let mut dfg = DFGBuilder::new(FunctionType::new_linear(vec![list_usize]))?; + let rev = dfg.add_dataflow_op( + LeafOp::from(ExternalOp::Extension( + e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], ®1) + .unwrap(), + )), + dfg.input_wires(), + )?; + dfg.finish_hugr_with_outputs( + rev.outputs(), + &[PRELUDE.to_owned(), EXTENSION.to_owned(), e].into(), + )?; + + Ok(()) + } +} diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index 07e88ddb8..3102cdcb7 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; -use super::{ExtensionRegistry, EMPTY_REG}; +use super::ExtensionRegistry; /// Name of prelude extension. pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude"); @@ -92,7 +92,7 @@ pub const BOOL_T: Type = Type::new_unit_sum(2); pub fn array_type(element_ty: Type, size: u64) -> Type { let array_def = PRELUDE.get_type("array").unwrap(); let custom_t = array_def - .instantiate_concrete(vec![ + .instantiate(vec![ TypeArg::Type { ty: element_ty }, TypeArg::BoundedNat { n: size }, ]) @@ -112,7 +112,7 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp { TypeArg::Type { ty: element_ty }, TypeArg::BoundedNat { n: size }, ], - &EMPTY_REG, + &PRELUDE_REGISTRY, ) .unwrap() .into() diff --git a/src/extension/type_def.rs b/src/extension/type_def.rs index 1ad0d28ac..dc03039ad 100644 --- a/src/extension/type_def.rs +++ b/src/extension/type_def.rs @@ -93,10 +93,7 @@ impl TypeDef { /// /// This function will return an error if the provided arguments are not /// valid instances of the type parameters. - pub fn instantiate_concrete( - &self, - args: impl Into>, - ) -> Result { + pub fn instantiate(&self, args: impl Into>) -> Result { let args = args.into(); self.check_args_impl(&args)?; let bound = self.bound(&args); @@ -188,21 +185,18 @@ mod test { bound: TypeDefBound::FromParams(vec![0]), }; let typ = Type::new_extension( - def.instantiate_concrete(vec![TypeArg::Type { + def.instantiate(vec![TypeArg::Type { ty: Type::new_function(FunctionType::new(vec![], vec![])), }]) .unwrap(), ); assert_eq!(typ.least_upper_bound(), TypeBound::Copyable); - let typ2 = Type::new_extension( - def.instantiate_concrete([TypeArg::Type { ty: USIZE_T }]) - .unwrap(), - ); + let typ2 = Type::new_extension(def.instantiate([TypeArg::Type { ty: USIZE_T }]).unwrap()); assert_eq!(typ2.least_upper_bound(), TypeBound::Eq); // And some bad arguments...firstly, wrong kind of TypeArg: assert_eq!( - def.instantiate_concrete([TypeArg::Type { ty: QB_T }]), + def.instantiate([TypeArg::Type { ty: QB_T }]), Err(SignatureError::TypeArgMismatch( TypeArgError::TypeMismatch { arg: TypeArg::Type { ty: QB_T }, @@ -212,12 +206,12 @@ mod test { ); // Too few arguments: assert_eq!( - def.instantiate_concrete([]).unwrap_err(), + def.instantiate([]).unwrap_err(), SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(0, 1)) ); // Too many arguments: assert_eq!( - def.instantiate_concrete([ + def.instantiate([ TypeArg::Type { ty: FLOAT64_TYPE }, TypeArg::Type { ty: FLOAT64_TYPE }, ]) diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index ac250ad60..0b01df5aa 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -463,7 +463,7 @@ mod test { collections::EXTENSION .get_type(collections::LIST_TYPENAME.as_str()) .unwrap() - .instantiate_concrete([TypeArg::Type { ty: USIZE_T }]) + .instantiate([TypeArg::Type { ty: USIZE_T }]) .unwrap(), ); let pop: LeafOp = collections::EXTENSION diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index ec601dfa6..0d0724379 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -164,29 +164,38 @@ impl<'a, 'b> ValidationContext<'a, 'b> { } // Check operation-specific constraints. - if let OpType::LeafOp(crate::ops::LeafOp::CustomOp(b)) = op_type { - // Check TypeArgs are valid (in themselves, not necessarily wrt the TypeParams) - for arg in b.args() { - arg.validate(self.extension_registry) - .map_err(|cause| ValidationError::SignatureError { node, cause })?; - } - // Try to resolve serialized names to actual OpDefs in Extensions. - let e: Option; - let ext_op = match &**b { - ExternalOp::Opaque(op) => { - // If resolve_extension_ops has been called first, this would always return Ok(None) - e = resolve_opaque_op(node, op, self.extension_registry)?; - e.as_ref() + // TODO make a separate method for this (perhaps producing Result<(), SignatureError>) + match op_type { + OpType::LeafOp(crate::ops::LeafOp::CustomOp(b)) => { + // Check TypeArgs are valid (in themselves, not necessarily wrt the TypeParams) + for arg in b.args() { + // Hugrs are monomorphic, so no type variables in scope + arg.validate(self.extension_registry, &[]) + .map_err(|cause| ValidationError::SignatureError { node, cause })?; } - ExternalOp::Extension(ext) => Some(ext), - }; - // If successful, check TypeArgs are valid for the declared TypeParams - if let Some(ext_op) = ext_op { - ext_op - .def() - .check_args(ext_op.args()) + // Try to resolve serialized names to actual OpDefs in Extensions. + let e: Option; + let ext_op = match &**b { + ExternalOp::Opaque(op) => { + // If resolve_extension_ops has been called first, this would always return Ok(None) + e = resolve_opaque_op(node, op, self.extension_registry)?; + e.as_ref() + } + ExternalOp::Extension(ext) => Some(ext), + }; + // If successful, check TypeArgs are valid for the declared TypeParams + if let Some(ext_op) = ext_op { + ext_op + .def() + .check_args(ext_op.args()) + .map_err(|cause| ValidationError::SignatureError { node, cause })?; + } + } + OpType::LeafOp(crate::ops::LeafOp::TypeApply { ta }) => { + ta.validate(self.extension_registry) .map_err(|cause| ValidationError::SignatureError { node, cause })?; } + _ => (), } // Secondly that the node has correct children self.validate_children(node, node_type)?; @@ -241,7 +250,7 @@ impl<'a, 'b> ValidationContext<'a, 'b> { match &port_kind { EdgeKind::Value(ty) | EdgeKind::Static(ty) => ty - .validate(self.extension_registry) + .validate(self.extension_registry, &[]) // no type vars inside the Hugr .map_err(|cause| ValidationError::SignatureError { node, cause })?, _ => (), } diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 768d35fcb..24f1ebb7e 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -5,6 +5,9 @@ use smol_str::SmolStr; use super::custom::ExternalOp; use super::{OpName, OpTag, OpTrait, StaticTag}; +use crate::extension::{ExtensionRegistry, SignatureError}; +use crate::types::type_param::TypeArg; +use crate::types::PolyFuncType; use crate::{ extension::{ExtensionId, ExtensionSet}, types::{EdgeKind, FunctionType, Type, TypeRow}, @@ -49,6 +52,59 @@ pub enum LeafOp { /// The extensions which we're adding to the inputs new_extension: ExtensionId, }, + /// Fixes some [TypeParam]s of a polymorphic type by providing [TypeArg]s + /// + /// [TypeParam]: crate::types::type_param::TypeParam + TypeApply { + /// The type and args, plus a cache of the resulting type + ta: TypeApplication, + }, +} + +/// Records details of an application of a [PolyFuncType] to some [TypeArg]s +/// and the result (a less-, but still potentially-, polymorphic type). +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct TypeApplication { + input: PolyFuncType, + args: Vec, + output: PolyFuncType, // cached +} + +impl TypeApplication { + /// Checks that the specified args are correct for the [TypeParam]s of the polymorphic input. + /// Note the extension registry is required here to recompute [Type::least_upper_bound]s. + /// + /// [TypeParam]: crate::types::type_param::TypeParam + pub fn try_new( + input: PolyFuncType, + args: impl Into>, + extension_registry: &ExtensionRegistry, + ) -> Result { + let args = args.into(); + // Should we require >=1 `arg`s here? Or that input declares >=1 params? + // At the moment we allow an identity TypeApply on a monomorphic function type. + let output = input.instantiate_poly(&args, extension_registry)?; + Ok(Self { + input, + args, + output, + }) + } + + pub(crate) fn validate( + &self, + extension_registry: &ExtensionRegistry, + ) -> Result<(), SignatureError> { + let other = Self::try_new(self.input.clone(), self.args.clone(), extension_registry)?; + if other.output == self.output { + Ok(()) + } else { + Err(SignatureError::TypeApplyIncorrectCache { + cached: self.output.clone(), + expected: other.output.clone(), + }) + } + } } impl Default for LeafOp { @@ -66,6 +122,7 @@ impl OpName for LeafOp { LeafOp::UnpackTuple { tys: _ } => "UnpackTuple", LeafOp::Tag { .. } => "Tag", LeafOp::Lift { .. } => "Lift", + LeafOp::TypeApply { .. } => "TypeApply", } .into() } @@ -85,6 +142,9 @@ impl OpTrait for LeafOp { LeafOp::UnpackTuple { tys: _ } => "UnpackTuple operation", LeafOp::Tag { .. } => "Tag Sum operation", LeafOp::Lift { .. } => "Add a extension requirement to an edge", + LeafOp::TypeApply { .. } => { + "Instantiate (perhaps partially) a polymorphic type with some type arguments" + } } } @@ -115,6 +175,10 @@ impl OpTrait for LeafOp { new_extension, } => FunctionType::new(type_row.clone(), type_row.clone()) .with_extension_delta(&ExtensionSet::singleton(new_extension)), + LeafOp::TypeApply { ta } => FunctionType::new( + vec![Type::new_function(ta.input.clone())], + vec![Type::new_function(ta.output.clone())], + ), } } @@ -126,3 +190,73 @@ impl OpTrait for LeafOp { Some(EdgeKind::StateOrder) } } + +#[cfg(test)] +mod test { + use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}; + use crate::extension::prelude::BOOL_T; + use crate::extension::SignatureError; + use crate::extension::{prelude::USIZE_T, PRELUDE}; + use crate::hugr::ValidationError; + use crate::ops::handle::NodeHandle; + use crate::std_extensions::collections::EXTENSION; + use crate::types::Type; + use crate::types::{test::nested_func, FunctionType, TypeArg}; + + use super::{LeafOp, TypeApplication}; + + const USIZE_TA: TypeArg = TypeArg::Type { ty: USIZE_T }; + + #[test] + fn hugr_with_type_apply() -> Result<(), Box> { + let reg = [PRELUDE.to_owned(), EXTENSION.to_owned()].into(); + let pf_in = nested_func(); + let pf_out = pf_in.instantiate(&[USIZE_TA], ®)?; + let mut dfg = DFGBuilder::new(FunctionType::new( + vec![Type::new_function(pf_in.clone())], + vec![Type::new_function(pf_out)], + ))?; + let ta = dfg.add_dataflow_op( + LeafOp::TypeApply { + ta: TypeApplication::try_new(pf_in, [USIZE_TA], ®).unwrap(), + }, + dfg.input_wires(), + )?; + dfg.finish_hugr_with_outputs(ta.outputs(), ®)?; + Ok(()) + } + + #[test] + fn bad_type_apply() -> Result<(), Box> { + let reg = [PRELUDE.to_owned(), EXTENSION.to_owned()].into(); + let pf = nested_func(); + let pf_usz = pf.instantiate_poly(&[USIZE_TA], ®)?; + let pf_bool = pf.instantiate_poly(&[TypeArg::Type { ty: BOOL_T }], ®)?; + let mut dfg = DFGBuilder::new(FunctionType::new( + vec![Type::new_function(pf.clone())], + vec![Type::new_function(pf_usz.clone())], + ))?; + let ta = dfg.add_dataflow_op( + LeafOp::TypeApply { + ta: TypeApplication { + input: pf, + args: vec![TypeArg::Type { ty: BOOL_T }], + output: pf_usz.clone(), + }, + }, + dfg.input_wires(), + )?; + let res = dfg.finish_hugr_with_outputs(ta.outputs(), ®); + assert_eq!( + res.unwrap_err(), + BuildError::InvalidHUGR(ValidationError::SignatureError { + node: ta.node(), + cause: SignatureError::TypeApplyIncorrectCache { + cached: pf_usz, + expected: pf_bool + } + }) + ); + Ok(()) + } +} diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 5e193c27a..40b10f6a2 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -113,7 +113,7 @@ fn get_type(name: &str) -> &TypeDef { } fn list_types(args: &[TypeArg]) -> Result<(Type, Type), SignatureError> { - let list_custom_type = get_type(&LIST_TYPENAME).instantiate_concrete(args)?; + let list_custom_type = get_type(&LIST_TYPENAME).instantiate(args)?; let [TypeArg::Type { ty: element_type }] = args else { panic!("should be checked by def.") }; @@ -127,9 +127,9 @@ mod test { use crate::{ extension::{ prelude::{ConstUsize, QB_T, USIZE_T}, - OpDef, + OpDef, PRELUDE, }, - std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}, + std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, types::{type_param::TypeArg, Type}, Extension, }; @@ -152,11 +152,11 @@ mod test { let list_def = r.get_type(&LIST_TYPENAME).unwrap(); let list_type = list_def - .instantiate_concrete([TypeArg::Type { ty: USIZE_T }]) + .instantiate([TypeArg::Type { ty: USIZE_T }]) .unwrap(); assert!(list_def - .instantiate_concrete([TypeArg::BoundedNat { n: 3 }]) + .instantiate([TypeArg::BoundedNat { n: 3 }]) .is_err()); list_def.check_custom(&list_type).unwrap(); @@ -170,7 +170,12 @@ mod test { #[test] fn test_list_ops() { - let reg = &[EXTENSION.to_owned()].into(); + let reg = &[ + EXTENSION.to_owned(), + PRELUDE.to_owned(), + float_types::extension(), + ] + .into(); let pop_sig = get_op(&POP_NAME) .compute_signature(&[TypeArg::Type { ty: QB_T }], reg) .unwrap(); diff --git a/src/std_extensions/quantum.rs b/src/std_extensions/quantum.rs index 2c8d67498..41b9054e4 100644 --- a/src/std_extensions/quantum.rs +++ b/src/std_extensions/quantum.rs @@ -297,22 +297,27 @@ lazy_static! { #[cfg(test)] pub(crate) mod test { + use lazy_static::lazy_static; use std::f64::consts::TAU; use cool_asserts::assert_matches; use crate::{ - extension::EMPTY_REG, + extension::{ExtensionRegistry, PRELUDE}, ops::LeafOp, - std_extensions::quantum::{get_log_denom, ConstAngle}, types::{type_param::TypeArgError, ConstTypeError, TypeArg}, }; - use super::EXTENSION; + use super::{get_log_denom, ConstAngle, EXTENSION}; + + lazy_static! { + /// Quantum extension definition. + static ref REG: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); + } fn get_gate(gate_name: &str) -> LeafOp { EXTENSION - .instantiate_extension_op(gate_name, [], &EMPTY_REG) + .instantiate_extension_op(gate_name, [], ®) .unwrap() .into() } diff --git a/src/types.rs b/src/types.rs index 8ea81504e..e2452d2b3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -2,6 +2,7 @@ mod check; pub mod custom; +mod poly_func; mod primitive; mod serialize; mod signature; @@ -10,6 +11,7 @@ pub mod type_row; pub use check::{ConstTypeError, CustomCheckFailure}; pub use custom::CustomType; +pub use poly_func::PolyFuncType; pub use signature::{FunctionType, Signature}; pub use type_param::TypeArg; pub use type_row::TypeRow; @@ -25,6 +27,7 @@ use crate::type_row; use std::fmt::Debug; pub use self::primitive::PrimType; +use self::type_param::TypeParam; #[cfg(feature = "pyo3")] use pyo3::pyclass; @@ -212,8 +215,8 @@ impl Type { const UNIT_REF: &'static Self = &Self::UNIT; /// Initialize a new function type. - pub fn new_function(signature: FunctionType) -> Self { - Self::new(TypeEnum::Prim(PrimType::Function(Box::new(signature)))) + pub fn new_function(fun_ty: impl Into) -> Self { + Self::new(TypeEnum::Prim(PrimType::Function(Box::new(fun_ty.into())))) } /// Initialize a new tuple type by providing the elements. @@ -260,6 +263,13 @@ impl Type { Self(TypeEnum::Sum(SumType::Unit { size }), TypeBound::Eq) } + /// New use (occurrence) of the type variable with specified DeBruijn index. + /// For use in type schemes only: `bound` must match that with which the + /// variable was declared (i.e. as a [TypeParam::Type]`(bound)`). + pub fn new_var_use(idx: usize, bound: TypeBound) -> Self { + Self(TypeEnum::Prim(PrimType::Variable(idx, bound)), bound) + } + /// Report the least upper TypeBound, if there is one. #[inline(always)] pub const fn least_upper_bound(&self) -> TypeBound { @@ -278,24 +288,99 @@ impl Type { TypeBound::Copyable.contains(self.least_upper_bound()) } + /// Checks all variables used in the type are in the provided list + /// of bound variables, and that for each [CustomType] the corresponding + /// [TypeDef] is in the [ExtensionRegistry] and the type arguments + /// [validate] and fit into the def's declared parameters. + /// + /// [validate]: crate::types::type_param::TypeArg::validate + /// [TypeDef]: crate::extension::TypeDef pub(crate) fn validate( &self, extension_registry: &ExtensionRegistry, + var_decls: &[TypeParam], ) -> Result<(), SignatureError> { // There is no need to check the components against the bound, // that is guaranteed by construction (even for deserialization) match &self.0 { - TypeEnum::Tuple(row) | TypeEnum::Sum(SumType::General { row }) => { - row.iter().try_for_each(|t| t.validate(extension_registry)) - } + TypeEnum::Tuple(row) | TypeEnum::Sum(SumType::General { row }) => row + .iter() + .try_for_each(|t| t.validate(extension_registry, var_decls)), TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there TypeEnum::Prim(PrimType::Alias(_)) => Ok(()), - TypeEnum::Prim(PrimType::Extension(custy)) => custy.validate(extension_registry), - TypeEnum::Prim(PrimType::Function(ft)) => ft - .input - .iter() - .chain(ft.output.iter()) - .try_for_each(|t| t.validate(extension_registry)), + TypeEnum::Prim(PrimType::Extension(custy)) => { + custy.validate(extension_registry, var_decls) + } + TypeEnum::Prim(PrimType::Function(ft)) => ft.validate(extension_registry, var_decls), + TypeEnum::Prim(PrimType::Variable(idx, bound)) => { + check_typevar_decl(var_decls, *idx, &TypeParam::Type(*bound)) + } + } + } + + pub(crate) fn substitute(&self, t: &impl Substitution) -> Self { + match &self.0 { + TypeEnum::Prim(PrimType::Alias(_)) | TypeEnum::Sum(SumType::Unit { .. }) => { + self.clone() + } + TypeEnum::Prim(PrimType::Variable(idx, bound)) => t.apply_typevar(*idx, *bound), + TypeEnum::Prim(PrimType::Extension(cty)) => Type::new_extension(cty.substitute(t)), + TypeEnum::Prim(PrimType::Function(bf)) => Type::new_function(bf.substitute(t)), + TypeEnum::Tuple(elems) => Type::new_tuple(subst_row(elems, t)), + TypeEnum::Sum(SumType::General { row }) => Type::new_sum(subst_row(row, t)), + } + } +} + +/// A function that replaces type variables with values. +/// (The values depend upon the implementation, to allow dynamic computation; +/// and [Substitution] deals only with type variables, other/containing types/typeargs +/// are handled by [Type::substitute], [TypeArg::substitute] and friends.) +pub(crate) trait Substitution { + /// Apply to a variable of kind [TypeParam::Type] + fn apply_typevar(&self, idx: usize, bound: TypeBound) -> Type { + let TypeArg::Type { ty } = self.apply_var(idx, &TypeParam::Type(bound)) else { + panic!("Variable was not a type - try validate() first") + }; + ty + } + + /// Apply to a variable whose kind is any given [TypeParam] + fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg; + + fn extension_registry(&self) -> &ExtensionRegistry; +} + +fn subst_row(row: &TypeRow, tr: &impl Substitution) -> TypeRow { + let res = row + .iter() + .map(|ty| ty.substitute(tr)) + .collect::>() + .into(); + res +} + +pub(crate) fn check_typevar_decl( + decls: &[TypeParam], + idx: usize, + cached_decl: &TypeParam, +) -> Result<(), SignatureError> { + match decls.get(idx) { + None => Err(SignatureError::FreeTypeVar { + idx, + num_decls: decls.len(), + }), + Some(actual) => { + // The cache here just mirrors the declaration. The typevar can be used + // anywhere expecting a kind *containing* the decl - see `check_type_arg`. + if actual == cached_decl { + Ok(()) + } else { + Err(SignatureError::TypeVarDoesNotMatchDeclaration { + cached: cached_decl.clone(), + actual: actual.clone(), + }) + } } } } @@ -315,6 +400,7 @@ where #[cfg(test)] pub(crate) mod test { + pub(crate) use poly_func::test::nested_func; use super::*; use crate::{ @@ -344,7 +430,8 @@ pub(crate) mod test { ]); assert_eq!( t.to_string(), - "Tuple([usize([]), Function([[]][]), my_custom([]), Alias(my_alias)])".to_string() + "Tuple([usize([]), Function(forall . [[]][]), my_custom([]), Alias(my_alias)])" + .to_string() ); } diff --git a/src/types/check.rs b/src/types/check.rs index dafbb5087..496cc2a2d 100644 --- a/src/types/check.rs +++ b/src/types/check.rs @@ -65,7 +65,7 @@ impl PrimType { Ok(()) } (PrimType::Function(t), PrimValue::Function { hugr: v }) - if Some(t.as_ref()) == v.get_function_type() => + if v.get_function_type().is_some_and(|f| &**t == f) => { // exact signature equality, in future this may need to be // relaxed to be compatibility checks between the signatures. diff --git a/src/types/custom.rs b/src/types/custom.rs index 4e10cfd44..68e70b2c5 100644 --- a/src/types/custom.rs +++ b/src/types/custom.rs @@ -4,9 +4,12 @@ use smol_str::SmolStr; use std::fmt::{self, Display}; -use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError}; +use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDef}; -use super::{type_param::TypeArg, TypeBound}; +use super::{ + type_param::{TypeArg, TypeParam}, + Substitution, TypeBound, +}; /// An opaque type element. Contains the unique identifier of its definition. #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] @@ -59,27 +62,50 @@ impl CustomType { pub(super) fn validate( &self, extension_registry: &ExtensionRegistry, + var_decls: &[TypeParam], ) -> Result<(), SignatureError> { // Check the args are individually ok self.args .iter() - .try_for_each(|a| a.validate(extension_registry))?; + .try_for_each(|a| a.validate(extension_registry, var_decls))?; // And check they fit into the TypeParams declared by the TypeDef + let def = self.get_type_def(extension_registry)?; + def.check_custom(self) + } + + fn get_type_def<'a>( + &self, + extension_registry: &'a ExtensionRegistry, + ) -> Result<&'a TypeDef, SignatureError> { let ex = extension_registry.get(&self.extension); // Even if OpDef's (+binaries) are not available, the part of the Extension definition // describing the TypeDefs can easily be passed around (serialized), so should be available. let ex = ex.ok_or(SignatureError::ExtensionNotFound(self.extension.clone()))?; - let def = ex - .get_type(&self.id) + ex.get_type(&self.id) .ok_or(SignatureError::ExtensionTypeNotFound { exn: self.extension.clone(), typ: self.id.clone(), - })?; - def.check_custom(self) + }) + } + + pub(super) fn substitute(&self, tr: &impl Substitution) -> Self { + let args = self + .args + .iter() + .map(|arg| arg.substitute(tr)) + .collect::>(); + let bound = self + .get_type_def(tr.extension_registry()) + .expect("validate should rule this out") + .bound(&args); + debug_assert!(self.bound.contains(bound)); + Self { + args, + bound, + ..self.clone() + } } -} -impl CustomType { /// unique name of the type. pub fn name(&self) -> &SmolStr { &self.id diff --git a/src/types/poly_func.rs b/src/types/poly_func.rs new file mode 100644 index 000000000..a8ab29ff7 --- /dev/null +++ b/src/types/poly_func.rs @@ -0,0 +1,619 @@ +//! Polymorphic Function Types + +use crate::{ + extension::{ExtensionRegistry, SignatureError}, + types::type_param::check_type_arg, +}; +use itertools::Itertools; + +use super::type_param::{check_type_args, TypeArg, TypeParam}; +use super::{FunctionType, Substitution}; + +/// A polymorphic function type, e.g. of a [Graph], or perhaps an [OpDef]. +/// (Nodes/operations in the Hugr are not polymorphic.) +/// +/// [Graph]: crate::values::PrimValue::Function +/// [OpDef]: crate::extension::OpDef +#[derive( + Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, +)] +#[display( + fmt = "forall {}. {}", + "params.iter().map(ToString::to_string).join(\" \")", + "body" +)] +pub struct PolyFuncType { + /// The declared type parameters, i.e., these must be instantiated with + /// the same number of [TypeArg]s before the function can be called. Note that within + /// the [Self::body], variable (DeBruijn) index 0 is element 0 of this array, i.e. the + /// variables are bound from right to left. + /// + /// [TypeArg]: super::type_param::TypeArg + params: Vec, + /// Template for the function. May contain variables up to length of [Self::params] + pub(super) body: FunctionType, +} + +impl From for PolyFuncType { + fn from(body: FunctionType) -> Self { + Self { + params: vec![], + body, + } + } +} + +impl PolyFuncType { + /// The type parameters, aka binders, over which this type is polymorphic + pub fn params(&self) -> &[TypeParam] { + &self.params + } + + /// Create a new PolyFuncType and validates it. (This will only succeed + /// for outermost PolyFuncTypes i.e. with no free type-variables.) + /// The [ExtensionRegistry] should be the same (or a subset) of that which will later + /// be used to validate the Hugr; at this point we only need the types. + /// + /// #Errors + /// Validates that all types in the schema are well-formed and all variables in the body + /// are declared with [TypeParam]s that guarantee they will fit. + pub fn new_validated( + params: impl Into>, + body: FunctionType, + extension_registry: &ExtensionRegistry, + ) -> Result { + let params = params.into(); + body.validate(extension_registry, ¶ms)?; + Ok(Self { params, body }) + } + + pub(super) fn validate( + &self, + reg: &ExtensionRegistry, + external_var_decls: &[TypeParam], + ) -> Result<(), SignatureError> { + let mut v; // Declared here so live until end of scope + let all_var_decls = if self.params.is_empty() { + external_var_decls + } else { + // Type vars declared here go at lowest indices (as per DeBruijn) + v = self.params.clone(); + v.extend_from_slice(external_var_decls); + v.as_slice() + }; + self.body.validate(reg, all_var_decls) + } + + pub(super) fn substitute(&self, t: &impl Substitution) -> Self { + if self.params.is_empty() { + // Avoid using complex code for simple Monomorphic case + return self.body.substitute(t).into(); + } + PolyFuncType { + params: self.params.clone(), + body: self.body.substitute(&InsideBinders { + num_binders: self.params.len(), + underlying: t, + }), + } + } + + /// (Perhaps-partially) instantiates this [PolyFuncType] into another with fewer binders. + /// Note that indices into `args` correspond to the same index within [Self::params], + /// so we instantiate the lowest-index [Self::params] first, even though these + /// would be considered "innermost" / "closest" according to DeBruijn numbering. + pub(crate) fn instantiate_poly( + &self, + args: &[TypeArg], + exts: &ExtensionRegistry, + ) -> Result { + let remaining = self.params.get(args.len()..).unwrap_or_default(); + let mut v; + let args = if remaining.is_empty() { + args // instantiate below will fail if there were too many + } else { + // Partial application - renumber remaining params (still bound) downward + v = args.to_vec(); + v.extend( + remaining + .iter() + .enumerate() + .map(|(i, decl)| TypeArg::new_var_use(i, decl.clone())), + ); + v.as_slice() + }; + Ok(Self { + params: remaining.to_vec(), + body: self.instantiate(args, exts)?, + }) + } + + /// Instantiates an outer [PolyFuncType], i.e. with no free variables + /// (as ensured by [Self::validate]), into a monomorphic type. + /// + /// # Errors + /// If there is not exactly one [TypeArg] for each binder ([Self::params]), + /// or an arg does not fit into its corresponding [TypeParam] + pub(crate) fn instantiate( + &self, + args: &[TypeArg], + ext_reg: &ExtensionRegistry, + ) -> Result { + // Check that args are applicable, and that we have a value for each binder, + // i.e. each possible free variable within the body. + check_type_args(args, &self.params)?; + Ok(self.body.substitute(&SubstValues(args, ext_reg))) + } +} + +impl PartialEq for PolyFuncType { + fn eq(&self, other: &FunctionType) -> bool { + self.params.is_empty() && &self.body == other + } +} + +/// A [Substitution] with a finite list of known values. +/// (Variables out of the range of the list will result in a panic) +struct SubstValues<'a>(&'a [TypeArg], &'a ExtensionRegistry); + +impl<'a> Substitution for SubstValues<'a> { + fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg { + let arg = self + .0 + .get(idx) + .expect("Undeclared type variable - call validate() ?"); + debug_assert_eq!(check_type_arg(arg, decl), Ok(())); + arg.clone() + } + + fn extension_registry(&self) -> &ExtensionRegistry { + self.1 + } +} + +/// A [Substitution] that renumbers any type variable to another (of the same kind) +/// with a index increased by a fixed `usize``. +struct Renumber<'a> { + offset: usize, + exts: &'a ExtensionRegistry, +} + +impl<'a> Substitution for Renumber<'a> { + fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg { + TypeArg::new_var_use(idx + self.offset, decl.clone()) + } + + fn extension_registry(&self) -> &ExtensionRegistry { + self.exts + } +} + +/// Given a [Substitution] defined outside a binder (i.e. [PolyFuncType]), +/// applies that transformer to types inside the binder (i.e. arguments/results of said function) +struct InsideBinders<'a> { + /// The number of binders we have entered since (beneath where) we started to apply + /// [Self::underlying]). + /// That is, the lowest `num_binders` variable indices refer to locals bound since then. + num_binders: usize, + /// Substitution that was being applied outside those binders (i.e. in outer scope) + underlying: &'a dyn Substitution, +} + +impl<'a> Substitution for InsideBinders<'a> { + fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg { + // Convert variable index into outer scope + match idx.checked_sub(self.num_binders) { + None => TypeArg::new_var_use(idx, decl.clone()), // Bound locally, unknown to `underlying` + Some(idx_in_outer_scope) => { + let result_in_outer_scope = self.underlying.apply_var(idx_in_outer_scope, decl); + // Transform returned value into the current scope, i.e. avoid the variables newly bound + result_in_outer_scope.substitute(&Renumber { + offset: self.num_binders, + exts: self.extension_registry(), + }) + } + } + } + + fn extension_registry(&self) -> &ExtensionRegistry { + self.underlying.extension_registry() + } +} + +#[cfg(test)] +pub(crate) mod test { + use std::num::NonZeroU64; + + use smol_str::SmolStr; + + use crate::extension::prelude::{PRELUDE_ID, USIZE_CUSTOM_T, USIZE_T}; + use crate::extension::{ + ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound, PRELUDE, PRELUDE_REGISTRY, + }; + use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; + use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; + use crate::types::{CustomType, FunctionType, Type, TypeBound}; + use crate::Extension; + + use super::PolyFuncType; + + #[test] + fn test_opaque() -> Result<(), SignatureError> { + let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); + let tyvar = TypeArg::new_var_use(0, TypeParam::Type(TypeBound::Any)); + let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); + let reg: ExtensionRegistry = [PRELUDE.to_owned(), EXTENSION.to_owned()].into(); + let list_len = PolyFuncType::new_validated( + [TypeParam::Type(TypeBound::Any)], + FunctionType::new(vec![list_of_var], vec![USIZE_T]), + ®, + )?; + + let t = list_len.instantiate(&[TypeArg::Type { ty: USIZE_T }], ®)?; + assert_eq!( + t, + FunctionType::new( + vec![Type::new_extension( + list_def + .instantiate([TypeArg::Type { ty: USIZE_T }]) + .unwrap() + )], + vec![USIZE_T] + ) + ); + + Ok(()) + } + + fn id_fn(t: Type) -> FunctionType { + FunctionType::new(vec![t.clone()], vec![t]) + } + + #[test] + fn test_mismatched_args() -> Result<(), SignatureError> { + let ar_def = PRELUDE.get_type("array").unwrap(); + let typarams = [TypeParam::Type(TypeBound::Any), TypeParam::max_nat()]; + let [tyvar, szvar] = + [0, 1].map(|i| TypeArg::new_var_use(i, typarams.get(i).unwrap().clone())); + + // Valid schema... + let good_array = Type::new_extension(ar_def.instantiate([tyvar.clone(), szvar.clone()])?); + let good_ts = + PolyFuncType::new_validated(typarams.clone(), id_fn(good_array), &PRELUDE_REGISTRY)?; + + // Sanity check (good args) + good_ts.instantiate( + &[TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 5 }], + &PRELUDE_REGISTRY, + )?; + + let wrong_args = good_ts.instantiate( + &[TypeArg::BoundedNat { n: 5 }, TypeArg::Type { ty: USIZE_T }], + &PRELUDE_REGISTRY, + ); + assert_eq!( + wrong_args, + Err(SignatureError::TypeArgMismatch( + TypeArgError::TypeMismatch { + param: typarams[0].clone(), + arg: TypeArg::BoundedNat { n: 5 } + } + )) + ); + + // (Try to) make a schema with bad args + let arg_err = SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { + param: typarams[0].clone(), + arg: szvar.clone(), + }); + assert_eq!( + ar_def.instantiate([szvar.clone(), tyvar.clone()]), + Err(arg_err.clone()) + ); + // ok, so that doesn't work - well, it shouldn't! So let's say we just have this signature (with bad args)... + let bad_array = Type::new_extension(CustomType::new( + "array", + [szvar, tyvar], + PRELUDE_ID, + TypeBound::Any, + )); + let bad_ts = + PolyFuncType::new_validated(typarams.clone(), id_fn(bad_array), &PRELUDE_REGISTRY); + assert_eq!(bad_ts.err(), Some(arg_err)); + + Ok(()) + } + + #[test] + fn test_misused_variables() -> Result<(), SignatureError> { + // Variables in args have different bounds from variable declaration + let tv = TypeArg::new_var_use(0, TypeParam::Type(TypeBound::Copyable)); + let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); + let body_type = id_fn(Type::new_extension(list_def.instantiate([tv])?)); + let reg = [EXTENSION.to_owned()].into(); + for decl in [ + TypeParam::Extensions, + TypeParam::List(Box::new(TypeParam::max_nat())), + TypeParam::Opaque(USIZE_CUSTOM_T), + TypeParam::Tuple(vec![TypeParam::Type(TypeBound::Any), TypeParam::max_nat()]), + ] { + let invalid_ts = PolyFuncType::new_validated([decl.clone()], body_type.clone(), ®); + assert_eq!( + invalid_ts.err(), + Some(SignatureError::TypeVarDoesNotMatchDeclaration { + cached: TypeParam::Type(TypeBound::Copyable), + actual: decl + }) + ); + } + // Variable not declared at all + let invalid_ts = PolyFuncType::new_validated([], body_type, ®); + assert_eq!( + invalid_ts.err(), + Some(SignatureError::FreeTypeVar { + idx: 0, + num_decls: 0 + }) + ); + + Ok(()) + } + + fn decl_accepts_rejects_var( + bound: TypeParam, + accepted: &[TypeParam], + rejected: &[TypeParam], + ) -> Result<(), SignatureError> { + const EXT_ID: ExtensionId = ExtensionId::new_unchecked("my_ext"); + const TYPE_NAME: SmolStr = SmolStr::new_inline("MyType"); + + let mut e = Extension::new(EXT_ID); + e.add_type( + TYPE_NAME, + vec![bound.clone()], + "".into(), + TypeDefBound::Explicit(TypeBound::Any), + ) + .unwrap(); + + let reg: ExtensionRegistry = [e].into(); + + let make_scheme = |tp: TypeParam| { + PolyFuncType::new_validated( + [tp.clone()], + id_fn(Type::new_extension(CustomType::new( + TYPE_NAME, + [TypeArg::new_var_use(0, tp)], + EXT_ID, + TypeBound::Any, + ))), + ®, + ) + }; + for decl in accepted { + make_scheme(decl.clone())?; + } + for decl in rejected { + assert_eq!( + make_scheme(decl.clone()).err(), + Some(SignatureError::TypeArgMismatch( + TypeArgError::TypeMismatch { + param: bound.clone(), + arg: TypeArg::new_var_use(0, decl.clone()) + } + )) + ); + } + Ok(()) + } + + #[test] + fn test_bound_covariance() -> Result<(), SignatureError> { + decl_accepts_rejects_var( + TypeParam::Type(TypeBound::Copyable), + &[ + TypeParam::Type(TypeBound::Copyable), + TypeParam::Type(TypeBound::Eq), + ], + &[TypeParam::Type(TypeBound::Any)], + )?; + + let list_of_tys = |b| TypeParam::List(Box::new(TypeParam::Type(b))); + decl_accepts_rejects_var( + list_of_tys(TypeBound::Copyable), + &[list_of_tys(TypeBound::Copyable), list_of_tys(TypeBound::Eq)], + &[list_of_tys(TypeBound::Any)], + )?; + + decl_accepts_rejects_var( + TypeParam::max_nat(), + &[TypeParam::bounded_nat(NonZeroU64::new(5).unwrap())], + &[], + )?; + decl_accepts_rejects_var( + TypeParam::bounded_nat(NonZeroU64::new(10).unwrap()), + &[TypeParam::bounded_nat(NonZeroU64::new(5).unwrap())], + &[TypeParam::max_nat()], + )?; + Ok(()) + } + + fn new_pf1(param: TypeParam, input: Type, output: Type) -> PolyFuncType { + PolyFuncType { + params: vec![param], + body: FunctionType::new(vec![input], vec![output]), + } + } + + // The standard library new_array does not allow passing in a variable for size. + fn new_array(ty: Type, s: TypeArg) -> Type { + let array_def = PRELUDE.get_type("array").unwrap(); + Type::new_extension( + array_def + .instantiate(vec![TypeArg::Type { ty }, s]) + .unwrap(), + ) + } + + const USIZE_TA: TypeArg = TypeArg::Type { ty: USIZE_T }; + + #[test] + fn partial_instantiate() -> Result<(), SignatureError> { + // forall A,N.(Array -> A) + let array_max = PolyFuncType::new_validated( + vec![TypeParam::Type(TypeBound::Any), TypeParam::max_nat()], + FunctionType::new( + vec![new_array( + Type::new_var_use(0, TypeBound::Any), + TypeArg::new_var_use(1, TypeParam::max_nat()), + )], + vec![Type::new_var_use(0, TypeBound::Any)], + ), + &PRELUDE_REGISTRY, + )?; + + let concrete = FunctionType::new( + vec![new_array(USIZE_T, TypeArg::BoundedNat { n: 3 })], + vec![USIZE_T], + ); + let actual = array_max + .instantiate_poly(&[USIZE_TA, TypeArg::BoundedNat { n: 3 }], &PRELUDE_REGISTRY)?; + + assert_eq!(actual, concrete); + + // forall N.(Array -> usize) + let partial = PolyFuncType::new_validated( + vec![TypeParam::max_nat()], + FunctionType::new( + vec![new_array( + USIZE_T, + TypeArg::new_var_use(0, TypeParam::max_nat()), + )], + vec![USIZE_T], + ), + &PRELUDE_REGISTRY, + )?; + let res = array_max.instantiate_poly(&[USIZE_TA], &PRELUDE_REGISTRY)?; + assert_eq!(res, partial); + + Ok(()) + } + + fn list_of_tup(t1: Type, t2: Type) -> Type { + let list_def = EXTENSION.get_type(LIST_TYPENAME.as_str()).unwrap(); + Type::new_extension( + list_def + .instantiate([TypeArg::Type { + ty: Type::new_tuple(vec![t1, t2]), + }]) + .unwrap(), + ) + } + + // forall A. A -> (forall C. C -> List(Tuple(C, A)) + pub(crate) fn nested_func() -> PolyFuncType { + PolyFuncType::new_validated( + vec![TypeParam::Type(TypeBound::Any)], + FunctionType::new( + vec![Type::new_var_use(0, TypeBound::Any)], + vec![Type::new_function(new_pf1( + TypeParam::Type(TypeBound::Copyable), + Type::new_var_use(0, TypeBound::Copyable), + list_of_tup( + Type::new_var_use(0, TypeBound::Copyable), + Type::new_var_use(1, TypeBound::Any), // The outer variable (renumbered) + ), + ))], + ), + &[EXTENSION.to_owned()].into(), + ) + .unwrap() + } + + #[test] + fn test_instantiate_nested() -> Result<(), SignatureError> { + let outer = nested_func(); + let reg: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); + + let arg = new_array(USIZE_T, TypeArg::BoundedNat { n: 5 }); + // `arg` -> (forall C. C -> List(Tuple(C, `arg`))) + let outer_applied = FunctionType::new( + vec![arg.clone()], // This had index 0, but is replaced + vec![Type::new_function(new_pf1( + TypeParam::Type(TypeBound::Copyable), + // We are checking that the substitution has been applied to the right var + // - NOT to the inner_var which has index 0 here + Type::new_var_use(0, TypeBound::Copyable), + list_of_tup( + Type::new_var_use(0, TypeBound::Copyable), + arg.clone(), // This had index 1, but is replaced + ), + ))], + ); + + let res = outer.instantiate(&[TypeArg::Type { ty: arg }], ®)?; + assert_eq!(res, outer_applied); + Ok(()) + } + + #[test] + fn free_var_under_binder() { + let outer = nested_func(); + + // Now substitute in a free var from further outside + let reg = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); + const FREE: usize = 3; + const TP_EQ: TypeParam = TypeParam::Type(TypeBound::Eq); + let res = outer + .instantiate(&[TypeArg::new_var_use(FREE, TP_EQ)], ®) + .unwrap(); + assert_eq!( + res, + // F -> forall C. (C -> List(Tuple(C, F))) + FunctionType::new( + vec![Type::new_var_use(FREE, TypeBound::Eq)], + vec![Type::new_function(new_pf1( + TypeParam::Type(TypeBound::Copyable), + Type::new_var_use(0, TypeBound::Copyable), // unchanged + list_of_tup( + Type::new_var_use(0, TypeBound::Copyable), + // Next is the free variable that we substituted in (hence Eq) + // - renumbered because of the intervening forall (Copyable) + Type::new_var_use(FREE + 1, TypeBound::Eq) + ) + ))] + ) + ); + + // Also try substituting in a type containing both free and bound vars + let rhs = |i| { + Type::new_function(new_pf1( + TP_EQ, + Type::new_var_use(0, TypeBound::Eq), + new_array( + Type::new_var_use(0, TypeBound::Eq), + TypeArg::new_var_use(i, TypeParam::max_nat()), + ), + )) + }; + + let res = outer + .instantiate(&[TypeArg::Type { ty: rhs(FREE) }], ®) + .unwrap(); + assert_eq!( + res, + FunctionType::new( + vec![rhs(FREE)], // Input: forall TEQ. (TEQ -> Array(TEQ, FREE)) + // Output: forall C. C -> List(Tuple(C, Input)) + vec![Type::new_function(new_pf1( + TypeParam::Type(TypeBound::Copyable), + Type::new_var_use(0, TypeBound::Copyable), + list_of_tup( + Type::new_var_use(0, TypeBound::Copyable), // not renumbered... + rhs(FREE + 1) // renumbered + ) + ))] + ) + ) + } +} diff --git a/src/types/primitive.rs b/src/types/primitive.rs index 7602cc8c0..fe34151e5 100644 --- a/src/types/primitive.rs +++ b/src/types/primitive.rs @@ -2,7 +2,7 @@ use crate::ops::AliasDecl; -use super::{CustomType, FunctionType, TypeBound}; +use super::{CustomType, PolyFuncType, TypeBound}; #[derive( Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, @@ -18,7 +18,11 @@ pub enum PrimType { Alias(AliasDecl), #[allow(missing_docs)] #[display(fmt = "Function({})", "_0")] - Function(Box), + Function(Box), + // DeBruijn index, and cache of TypeBound (checked in validation) + #[allow(missing_docs)] + #[display(fmt = "Variable({})", _0)] + Variable(usize, TypeBound), } impl PrimType { @@ -28,6 +32,7 @@ impl PrimType { PrimType::Extension(c) => c.bound(), PrimType::Alias(a) => a.bound, PrimType::Function(_) => TypeBound::Copyable, + PrimType::Variable(_, b) => *b, } } } diff --git a/src/types/serialize.rs b/src/types/serialize.rs index f55258fc9..34ad609ed 100644 --- a/src/types/serialize.rs +++ b/src/types/serialize.rs @@ -1,9 +1,7 @@ -use super::{SumType, Type, TypeEnum, TypeRow}; +use super::{PolyFuncType, SumType, Type, TypeBound, TypeEnum, TypeRow}; use super::custom::CustomType; -use super::FunctionType; - use crate::extension::prelude::{array_type, QB_T, USIZE_T}; use crate::ops::AliasDecl; use crate::types::primitive::PrimType; @@ -13,12 +11,13 @@ use crate::types::primitive::PrimType; pub(super) enum SerSimpleType { Q, I, - G(Box), + G(Box), Tuple { inner: TypeRow }, Sum(SumType), Array { inner: Box, len: u64 }, Opaque(CustomType), Alias(AliasDecl), + V { i: usize, b: TypeBound }, } impl From for SerSimpleType { @@ -35,7 +34,8 @@ impl From for SerSimpleType { TypeEnum::Prim(t) => match t { PrimType::Extension(c) => SerSimpleType::Opaque(c), PrimType::Alias(a) => SerSimpleType::Alias(a), - PrimType::Function(sig) => SerSimpleType::G(Box::new(*sig)), + PrimType::Function(sig) => SerSimpleType::G(sig), + PrimType::Variable(i, b) => SerSimpleType::V { i, b }, }, TypeEnum::Sum(sum) => SerSimpleType::Sum(sum), TypeEnum::Tuple(inner) => SerSimpleType::Tuple { inner }, @@ -54,6 +54,7 @@ impl From for Type { SerSimpleType::Array { inner, len } => array_type((*inner).into(), len), SerSimpleType::Opaque(custom) => Type::new_extension(custom), SerSimpleType::Alias(a) => Type::new_alias(a), + SerSimpleType::V { i, b } => Type::new_var_use(i, b), } } } diff --git a/src/types/signature.rs b/src/types/signature.rs index bcd28d2ea..90df38efd 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -6,8 +6,10 @@ use pyo3::{pyclass, pymethods}; use delegate::delegate; use std::fmt::{self, Display, Write}; -use crate::extension::ExtensionSet; -use crate::types::{Type, TypeRow}; +use super::type_param::TypeParam; +use super::{subst_row, Substitution, Type, TypeRow}; + +use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; use crate::{Direction, IncomingPort, OutgoingPort, Port}; #[cfg_attr(feature = "pyo3", pyclass)] @@ -51,6 +53,26 @@ impl FunctionType { pub fn pure(self) -> Signature { self.with_input_extensions(ExtensionSet::new()) } + + pub(crate) fn validate( + &self, + extension_registry: &ExtensionRegistry, + var_decls: &[TypeParam], + ) -> Result<(), SignatureError> { + self.input + .iter() + .chain(self.output.iter()) + .try_for_each(|t| t.validate(extension_registry, var_decls))?; + self.extension_reqs.validate(var_decls) + } + + pub(crate) fn substitute(&self, tr: &impl Substitution) -> Self { + FunctionType { + input: subst_row(&self.input, tr), + output: subst_row(&self.output, tr), + extension_reqs: self.extension_reqs.substitute(tr), + } + } } impl From for FunctionType { diff --git a/src/types/type_param.rs b/src/types/type_param.rs index 468bb40d6..d8546b58f 100644 --- a/src/types/type_param.rs +++ b/src/types/type_param.rs @@ -4,21 +4,22 @@ //! //! [`TypeDef`]: crate::extension::TypeDef +use itertools::Itertools; use std::num::NonZeroU64; - use thiserror::Error; use crate::extension::ExtensionRegistry; use crate::extension::ExtensionSet; use crate::extension::SignatureError; -use super::CustomType; -use super::Type; -use super::TypeBound; +use super::{check_typevar_decl, CustomType, Substitution, Type, TypeBound}; -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] // A None inner value implies the maximum bound: u64::MAX + 1 (all u64 values valid) +#[derive( + Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Deserialize, serde::Serialize, +)] +#[display(fmt = "{}", "_0.map(|i|i.to_string()).unwrap_or(\"-\".to_string())")] pub struct UpperBound(Option); impl UpperBound { fn valid_value(&self, val: u64) -> bool { @@ -28,11 +29,23 @@ impl UpperBound { _ => false, } } + fn contains(&self, other: &UpperBound) -> bool { + match (self.0, other.0) { + (None, _) => true, + (Some(b1), Some(b2)) if b1 >= b2 => true, + _ => false, + } + } } -/// A parameter declared by an OpDef. Specifies a value -/// that must be provided by each operation node. -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +/// A *kind* of [TypeArg]. Thus, a parameter declared by a [PolyFuncType] (e.g. [OpDef]), +/// specifying a value that may (resp. must) be provided to instantiate it. +/// +/// [PolyFuncType]: super::PolyFuncType +/// [OpDef]: crate::extension::OpDef +#[derive( + Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Deserialize, serde::Serialize, +)] #[non_exhaustive] pub enum TypeParam { /// Argument is a [TypeArg::Type]. @@ -44,6 +57,7 @@ pub enum TypeParam { /// Argument is a [TypeArg::Sequence]. A list of indeterminate size containing parameters. List(Box), /// Argument is a [TypeArg::Sequence]. A tuple of parameters. + #[display(fmt = "Tuple({})", "_0.iter().map(|t|t.to_string()).join(\", \")")] Tuple(Vec), /// Argument is a [TypeArg::Extensions]. A set of [ExtensionId]s. /// @@ -61,6 +75,20 @@ impl TypeParam { pub const fn bounded_nat(upper_bound: NonZeroU64) -> Self { Self::BoundedNat(UpperBound(Some(upper_bound))) } + + fn contains(&self, other: &TypeParam) -> bool { + match (self, other) { + (TypeParam::Type(b1), TypeParam::Type(b2)) => b1.contains(*b2), + (TypeParam::BoundedNat(b1), TypeParam::BoundedNat(b2)) => b1.contains(b2), + (TypeParam::Opaque(c1), TypeParam::Opaque(c2)) => c1 == c2, + (TypeParam::List(e1), TypeParam::List(e2)) => e1.contains(e2), + (TypeParam::Tuple(es1), TypeParam::Tuple(es2)) => { + es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.contains(e2)) + } + (TypeParam::Extensions, TypeParam::Extensions) => true, + _ => false, + } + } } /// A statically-known argument value to an operation. @@ -84,35 +112,102 @@ pub enum TypeArg { arg: CustomTypeArg, }, /// Instance of [TypeParam::List] or [TypeParam::Tuple], defined by a - /// sequence of arguments. + /// sequence of elements. Sequence { #[allow(missing_docs)] - args: Vec, + elems: Vec, }, /// Instance of [TypeParam::Extensions], providing the extension ids. Extensions { #[allow(missing_docs)] es: ExtensionSet, }, + /// Variable (used in type schemes only), that is not a [TypeArg::Type] + /// or [TypeArg::Extensions] - see [TypeArg::new_var_use] + Variable { + #[allow(missing_docs)] + v: TypeArgVariable, + }, +} + +/// Variable in a TypeArg, that is not a [TypeArg::Type] or [TypeArg::Extensions], +#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct TypeArgVariable { + idx: usize, + cached_decl: TypeParam, } impl TypeArg { + /// Makes a TypeArg representing a use (occurrence) of the type variable + /// with the specified DeBruijn index. For use within type schemes only: + /// `bound` must match that with which the variable was declared. + pub fn new_var_use(idx: usize, decl: TypeParam) -> Self { + match decl { + TypeParam::Type(b) => TypeArg::Type { + ty: Type::new_var_use(idx, b), + }, + TypeParam::Extensions => TypeArg::Extensions { + es: ExtensionSet::type_var(idx), + }, + _ => TypeArg::Variable { + v: TypeArgVariable { + idx, + cached_decl: decl, + }, + }, + } + } + + /// Much as [Type::validate], also checks that the type of any [TypeArg::Opaque] + /// is valid and closed. pub(crate) fn validate( &self, extension_registry: &ExtensionRegistry, + var_decls: &[TypeParam], ) -> Result<(), SignatureError> { match self { - TypeArg::Type { ty } => ty.validate(extension_registry), + TypeArg::Type { ty } => ty.validate(extension_registry, var_decls), TypeArg::BoundedNat { .. } => Ok(()), TypeArg::Opaque { arg: custarg } => { // We could also add a facility to Extension to validate that the constant *value* // here is a valid instance of the type. - custarg.typ.validate(extension_registry) - } - TypeArg::Sequence { args } => { - args.iter().try_for_each(|a| a.validate(extension_registry)) + // The type must be equal to that declared (in a TypeParam) by the instantiated TypeDef, + // so cannot contain variables declared by the instantiator (providing the TypeArgs) + custarg.typ.validate(extension_registry, &[]) } + TypeArg::Sequence { elems } => elems + .iter() + .try_for_each(|a| a.validate(extension_registry, var_decls)), TypeArg::Extensions { es: _ } => Ok(()), + TypeArg::Variable { + v: TypeArgVariable { idx, cached_decl }, + } => check_typevar_decl(var_decls, *idx, cached_decl), + } + } + + pub(crate) fn substitute(&self, t: &impl Substitution) -> Self { + match self { + TypeArg::Type { ty } => TypeArg::Type { + ty: ty.substitute(t), + }, + TypeArg::BoundedNat { .. } => self.clone(), // We do not allow variables as bounds on BoundedNat's + TypeArg::Opaque { + arg: CustomTypeArg { typ, .. }, + } => { + // The type must be equal to that declared (in a TypeParam) by the instantiated TypeDef, + // so cannot contain variables declared by the instantiator (providing the TypeArgs) + debug_assert_eq!(&typ.substitute(t), typ); + self.clone() + } + TypeArg::Sequence { elems } => TypeArg::Sequence { + elems: elems.iter().map(|ta| ta.substitute(t)).collect(), + }, + TypeArg::Extensions { es } => TypeArg::Extensions { + es: es.substitute(t), + }, + TypeArg::Variable { + v: TypeArgVariable { idx, cached_decl }, + } => t.apply_var(*idx, cached_decl), } } } @@ -143,15 +238,21 @@ impl CustomTypeArg { /// Checks a [TypeArg] is as expected for a [TypeParam] pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgError> { match (arg, param) { - (TypeArg::Type { ty: t }, TypeParam::Type(bound)) - if bound.contains(t.least_upper_bound()) => + ( + TypeArg::Variable { + v: TypeArgVariable { cached_decl, .. }, + }, + _, + ) if param.contains(cached_decl) => Ok(()), + (TypeArg::Type { ty }, TypeParam::Type(bound)) + if bound.contains(ty.least_upper_bound()) => { Ok(()) } - (TypeArg::Sequence { args: items }, TypeParam::List(param)) => { - items.iter().try_for_each(|arg| check_type_arg(arg, param)) + (TypeArg::Sequence { elems }, TypeParam::List(param)) => { + elems.iter().try_for_each(|arg| check_type_arg(arg, param)) } - (TypeArg::Sequence { args: items }, TypeParam::Tuple(types)) => { + (TypeArg::Sequence { elems: items }, TypeParam::Tuple(types)) => { if items.len() != types.len() { Err(TypeArgError::WrongNumberTuple(items.len(), types.len())) } else { @@ -180,6 +281,17 @@ pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgErr } } +/// Check a list of type arguments match a list of required type parameters +pub fn check_type_args(args: &[TypeArg], params: &[TypeParam]) -> Result<(), TypeArgError> { + if args.len() != params.len() { + return Err(TypeArgError::WrongNumberArgs(args.len(), params.len())); + } + for (a, p) in args.iter().zip(params.iter()) { + check_type_arg(a, p)?; + } + Ok(()) +} + /// Errors that can occur fitting a [TypeArg] into a [TypeParam] #[derive(Clone, Debug, PartialEq, Eq, Error)] pub enum TypeArgError {