diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 144a9d5d7..36cd249c8 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -406,6 +406,7 @@ pub trait Dataflow: Container { (predicate_inputs, predicate_wire): (impl IntoIterator, Wire), other_inputs: impl IntoIterator, output_types: TypeRow, + extension_delta: ExtensionSet, ) -> Result, BuildError> { let mut input_wires = vec![predicate_wire]; let (input_types, rest_input_wires): (Vec, Vec) = @@ -422,6 +423,7 @@ pub trait Dataflow: Container { predicate_inputs, other_inputs: inputs, outputs: output_types, + extension_delta, }, input_wires, )?; diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index e894e56bc..7eb2a388a 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -1,8 +1,8 @@ use crate::hugr::views::HugrView; use crate::types::{FunctionType, TypeRow}; -use crate::ops; use crate::ops::handle::CaseID; +use crate::ops::{self, OpTrait}; use super::build_traits::SubContainer; use super::handle::BuildHandle; @@ -15,6 +15,7 @@ use super::{ use crate::Node; use crate::{ + extension::ExtensionSet, hugr::{HugrMut, NodeType}, Hugr, }; @@ -102,6 +103,7 @@ impl + AsRef> ConditionalBuilder { pub fn case_builder(&mut self, case: usize) -> Result, BuildError> { let conditional = self.conditional_node; let control_op = self.hugr().get_optype(self.conditional_node); + let extension_delta = control_op.signature().extension_reqs; let cond: ops::Conditional = control_op .clone() @@ -117,7 +119,8 @@ impl + AsRef> ConditionalBuilder { let outputs = cond.outputs; let case_op = ops::Case { - signature: FunctionType::new(inputs.clone(), outputs.clone()), + signature: FunctionType::new(inputs.clone(), outputs.clone()) + .with_extension_delta(&extension_delta), }; let case_node = // add case before any existing subsequent cases @@ -134,7 +137,7 @@ impl + AsRef> ConditionalBuilder { let dfg_builder = DFGBuilder::create_with_io( self.hugr_mut(), case_node, - FunctionType::new(inputs, outputs), + FunctionType::new(inputs, outputs).with_extension_delta(&extension_delta), None, )?; @@ -143,7 +146,8 @@ impl + AsRef> ConditionalBuilder { } impl HugrBuilder for ConditionalBuilder { - fn finish_hugr(self) -> Result { + fn finish_hugr(mut self) -> Result { + self.base.infer_extensions()?; self.base.validate()?; Ok(self.base) } @@ -155,6 +159,7 @@ impl ConditionalBuilder { predicate_inputs: impl IntoIterator, other_inputs: impl Into, outputs: impl Into, + extension_delta: ExtensionSet, ) -> Result { let predicate_inputs: Vec<_> = predicate_inputs.into_iter().collect(); let other_inputs = other_inputs.into(); @@ -167,6 +172,7 @@ impl ConditionalBuilder { predicate_inputs, other_inputs, outputs, + extension_delta, }; // TODO: Allow input extensions to be specified let base = Hugr::new(NodeType::pure(op)); @@ -183,10 +189,7 @@ impl ConditionalBuilder { impl CaseBuilder { /// Initialize a Case rooted HUGR - pub fn new(input: impl Into, output: impl Into) -> Result { - let input = input.into(); - let output = output.into(); - let signature = FunctionType::new(input, output); + pub fn new(signature: FunctionType) -> Result { let op = ops::Case { signature: signature.clone(), }; @@ -209,6 +212,7 @@ mod test { test::{n_identity, NAT}, Dataflow, }, + extension::ExtensionSet, ops::Const, type_row, }; @@ -218,8 +222,12 @@ mod test { #[test] fn basic_conditional() -> Result<(), BuildError> { let predicate_inputs = vec![type_row![]; 2]; - let mut conditional_b = - ConditionalBuilder::new(predicate_inputs, type_row![NAT], type_row![NAT])?; + let mut conditional_b = ConditionalBuilder::new( + predicate_inputs, + type_row![NAT], + type_row![NAT], + ExtensionSet::new(), + )?; n_identity(conditional_b.case_builder(1)?)?; n_identity(conditional_b.case_builder(0)?)?; @@ -246,6 +254,7 @@ mod test { (predicate_inputs, const_wire), other_inputs, outputs, + ExtensionSet::new(), )?; n_identity(conditional_b.case_builder(0)?)?; @@ -267,7 +276,12 @@ mod test { #[test] fn test_not_all_cases() -> Result<(), BuildError> { let predicate_inputs = vec![type_row![]; 2]; - let mut builder = ConditionalBuilder::new(predicate_inputs, type_row![], type_row![])?; + let mut builder = ConditionalBuilder::new( + predicate_inputs, + type_row![], + type_row![], + ExtensionSet::new(), + )?; n_identity(builder.case_builder(0)?)?; assert_matches!( builder.finish_sub_container().map(|_| ()), @@ -281,7 +295,12 @@ mod test { #[test] fn test_case_already_built() -> Result<(), BuildError> { let predicate_inputs = vec![type_row![]; 2]; - let mut builder = ConditionalBuilder::new(predicate_inputs, type_row![], type_row![])?; + let mut builder = ConditionalBuilder::new( + predicate_inputs, + type_row![], + type_row![], + ExtensionSet::new(), + )?; n_identity(builder.case_builder(0)?)?; assert_matches!( builder.case_builder(0).map(|_| ()), diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index ed463c6b6..0c7bbd130 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -98,6 +98,7 @@ mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }, extension::prelude::{ConstUsize, USIZE_T}, + extension::ExtensionSet, hugr::ValidationError, ops::Const, type_row, Hugr, @@ -143,6 +144,7 @@ mod test { (predicate_inputs, const_wire), vec![(BIT, b1)], output_row, + ExtensionSet::new(), )?; let mut branch_0 = conditional_b.case_builder(0)?; diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 457b463b7..fac254feb 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -881,21 +881,28 @@ mod test { let [w] = mult.outputs_arr(); builder.set_outputs([w])?; - let hugr = builder.base; - // TODO: when we put new extensions onto the graph after inference, we - // can call `finish_hugr` and just look at the graph - let (solution, extra) = infer_extensions(&hugr)?; - assert!(extra.is_empty()); + let mut hugr = builder.base; + let closure = hugr.infer_extensions()?; + assert!(closure.is_empty()); assert_eq!( - *solution.get(&(src.node(), Direction::Outgoing)).unwrap(), + hugr.get_nodetype(src.node()) + .signature() + .unwrap() + .output_extensions(), rs ); assert_eq!( - *solution.get(&(mult.node(), Direction::Incoming)).unwrap(), + hugr.get_nodetype(mult.node()) + .signature() + .unwrap() + .input_extensions, rs ); assert_eq!( - *solution.get(&(mult.node(), Direction::Outgoing)).unwrap(), + hugr.get_nodetype(mult.node()) + .signature() + .unwrap() + .output_extensions(), rs ); Ok(()) diff --git a/src/hugr.rs b/src/hugr.rs index 87698e02f..51535e584 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -193,7 +193,7 @@ impl Hugr { rw.apply(self) } - /// Infer extension requirements + /// Infer extension requirements and add new information to `op_types` field pub fn infer_extensions( &mut self, ) -> Result, InferExtensionError> { @@ -202,9 +202,22 @@ impl Hugr { Ok(extension_closure) } - /// TODO: Write this - fn instantiate_extensions(&mut self, _solution: ExtensionSolution) { - //todo!() + /// Add extension requirement information to the hugr in place. + fn instantiate_extensions(&mut self, solution: ExtensionSolution) { + // We only care about inferred _input_ extensions, because `NodeType` + // uses those to infer the output extensions + for ((node, _), input_extensions) in solution + .iter() + .filter(|((_, dir), _)| *dir == Direction::Incoming) + { + let nodetype = self.op_types.try_get_mut(node.index).unwrap(); + match nodetype.signature() { + None => nodetype.input_extensions = Some(input_extensions.clone()), + Some(existing_ext_reqs) => { + debug_assert_eq!(existing_ext_reqs.input_extensions, *input_extensions) + } + } + } } } @@ -428,7 +441,14 @@ impl From for PyErr { #[cfg(test)] mod test { - use super::Hugr; + use super::{Hugr, HugrView, NodeType}; + use crate::extension::ExtensionSet; + use crate::hugr::HugrMut; + use crate::ops; + use crate::type_row; + use crate::types::{FunctionType, Type}; + + use std::error::Error; #[test] fn impls_send_and_sync() { @@ -447,4 +467,55 @@ mod test { let hugr = simple_dfg_hugr(); assert_matches!(hugr.get_io(hugr.root()), Some(_)); } + + #[test] + fn extension_instantiation() -> Result<(), Box> { + const BIT: Type = crate::extension::prelude::USIZE_T; + let r = ExtensionSet::singleton(&"R".into()); + + let root = NodeType::pure(ops::DFG { + signature: FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r), + }); + let mut hugr = Hugr::new(root); + let input = hugr.add_node_with_parent( + hugr.root(), + NodeType::pure(ops::Input { + types: type_row![BIT], + }), + )?; + let output = hugr.add_node_with_parent( + hugr.root(), + NodeType::open_extensions(ops::Output { + types: type_row![BIT], + }), + )?; + let lift = hugr.add_node_with_parent( + hugr.root(), + NodeType::open_extensions(ops::LeafOp::Lift { + type_row: type_row![BIT], + new_extension: "R".into(), + }), + )?; + hugr.connect(input, 0, lift, 0)?; + hugr.connect(lift, 0, output, 0)?; + hugr.infer_extensions()?; + + assert_eq!( + hugr.op_types + .get(lift.index) + .signature() + .unwrap() + .input_extensions, + ExtensionSet::new() + ); + assert_eq!( + hugr.op_types + .get(output.index) + .signature() + .unwrap() + .input_extensions, + r + ); + Ok(()) + } } diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 395516943..88e6ee435 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -2,6 +2,7 @@ use smol_str::SmolStr; +use crate::extension::ExtensionSet; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; use super::dataflow::DataflowOpTrait; @@ -59,6 +60,8 @@ pub struct Conditional { pub other_inputs: TypeRow, /// Output types pub outputs: TypeRow, + /// Extensions used to produce the outputs + pub extension_delta: ExtensionSet, } impl_op_name!(Conditional); @@ -74,7 +77,7 @@ impl DataflowOpTrait for Conditional { inputs .to_mut() .insert(0, Type::new_predicate(self.predicate_inputs.clone())); - FunctionType::new(inputs, self.outputs.clone()) + FunctionType::new(inputs, self.outputs.clone()).with_extension_delta(&self.extension_delta) } } @@ -209,6 +212,10 @@ impl OpTrait for Case { fn tag(&self) -> OpTag { ::TAG } + + fn signature(&self) -> FunctionType { + self.signature.clone() + } } impl Case {