Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add extension info to Conditional and Case #463

Merged
merged 4 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ pub trait Dataflow: Container {
(predicate_inputs, predicate_wire): (impl IntoIterator<Item = TypeRow>, Wire),
other_inputs: impl IntoIterator<Item = (Type, Wire)>,
output_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<ConditionalBuilder<&mut Hugr>, BuildError> {
let mut input_wires = vec![predicate_wire];
let (input_types, rest_input_wires): (Vec<Type>, Vec<Wire>) =
Expand All @@ -422,6 +423,7 @@ pub trait Dataflow: Container {
predicate_inputs,
other_inputs: inputs,
outputs: output_types,
extension_delta,
},
input_wires,
)?;
Expand Down
43 changes: 31 additions & 12 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::hugr::views::HugrView;
use crate::types::{FunctionType, TypeRow};

use crate::ops;
use crate::ops::handle::CaseID;
use crate::ops::{self, OpTrait};

use super::build_traits::SubContainer;
use super::handle::BuildHandle;
Expand All @@ -15,6 +15,7 @@ use super::{

use crate::Node;
use crate::{
extension::ExtensionSet,
hugr::{HugrMut, NodeType},
Hugr,
};
Expand Down Expand Up @@ -102,6 +103,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
pub fn case_builder(&mut self, case: usize) -> Result<CaseBuilder<&mut Hugr>, BuildError> {
let conditional = self.conditional_node;
let control_op = self.hugr().get_optype(self.conditional_node);
let extension_delta = control_op.signature().extension_reqs;

let cond: ops::Conditional = control_op
.clone()
Expand All @@ -117,7 +119,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {

let outputs = cond.outputs;
let case_op = ops::Case {
signature: FunctionType::new(inputs.clone(), outputs.clone()),
signature: FunctionType::new(inputs.clone(), outputs.clone())
.with_extension_delta(&extension_delta),
};
let case_node =
// add case before any existing subsequent cases
Expand All @@ -134,7 +137,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
let dfg_builder = DFGBuilder::create_with_io(
self.hugr_mut(),
case_node,
FunctionType::new(inputs, outputs),
FunctionType::new(inputs, outputs).with_extension_delta(&extension_delta),
None,
)?;

Expand All @@ -143,7 +146,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
}

impl HugrBuilder for ConditionalBuilder<Hugr> {
fn finish_hugr(self) -> Result<Hugr, crate::hugr::ValidationError> {
fn finish_hugr(mut self) -> Result<Hugr, crate::hugr::ValidationError> {
self.base.infer_extensions()?;
self.base.validate()?;
Ok(self.base)
}
Expand All @@ -155,6 +159,7 @@ impl ConditionalBuilder<Hugr> {
predicate_inputs: impl IntoIterator<Item = TypeRow>,
other_inputs: impl Into<TypeRow>,
outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
) -> Result<Self, BuildError> {
let predicate_inputs: Vec<_> = predicate_inputs.into_iter().collect();
let other_inputs = other_inputs.into();
Expand All @@ -167,6 +172,7 @@ impl ConditionalBuilder<Hugr> {
predicate_inputs,
other_inputs,
outputs,
extension_delta,
};
// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::pure(op));
Expand All @@ -183,10 +189,7 @@ impl ConditionalBuilder<Hugr> {

impl CaseBuilder<Hugr> {
/// Initialize a Case rooted HUGR
pub fn new(input: impl Into<TypeRow>, output: impl Into<TypeRow>) -> Result<Self, BuildError> {
let input = input.into();
let output = output.into();
let signature = FunctionType::new(input, output);
pub fn new(signature: FunctionType) -> Result<Self, BuildError> {
let op = ops::Case {
signature: signature.clone(),
};
Expand All @@ -209,6 +212,7 @@ mod test {
test::{n_identity, NAT},
Dataflow,
},
extension::ExtensionSet,
ops::Const,
type_row,
};
Expand All @@ -218,8 +222,12 @@ mod test {
#[test]
fn basic_conditional() -> Result<(), BuildError> {
let predicate_inputs = vec![type_row![]; 2];
let mut conditional_b =
ConditionalBuilder::new(predicate_inputs, type_row![NAT], type_row![NAT])?;
let mut conditional_b = ConditionalBuilder::new(
predicate_inputs,
type_row![NAT],
type_row![NAT],
ExtensionSet::new(),
)?;

n_identity(conditional_b.case_builder(1)?)?;
n_identity(conditional_b.case_builder(0)?)?;
Expand All @@ -246,6 +254,7 @@ mod test {
(predicate_inputs, const_wire),
other_inputs,
outputs,
ExtensionSet::new(),
)?;

n_identity(conditional_b.case_builder(0)?)?;
Expand All @@ -267,7 +276,12 @@ mod test {
#[test]
fn test_not_all_cases() -> Result<(), BuildError> {
let predicate_inputs = vec![type_row![]; 2];
let mut builder = ConditionalBuilder::new(predicate_inputs, type_row![], type_row![])?;
let mut builder = ConditionalBuilder::new(
predicate_inputs,
type_row![],
type_row![],
ExtensionSet::new(),
)?;
n_identity(builder.case_builder(0)?)?;
assert_matches!(
builder.finish_sub_container().map(|_| ()),
Expand All @@ -281,7 +295,12 @@ mod test {
#[test]
fn test_case_already_built() -> Result<(), BuildError> {
let predicate_inputs = vec![type_row![]; 2];
let mut builder = ConditionalBuilder::new(predicate_inputs, type_row![], type_row![])?;
let mut builder = ConditionalBuilder::new(
predicate_inputs,
type_row![],
type_row![],
ExtensionSet::new(),
)?;
n_identity(builder.case_builder(0)?)?;
assert_matches!(
builder.case_builder(0).map(|_| ()),
Expand Down
2 changes: 2 additions & 0 deletions src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ mod test {
DataflowSubContainer, HugrBuilder, ModuleBuilder,
},
extension::prelude::{ConstUsize, USIZE_T},
extension::ExtensionSet,
hugr::ValidationError,
ops::Const,
type_row, Hugr,
Expand Down Expand Up @@ -143,6 +144,7 @@ mod test {
(predicate_inputs, const_wire),
vec![(BIT, b1)],
output_row,
ExtensionSet::new(),
)?;

let mut branch_0 = conditional_b.case_builder(0)?;
Expand Down
23 changes: 15 additions & 8 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,21 +881,28 @@ mod test {
let [w] = mult.outputs_arr();

builder.set_outputs([w])?;
let hugr = builder.base;
// TODO: when we put new extensions onto the graph after inference, we
// can call `finish_hugr` and just look at the graph
let (solution, extra) = infer_extensions(&hugr)?;
assert!(extra.is_empty());
let mut hugr = builder.base;
let closure = hugr.infer_extensions()?;
assert!(closure.is_empty());
assert_eq!(
*solution.get(&(src.node(), Direction::Outgoing)).unwrap(),
hugr.get_nodetype(src.node())
.signature()
.unwrap()
.output_extensions(),
rs
);
assert_eq!(
*solution.get(&(mult.node(), Direction::Incoming)).unwrap(),
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.input_extensions,
rs
);
assert_eq!(
*solution.get(&(mult.node(), Direction::Outgoing)).unwrap(),
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.output_extensions(),
rs
);
Ok(())
Expand Down
81 changes: 76 additions & 5 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl Hugr {
rw.apply(self)
}

/// Infer extension requirements
/// Infer extension requirements and add new information to `op_types` field
pub fn infer_extensions(
&mut self,
) -> Result<HashMap<(Node, Direction), ExtensionSet>, InferExtensionError> {
Expand All @@ -202,9 +202,22 @@ impl Hugr {
Ok(extension_closure)
}

/// TODO: Write this
fn instantiate_extensions(&mut self, _solution: ExtensionSolution) {
//todo!()
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
// uses those to infer the output extensions
for ((node, _), input_extensions) in solution
.iter()
.filter(|((_, dir), _)| *dir == Direction::Incoming)
{
let nodetype = self.op_types.try_get_mut(node.index).unwrap();
match nodetype.signature() {
None => nodetype.input_extensions = Some(input_extensions.clone()),
Some(existing_ext_reqs) => {
debug_assert_eq!(existing_ext_reqs.input_extensions, *input_extensions)
}
}
}
}
}

Expand Down Expand Up @@ -428,7 +441,14 @@ impl From<HugrError> for PyErr {

#[cfg(test)]
mod test {
use super::Hugr;
use super::{Hugr, HugrView, NodeType};
use crate::extension::ExtensionSet;
use crate::hugr::HugrMut;
use crate::ops;
use crate::type_row;
use crate::types::{FunctionType, Type};

use std::error::Error;

#[test]
fn impls_send_and_sync() {
Expand All @@ -447,4 +467,55 @@ mod test {
let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

#[test]
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
const BIT: Type = crate::extension::prelude::USIZE_T;
let r = ExtensionSet::singleton(&"R".into());

let root = NodeType::pure(ops::DFG {
signature: FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r),
});
let mut hugr = Hugr::new(root);
let input = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Input {
types: type_row![BIT],
}),
)?;
let output = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Output {
types: type_row![BIT],
}),
)?;
let lift = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "R".into(),
}),
)?;
hugr.connect(input, 0, lift, 0)?;
hugr.connect(lift, 0, output, 0)?;
hugr.infer_extensions()?;

assert_eq!(
hugr.op_types
.get(lift.index)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
assert_eq!(
hugr.op_types
.get(output.index)
.signature()
.unwrap()
.input_extensions,
r
);
Ok(())
}
}
9 changes: 8 additions & 1 deletion src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use smol_str::SmolStr;

use crate::extension::ExtensionSet;
use crate::types::{EdgeKind, FunctionType, Type, TypeRow};

use super::dataflow::DataflowOpTrait;
Expand Down Expand Up @@ -59,6 +60,8 @@ pub struct Conditional {
pub other_inputs: TypeRow,
/// Output types
pub outputs: TypeRow,
/// Extensions used to produce the outputs
pub extension_delta: ExtensionSet,
}
impl_op_name!(Conditional);

Expand All @@ -74,7 +77,7 @@ impl DataflowOpTrait for Conditional {
inputs
.to_mut()
.insert(0, Type::new_predicate(self.predicate_inputs.clone()));
FunctionType::new(inputs, self.outputs.clone())
FunctionType::new(inputs, self.outputs.clone()).with_extension_delta(&self.extension_delta)
}
}

Expand Down Expand Up @@ -209,6 +212,10 @@ impl OpTrait for Case {
fn tag(&self) -> OpTag {
<Self as StaticTag>::TAG
}

fn signature(&self) -> FunctionType {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this return a reference (and clone on use)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just following the OpTrait interface, but I think they could all just as easily be implemented that way

self.signature.clone()
}
}

impl Case {
Expand Down