Skip to content

Commit

Permalink
Change compute_signature to return FunctionType (et al) (#438)
Browse files Browse the repository at this point in the history
....as this is now equal to the current tuple
(TypeRow,TypeRow,Extensions). A longstanding TODO.

Note this does *not* address the confusion/disagreement across the
codebase as to whether the ExtensionSet should include the resource
declaring the op.
  • Loading branch information
acl-cqc authored Aug 29, 2023
1 parent 95c0b56 commit a323545
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 124 deletions.
14 changes: 7 additions & 7 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{
Extension, ExtensionBuildError, ExtensionId, ExtensionSet, SignatureError, TypeParametrised,
};

use crate::types::{SignatureDescription, TypeRow};
use crate::types::SignatureDescription;

use crate::types::FunctionType;

Expand All @@ -31,8 +31,8 @@ pub trait CustomSignatureFunc: Send + Sync {
name: &SmolStr,
arg_values: &[TypeArg],
misc: &HashMap<String, serde_yaml::Value>,
// TODO: Make return type an FunctionType
) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError>;
) -> Result<FunctionType, SignatureError>;

/// Describe the signature of a node, given the operation name,
/// values for the type parameters,
/// and 'misc' data from the extension definition YAML.
Expand All @@ -48,14 +48,14 @@ pub trait CustomSignatureFunc: Send + Sync {

impl<F> CustomSignatureFunc for F
where
F: Fn(&[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> + Send + Sync,
F: Fn(&[TypeArg]) -> Result<FunctionType, SignatureError> + Send + Sync,
{
fn compute_signature(
&self,
_name: &SmolStr,
arg_values: &[TypeArg],
_misc: &HashMap<String, serde_yaml::Value>,
) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
) -> Result<FunctionType, SignatureError> {
self(arg_values)
}
}
Expand Down Expand Up @@ -217,7 +217,7 @@ impl OpDef {
/// OpDef with statically-provided [TypeArg]s.
pub fn compute_signature(&self, args: &[TypeArg]) -> Result<FunctionType, SignatureError> {
self.check_args(args)?;
let (ins, outs, res) = match &self.signature_func {
let res = match &self.signature_func {
SignatureFunc::FromDecl { .. } => {
// Sig should be computed solely from inputs + outputs + args.
todo!()
Expand All @@ -227,7 +227,7 @@ impl OpDef {
// TODO bring this assert back once resource inference is done?
// https://github.com/CQCL-DEV/hugr/issues/425
// assert!(res.contains(self.extension()));
Ok(FunctionType::new(ins, outs).with_extension_delta(&res))
Ok(res)
}

/// Optional description of the ports in the signature.
Expand Down
17 changes: 7 additions & 10 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use smol_str::SmolStr;
use crate::{
extension::{ExtensionSet, SignatureError},
type_row,
types::{type_param::TypeArg, Type, TypeRow},
types::{type_param::TypeArg, FunctionType, Type},
utils::collect_array,
Extension,
};
Expand All @@ -18,25 +18,22 @@ use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
/// The extension identifier.
pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.conversions");

fn ftoi_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn ftoi_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok((
Ok(FunctionType::new(
type_row![FLOAT64_TYPE],
vec![Type::new_sum(vec![
int_type(arg.clone()),
crate::extension::prelude::ERROR_TYPE,
])]
.into(),
ExtensionSet::default(),
])],
))
}

fn itof_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn itof_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok((
vec![int_type(arg.clone())].into(),
Ok(FunctionType::new(
vec![int_type(arg.clone())],
type_row![FLOAT64_TYPE],
ExtensionSet::default(),
))
}

Expand Down
17 changes: 7 additions & 10 deletions src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use smol_str::SmolStr;
use crate::{
extension::{ExtensionSet, SignatureError},
type_row,
types::{type_param::TypeArg, TypeRow},
types::{type_param::TypeArg, FunctionType},
Extension,
};

Expand All @@ -14,27 +14,24 @@ use super::float_types::FLOAT64_TYPE;
/// The extension identifier.
pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.float");

fn fcmp_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
fn fcmp_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE; 2],
type_row![crate::extension::prelude::BOOL_T],
ExtensionSet::default(),
))
}

fn fbinop_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
fn fbinop_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE; 2],
type_row![FLOAT64_TYPE],
ExtensionSet::default(),
))
}

fn funop_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
fn funop_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE],
type_row![FLOAT64_TYPE],
ExtensionSet::default(),
))
}

Expand Down
92 changes: 41 additions & 51 deletions src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use smol_str::SmolStr;
use super::int_types::{get_log_width, int_type, type_arg, LOG_WIDTH_TYPE_PARAM};
use crate::extension::prelude::{BOOL_T, ERROR_TYPE};
use crate::type_row;
use crate::types::FunctionType;
use crate::utils::collect_array;
use crate::{
extension::{ExtensionSet, SignatureError},
Expand All @@ -15,111 +16,100 @@ use crate::{
/// The extension identifier.
pub const EXTENSION_ID: SmolStr = SmolStr::new_inline("arithmetic.int");

fn iwiden_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn iwiden_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
let m: u8 = get_log_width(arg0)?;
let n: u8 = get_log_width(arg1)?;
if m > n {
return Err(SignatureError::InvalidTypeArgs);
}
Ok((
vec![int_type(arg0.clone())].into(),
vec![int_type(arg1.clone())].into(),
ExtensionSet::default(),
Ok(FunctionType::new(
vec![int_type(arg0.clone())],
vec![int_type(arg1.clone())],
))
}

fn inarrow_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn inarrow_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
let m: u8 = get_log_width(arg0)?;
let n: u8 = get_log_width(arg1)?;
if m < n {
return Err(SignatureError::InvalidTypeArgs);
}
Ok((
vec![int_type(arg0.clone())].into(),
vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])].into(),
ExtensionSet::default(),
Ok(FunctionType::new(
vec![int_type(arg0.clone())],
vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])],
))
}

fn itob_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
vec![int_type(type_arg(0))].into(),
fn itob_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
vec![int_type(type_arg(0))],
type_row![BOOL_T],
ExtensionSet::default(),
))
}

fn btoi_sig(_arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
Ok((
fn btoi_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![BOOL_T],
vec![int_type(type_arg(0))].into(),
ExtensionSet::default(),
vec![int_type(type_arg(0))],
))
}

fn icmp_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn icmp_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok((
vec![int_type(arg.clone()); 2].into(),
Ok(FunctionType::new(
vec![int_type(arg.clone()); 2],
type_row![BOOL_T],
ExtensionSet::default(),
))
}

fn ibinop_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn ibinop_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok((
vec![int_type(arg.clone()); 2].into(),
vec![int_type(arg.clone())].into(),
ExtensionSet::default(),
Ok(FunctionType::new(
vec![int_type(arg.clone()); 2],
vec![int_type(arg.clone())],
))
}

fn iunop_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn iunop_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok((
vec![int_type(arg.clone())].into(),
vec![int_type(arg.clone())].into(),
ExtensionSet::default(),
Ok(FunctionType::new(
vec![int_type(arg.clone())],
vec![int_type(arg.clone())],
))
}

fn idivmod_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn idivmod_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into();
Ok((
Ok(FunctionType::new(
intpair.clone(),
vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])].into(),
ExtensionSet::default(),
vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])],
))
}

fn idiv_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn idiv_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
Ok((
vec![int_type(arg0.clone()), int_type(arg1.clone())].into(),
vec![Type::new_sum(vec![int_type(arg0.clone()), ERROR_TYPE])].into(),
ExtensionSet::default(),
Ok(FunctionType::new(
vec![int_type(arg0.clone()), int_type(arg1.clone())],
vec![Type::new_sum(vec![int_type(arg0.clone()), ERROR_TYPE])],
))
}

fn imod_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn imod_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
Ok((
vec![int_type(arg0.clone()), int_type(arg1.clone())].into(),
vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])].into(),
ExtensionSet::default(),
Ok(FunctionType::new(
vec![int_type(arg0.clone()), int_type(arg1.clone())],
vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])],
))
}

fn ish_sig(arg_values: &[TypeArg]) -> Result<(TypeRow, TypeRow, ExtensionSet), SignatureError> {
fn ish_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg0, arg1] = collect_array(arg_values);
Ok((
vec![int_type(arg0.clone()), int_type(arg1.clone())].into(),
vec![int_type(arg0.clone())].into(),
ExtensionSet::default(),
Ok(FunctionType::new(
vec![int_type(arg0.clone()), int_type(arg1.clone())],
vec![int_type(arg0.clone())],
))
}

Expand Down
22 changes: 11 additions & 11 deletions src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
extension::{ExtensionSet, SignatureError, TypeDef, TypeDefBound},
types::{
type_param::{TypeArg, TypeParam},
CustomCheckFailure, CustomType, Type, TypeBound, TypeRow,
CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow,
},
values::{CustomConst, Value},
Extension,
Expand Down Expand Up @@ -80,11 +80,11 @@ fn extension() -> Extension {
vec![],
move |args: &[TypeArg]| {
let (list_type, element_type) = list_types(args)?;

let inputs = TypeRow::from(vec![list_type.clone()]);
let outputs = TypeRow::from(vec![list_type, element_type]);
let extension_set = ExtensionSet::singleton(&EXTENSION_NAME);
Ok((inputs, outputs, extension_set))
Ok(FunctionType {
input: TypeRow::from(vec![list_type.clone()]),
output: TypeRow::from(vec![list_type, element_type]),
extension_reqs: ExtensionSet::singleton(&EXTENSION_NAME),
})
},
)
.unwrap();
Expand All @@ -97,11 +97,11 @@ fn extension() -> Extension {
vec![],
move |args: &[TypeArg]| {
let (list_type, element_type) = list_types(args)?;

let outputs = TypeRow::from(vec![list_type.clone()]);
let inputs = TypeRow::from(vec![list_type, element_type]);
let extension_set = ExtensionSet::singleton(&EXTENSION_NAME);
Ok((inputs, outputs, extension_set))
Ok(FunctionType {
output: TypeRow::from(vec![list_type.clone()]),
input: TypeRow::from(vec![list_type, element_type]),
extension_reqs: ExtensionSet::singleton(&EXTENSION_NAME),
})
},
)
.unwrap();
Expand Down
25 changes: 10 additions & 15 deletions src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use itertools::Itertools;
use smol_str::SmolStr;

use crate::{
extension::{prelude::BOOL_T, ExtensionSet},
extension::prelude::BOOL_T,
ops, type_row,
types::type_param::{TypeArg, TypeArgError, TypeParam},
types::{
type_param::{TypeArg, TypeArgError, TypeParam},
FunctionType,
},
Extension,
};
use lazy_static::lazy_static;
Expand Down Expand Up @@ -35,13 +38,7 @@ fn extension() -> Extension {
SmolStr::new_inline(NOT_NAME),
"logical 'not'".into(),
vec![],
|_arg_values: &[TypeArg]| {
Ok((
type_row![BOOL_T],
type_row![BOOL_T],
ExtensionSet::default(),
))
},
|_arg_values: &[TypeArg]| Ok(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])),
)
.unwrap();

Expand All @@ -62,10 +59,9 @@ fn extension() -> Extension {
.into());
}
};
Ok((
vec![BOOL_T; n as usize].into(),
Ok(FunctionType::new(
vec![BOOL_T; n as usize],
type_row![BOOL_T],
ExtensionSet::default(),
))
},
)
Expand All @@ -88,10 +84,9 @@ fn extension() -> Extension {
.into());
}
};
Ok((
vec![BOOL_T; n as usize].into(),
Ok(FunctionType::new(
vec![BOOL_T; n as usize],
type_row![BOOL_T],
ExtensionSet::default(),
))
},
)
Expand Down
Loading

0 comments on commit a323545

Please sign in to comment.