diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 71ddff368..a40ea306b 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -7,7 +7,7 @@ use super::{ Extension, ExtensionBuildError, ExtensionId, ExtensionSet, SignatureError, TypeParametrised, }; -use crate::types::{SignatureDescription, TypeRow}; +use crate::types::SignatureDescription; use crate::types::FunctionType; @@ -31,8 +31,8 @@ pub trait CustomSignatureFunc: Send + Sync { name: &SmolStr, arg_values: &[TypeArg], misc: &HashMap, - // TODO: Make return type an FunctionType - ) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError>; + ) -> Result; + /// Describe the signature of a node, given the operation name, /// values for the type parameters, /// and 'misc' data from the extension definition YAML. @@ -48,14 +48,14 @@ pub trait CustomSignatureFunc: Send + Sync { impl CustomSignatureFunc for F where - F: Fn(&[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> + Send + Sync, + F: Fn(&[TypeArg]) -> Result + Send + Sync, { fn compute_signature( &self, _name: &SmolStr, arg_values: &[TypeArg], _misc: &HashMap, - ) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { + ) -> Result { self(arg_values) } } @@ -217,7 +217,7 @@ impl OpDef { /// OpDef with statically-provided [TypeArg]s. pub fn compute_signature(&self, args: &[TypeArg]) -> Result { self.check_args(args)?; - let (ins, outs, res) = match &self.signature_func { + let res = match &self.signature_func { SignatureFunc::FromDecl { .. } => { // Sig should be computed solely from inputs + outputs + args. todo!() @@ -227,7 +227,7 @@ impl OpDef { // TODO bring this assert back once resource inference is done? // https://github.com/CQCL-DEV/hugr/issues/425 // assert!(res.contains(self.extension())); - Ok(FunctionType::new(ins, outs).with_extension_delta(&res)) + Ok(res) } /// Optional description of the ports in the signature. diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index c2cad2fa0..c299a9eed 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -7,7 +7,7 @@ use smol_str::SmolStr; use crate::{ extension::{ExtensionSet, SignatureError}, type_row, - types::{type_param::TypeArg, Type, TypeRow}, + types::{type_param::TypeArg, FunctionType, Type}, utils::collect_array, Extension, }; @@ -18,25 +18,22 @@ use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.conversions"); -fn ftoi_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn ftoi_sig(arg_values: &[TypeArg]) -> Result { let [arg] = collect_array(arg_values); - Ok(( + Ok(FunctionType::new( type_row![FLOAT64_TYPE], vec![Type::new_sum(vec![ int_type(arg.clone()), crate::extension::prelude::ERROR_TYPE, - ])] - .into(), - ExtensionSet::default(), + ])], )) } -fn itof_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn itof_sig(arg_values: &[TypeArg]) -> Result { let [arg] = collect_array(arg_values); - Ok(( - vec![int_type(arg.clone())].into(), + Ok(FunctionType::new( + vec![int_type(arg.clone())], type_row![FLOAT64_TYPE], - ExtensionSet::default(), )) } diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 16a3592b5..f328e20fb 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -5,7 +5,7 @@ use smol_str::SmolStr; use crate::{ extension::{ExtensionSet, SignatureError}, type_row, - types::{type_param::TypeArg, TypeRow}, + types::{type_param::TypeArg, FunctionType}, Extension, }; @@ -14,27 +14,24 @@ use super::float_types::FLOAT64_TYPE; /// The extension identifier. pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.float"); -fn fcmp_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { - Ok(( +fn fcmp_sig(_arg_values: &[TypeArg]) -> Result { + Ok(FunctionType::new( type_row![FLOAT64_TYPE; 2], type_row![crate::extension::prelude::BOOL_T], - ExtensionSet::default(), )) } -fn fbinop_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { - Ok(( +fn fbinop_sig(_arg_values: &[TypeArg]) -> Result { + Ok(FunctionType::new( type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE], - ExtensionSet::default(), )) } -fn funop_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { - Ok(( +fn funop_sig(_arg_values: &[TypeArg]) -> Result { + Ok(FunctionType::new( type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE], - ExtensionSet::default(), )) } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index eae8847de..b647a7d47 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -5,6 +5,7 @@ use smol_str::SmolStr; use super::int_types::{get_log_width, int_type, type_arg, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{BOOL_T, ERROR_TYPE}; use crate::type_row; +use crate::types::FunctionType; use crate::utils::collect_array; use crate::{ extension::{ExtensionSet, SignatureError}, @@ -15,111 +16,100 @@ use crate::{ /// The extension identifier. pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.int"); -fn iwiden_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn iwiden_sig(arg_values: &[TypeArg]) -> Result { let [arg0, arg1] = collect_array(arg_values); let m: u8 = get_log_width(arg0)?; let n: u8 = get_log_width(arg1)?; if m > n { return Err(SignatureError::InvalidTypeArgs); } - Ok(( - vec![int_type(arg0.clone())].into(), - vec![int_type(arg1.clone())].into(), - ExtensionSet::default(), + Ok(FunctionType::new( + vec![int_type(arg0.clone())], + vec![int_type(arg1.clone())], )) } -fn inarrow_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn inarrow_sig(arg_values: &[TypeArg]) -> Result { let [arg0, arg1] = collect_array(arg_values); let m: u8 = get_log_width(arg0)?; let n: u8 = get_log_width(arg1)?; if m < n { return Err(SignatureError::InvalidTypeArgs); } - Ok(( - vec![int_type(arg0.clone())].into(), - vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])].into(), - ExtensionSet::default(), + Ok(FunctionType::new( + vec![int_type(arg0.clone())], + vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])], )) } -fn itob_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { - Ok(( - vec![int_type(type_arg(0))].into(), +fn itob_sig(_arg_values: &[TypeArg]) -> Result { + Ok(FunctionType::new( + vec![int_type(type_arg(0))], type_row![BOOL_T], - ExtensionSet::default(), )) } -fn btoi_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { - Ok(( +fn btoi_sig(_arg_values: &[TypeArg]) -> Result { + Ok(FunctionType::new( type_row![BOOL_T], - vec![int_type(type_arg(0))].into(), - ExtensionSet::default(), + vec![int_type(type_arg(0))], )) } -fn icmp_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn icmp_sig(arg_values: &[TypeArg]) -> Result { let [arg] = collect_array(arg_values); - Ok(( - vec![int_type(arg.clone()); 2].into(), + Ok(FunctionType::new( + vec![int_type(arg.clone()); 2], type_row![BOOL_T], - ExtensionSet::default(), )) } -fn ibinop_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn ibinop_sig(arg_values: &[TypeArg]) -> Result { let [arg] = collect_array(arg_values); - Ok(( - vec![int_type(arg.clone()); 2].into(), - vec![int_type(arg.clone())].into(), - ExtensionSet::default(), + Ok(FunctionType::new( + vec![int_type(arg.clone()); 2], + vec![int_type(arg.clone())], )) } -fn iunop_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn iunop_sig(arg_values: &[TypeArg]) -> Result { let [arg] = collect_array(arg_values); - Ok(( - vec![int_type(arg.clone())].into(), - vec![int_type(arg.clone())].into(), - ExtensionSet::default(), + Ok(FunctionType::new( + vec![int_type(arg.clone())], + vec![int_type(arg.clone())], )) } -fn idivmod_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn idivmod_sig(arg_values: &[TypeArg]) -> Result { let [arg0, arg1] = collect_array(arg_values); let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into(); - Ok(( + Ok(FunctionType::new( intpair.clone(), - vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])].into(), - ExtensionSet::default(), + vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])], )) } -fn idiv_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn idiv_sig(arg_values: &[TypeArg]) -> Result { let [arg0, arg1] = collect_array(arg_values); - Ok(( - vec![int_type(arg0.clone()), int_type(arg1.clone())].into(), - vec![Type::new_sum(vec![int_type(arg0.clone()), ERROR_TYPE])].into(), - ExtensionSet::default(), + Ok(FunctionType::new( + vec![int_type(arg0.clone()), int_type(arg1.clone())], + vec![Type::new_sum(vec![int_type(arg0.clone()), ERROR_TYPE])], )) } -fn imod_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn imod_sig(arg_values: &[TypeArg]) -> Result { let [arg0, arg1] = collect_array(arg_values); - Ok(( - vec![int_type(arg0.clone()), int_type(arg1.clone())].into(), - vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])].into(), - ExtensionSet::default(), + Ok(FunctionType::new( + vec![int_type(arg0.clone()), int_type(arg1.clone())], + vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])], )) } -fn ish_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { +fn ish_sig(arg_values: &[TypeArg]) -> Result { let [arg0, arg1] = collect_array(arg_values); - Ok(( - vec![int_type(arg0.clone()), int_type(arg1.clone())].into(), - vec![int_type(arg0.clone())].into(), - ExtensionSet::default(), + Ok(FunctionType::new( + vec![int_type(arg0.clone()), int_type(arg1.clone())], + vec![int_type(arg0.clone())], )) } diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 9718edd9e..b311b2c01 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -8,7 +8,7 @@ use crate::{ extension::{ExtensionSet, SignatureError, TypeDef, TypeDefBound}, types::{ type_param::{TypeArg, TypeParam}, - CustomCheckFailure, CustomType, Type, TypeBound, TypeRow, + CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow, }, values::{CustomConst, Value}, Extension, @@ -80,11 +80,11 @@ fn extension() -> Extension { vec![], move |args: &[TypeArg]| { let (list_type, element_type) = list_types(args)?; - - let inputs = TypeRow::from(vec![list_type.clone()]); - let outputs = TypeRow::from(vec![list_type, element_type]); - let extension_set = ExtensionSet::singleton(&EXTENSION_NAME); - Ok((inputs, outputs, extension_set)) + Ok(FunctionType { + input: TypeRow::from(vec![list_type.clone()]), + output: TypeRow::from(vec![list_type, element_type]), + extension_reqs: ExtensionSet::singleton(&EXTENSION_NAME), + }) }, ) .unwrap(); @@ -97,11 +97,11 @@ fn extension() -> Extension { vec![], move |args: &[TypeArg]| { let (list_type, element_type) = list_types(args)?; - - let outputs = TypeRow::from(vec![list_type.clone()]); - let inputs = TypeRow::from(vec![list_type, element_type]); - let extension_set = ExtensionSet::singleton(&EXTENSION_NAME); - Ok((inputs, outputs, extension_set)) + Ok(FunctionType { + output: TypeRow::from(vec![list_type.clone()]), + input: TypeRow::from(vec![list_type, element_type]), + extension_reqs: ExtensionSet::singleton(&EXTENSION_NAME), + }) }, ) .unwrap(); diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 957985482..565bca4f4 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -4,9 +4,12 @@ use itertools::Itertools; use smol_str::SmolStr; use crate::{ - extension::{prelude::BOOL_T, ExtensionSet}, + extension::prelude::BOOL_T, ops, type_row, - types::type_param::{TypeArg, TypeArgError, TypeParam}, + types::{ + type_param::{TypeArg, TypeArgError, TypeParam}, + FunctionType, + }, Extension, }; use lazy_static::lazy_static; @@ -35,13 +38,7 @@ fn extension() -> Extension { SmolStr::new_inline(NOT_NAME), "logical 'not'".into(), vec![], - |_arg_values: &[TypeArg]| { - Ok(( - type_row![BOOL_T], - type_row![BOOL_T], - ExtensionSet::default(), - )) - }, + |_arg_values: &[TypeArg]| Ok(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])), ) .unwrap(); @@ -62,10 +59,9 @@ fn extension() -> Extension { .into()); } }; - Ok(( - vec![BOOL_T; n as usize].into(), + Ok(FunctionType::new( + vec![BOOL_T; n as usize], type_row![BOOL_T], - ExtensionSet::default(), )) }, ) @@ -88,10 +84,9 @@ fn extension() -> Extension { .into()); } }; - Ok(( - vec![BOOL_T; n as usize].into(), + Ok(FunctionType::new( + vec![BOOL_T; n as usize], type_row![BOOL_T], - ExtensionSet::default(), )) }, ) diff --git a/src/std_extensions/quantum.rs b/src/std_extensions/quantum.rs index d97ecc62f..2d41c272c 100644 --- a/src/std_extensions/quantum.rs +++ b/src/std_extensions/quantum.rs @@ -3,26 +3,25 @@ use smol_str::SmolStr; use crate::extension::prelude::{BOOL_T, QB_T}; -use crate::extension::{ExtensionSet, SignatureError}; +use crate::extension::SignatureError; use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use crate::type_row; use crate::types::type_param::TypeArg; -use crate::types::TypeRow; +use crate::types::FunctionType; use crate::Extension; use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("quantum"); -fn one_qb_func(_: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { - Ok((type_row![QB_T], type_row![QB_T], ExtensionSet::new())) +fn one_qb_func(_: &[TypeArg]) -> Result { + Ok(FunctionType::new(type_row![QB_T], type_row![QB_T])) } -fn two_qb_func(_: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> { - Ok(( +fn two_qb_func(_: &[TypeArg]) -> Result { + Ok(FunctionType::new( type_row![QB_T, QB_T], type_row![QB_T, QB_T], - ExtensionSet::new(), )) } @@ -43,10 +42,9 @@ fn extension() -> Extension { "Rotation specified by float".into(), vec![], |_: &[_]| { - Ok(( + Ok(FunctionType::new( type_row![QB_T, FLOAT64_TYPE], type_row![QB_T], - ExtensionSet::new(), )) }, ) @@ -62,14 +60,10 @@ fn extension() -> Extension { "Measure a qubit, returning the qubit and the measurement result.".into(), vec![], |_arg_values: &[TypeArg]| { - Ok(( - type_row![QB_T], - type_row![QB_T, BOOL_T], - // TODO add logic as an extension delta when inference is - // done? - // https://github.com/CQCL-DEV/hugr/issues/425 - ExtensionSet::new(), - )) + Ok(FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T])) + // TODO add logic as an extension delta when inference is + // done? + // https://github.com/CQCL-DEV/hugr/issues/425 }, ) .unwrap(); diff --git a/src/std_extensions/rotation.rs b/src/std_extensions/rotation.rs index 47a1fcbde..eb5ba0602 100644 --- a/src/std_extensions/rotation.rs +++ b/src/std_extensions/rotation.rs @@ -10,9 +10,8 @@ use smol_str::SmolStr; #[cfg(feature = "pyo3")] use pyo3::prelude::*; -use crate::extension::ExtensionSet; use crate::types::type_param::TypeArg; -use crate::types::{CustomCheckFailure, CustomType, Type, TypeBound, TypeRow}; +use crate::types::{CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow}; use crate::values::CustomConst; use crate::{ops, Extension}; @@ -47,7 +46,7 @@ pub fn extension() -> Extension { |_arg_values: &[TypeArg]| { let t: TypeRow = vec![Type::new_extension(RotationType::Angle.custom_type())].into(); - Ok((t.clone(), t, ExtensionSet::default())) + Ok(FunctionType::new(t.clone(), t)) }, ) .unwrap();