Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support encoding float and sympy ops #618

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tket2-py/tket2/extensions/_json_defs/tket2/rotation.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"from_halfturns": {
"extension": "tket2.rotation",
"name": "from_halfturns",
"description": "Construct rotation from number of half-turns (would be multiples of π in radians).",
"description": "Construct rotation from number of half-turns (would be multiples of π in radians). Returns None if the float is non-finite.",
"signature": {
"params": [],
"body": {
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/extension/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl MakeOpDef for RotationOp {
fn description(&self) -> String {
match self {
RotationOp::from_halfturns => {
"Construct rotation from number of half-turns (would be multiples of π in radians)."
"Construct rotation from number of half-turns (would be multiples of π in radians). Returns None if the float is non-finite."
}
RotationOp::to_halfturns => {
"Convert rotation to number of half-turns (would be multiples of π in radians)."
Expand Down
15 changes: 11 additions & 4 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod decoder;
mod encoder;
mod op;

use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::types::Type;

use hugr::Node;
Expand Down Expand Up @@ -300,15 +301,21 @@ fn try_param_to_constant(param: &str) -> Option<Value> {
ConstRotation::new(half_turns).ok().map(Into::into)
}

/// Convert a HUGR angle constant to a TKET1 parameter.
/// Convert a HUGR rotation or float constant to a TKET1 parameter.
///
/// Angle parameters in TKET1 are encoded as a number of half-turns,
/// whereas HUGR uses radians.
#[inline]
fn try_constant_to_param(val: &Value) -> Option<String> {
let const_angle = val.get_custom_value::<ConstRotation>()?;
let half_turns = const_angle.half_turns();
Some(half_turns.to_string())
if let Some(const_angle) = val.get_custom_value::<ConstRotation>() {
let half_turns = const_angle.half_turns();
Some(half_turns.to_string())
} else if let Some(const_float) = val.get_custom_value::<ConstF64>() {
let float = const_float.value();
Some(float.to_string())
} else {
None
}
}

/// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map,
Expand Down
91 changes: 85 additions & 6 deletions tket2/src/serialize/pytket/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use std::collections::{HashMap, HashSet, VecDeque};

use hugr::extension::prelude::{BOOL_T, QB_T};
use hugr::ops::{OpTrait, OpType};
use hugr::std_extensions::arithmetic::float_ops::FloatOps;
use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use hugr::types::Type;
use hugr::{HugrView, Wire};
use itertools::Itertools;
use tket_json_rs::circuit_json::Register as RegisterUnit;
Expand All @@ -13,6 +16,7 @@ use tket_json_rs::circuit_json::{self, SerialCircuit};
use crate::circuit::command::{CircuitUnit, Command};
use crate::circuit::Circuit;
use crate::extension::rotation::{RotationOp, ROTATION_TYPE};
use crate::extension::sympy::SympyOp;
use crate::ops::match_symb_const_op;
use crate::serialize::pytket::RegisterHash;
use crate::Tk2Op;
Expand Down Expand Up @@ -573,10 +577,19 @@ impl ParameterTracker {
optype: &OpType,
) -> Result<bool, OpConvertError> {
let input_count = if let Some(signature) = optype.dataflow_signature() {
// Only consider commands where all inputs are parameters,
// and some outputs are also parameters.
let all_inputs = signature.input().iter().all(|ty| ty == &ROTATION_TYPE);
let some_output = signature.output().iter().any(|ty| ty == &ROTATION_TYPE);
// Only consider commands where all inputs and some outputs are
// parameters that we can track.
//
// TODO: We should track Option<T> parameters too, `RotationOp::from_halfturns` returns options.
const TRACKED_PARAMS: [Type; 2] = [ROTATION_TYPE, FLOAT64_TYPE];
let all_inputs = signature
.input()
.iter()
.all(|ty| TRACKED_PARAMS.contains(ty));
let some_output = signature
.output()
.iter()
.any(|ty| TRACKED_PARAMS.contains(ty));
if !all_inputs || !some_output {
return Ok(false);
}
Expand Down Expand Up @@ -619,8 +632,28 @@ impl ParameterTracker {
// Re-use the parameter from the input.
inputs[0].clone()
}
OpType::ExtensionOp(_) if optype.cast() == Some(RotationOp::radd) => {
format!("{} + {}", inputs[0], inputs[1])
// Encode some angle and float operations directly as strings using
// the already encoded inputs. Fail if the operation is not
// supported, and let the operation encoding process it instead.
OpType::ExtensionOp(_) => {
if let Some(s) = optype
.cast::<RotationOp>()
.and_then(|op| self.encode_rotation_op(&op, inputs.as_slice()))
{
s
} else if let Some(s) = optype
.cast::<FloatOps>()
.and_then(|op| self.encode_float_op(&op, inputs.as_slice()))
{
s
} else if let Some(s) = optype
.cast::<SympyOp>()
.and_then(|op| self.encode_sympy_op(&op, inputs.as_slice()))
{
s
} else {
return Ok(false);
}
Comment on lines +639 to +656
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way this could be some sort of match statement on the optype type instead of trying each out?

Copy link
Collaborator Author

@aborgna-q aborgna-q Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each cast has a different type, so all options I tried for writing this look messy (matching on a tuple, chaining Option::or_else, ...).

}
_ => {
let Some(s) = match_symb_const_op(optype) else {
Expand All @@ -647,6 +680,52 @@ impl ParameterTracker {
fn get(&self, wire: &Wire) -> Option<&String> {
self.parameters.get(wire)
}

/// Encode an [`RotationOp`]s as a string, given its encoded inputs.
///
/// `inputs` contains the expressions to compute each input.
fn encode_rotation_op(&self, op: &RotationOp, inputs: &[&String]) -> Option<String> {
let s = match op {
RotationOp::radd => format!("({} + {})", inputs[0], inputs[1]),
// Encode/decode the rotation as pytket parameters, expressed as half-turns.
// Note that the tracked parameter strings are always written in half-turns,
// so the conversion here is a no-op.
RotationOp::to_halfturns => inputs[0].clone(),
RotationOp::from_halfturns => inputs[0].clone(),
};
Some(s)
}

/// Encode an [`FloatOps`] as a string, given its encoded inputs.
fn encode_float_op(&self, op: &FloatOps, inputs: &[&String]) -> Option<String> {
let s = match op {
FloatOps::fadd => format!("({} + {})", inputs[0], inputs[1]),
FloatOps::fsub => format!("({} - {})", inputs[0], inputs[1]),
FloatOps::fneg => format!("(-{})", inputs[0]),
FloatOps::fmul => format!("({} * {})", inputs[0], inputs[1]),
FloatOps::fdiv => format!("({} / {})", inputs[0], inputs[1]),
FloatOps::fpow => format!("({} ** {})", inputs[0], inputs[1]),
FloatOps::ffloor => format!("floor({})", inputs[0]),
FloatOps::fceil => format!("ceil({})", inputs[0]),
FloatOps::fround => format!("round({})", inputs[0]),
FloatOps::fmax => format!("max({}, {})", inputs[0], inputs[1]),
FloatOps::fmin => format!("min({}, {})", inputs[0], inputs[1]),
FloatOps::fabs => format!("abs({})", inputs[0]),
_ => return None,
};
Some(s)
}

/// Encode a [`SympyOp`]s as a string.
///
/// Note that the sympy operation does not have any inputs.
fn encode_sympy_op(&self, op: &SympyOp, inputs: &[&String]) -> Option<String> {
if !inputs.is_empty() {
return None;
}

Some(op.expr.clone())
}
}

/// A utility class for finding new unused qubit/bit names.
Expand Down
8 changes: 6 additions & 2 deletions tket2/src/serialize/pytket/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ impl Tk1Op {
}
Ok(Some(Tk1Op::Native(native)))
} else {
let opaque = OpaqueTk1Op::try_from_tket2(&op)?;
Ok(opaque.map(Tk1Op::Opaque))
// Unrecognised opaque operation. If it's an opaque tket1 op, return it.
// Otherwise, it's an unsupported operation and we should fail.
match OpaqueTk1Op::try_from_tket2(&op)? {
Some(opaque) => Ok(Some(Tk1Op::Opaque(opaque))),
None => Err(OpConvertError::UnsupportedOpSerialization(op.clone())),
}
}
}

Expand Down
34 changes: 32 additions & 2 deletions tket2/src/serialize/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tket_json_rs::optype;
use super::{TKETDecode, METADATA_Q_OUTPUT_REGISTERS};
use crate::circuit::Circuit;
use crate::extension::rotation::{ConstRotation, RotationOp, ROTATION_TYPE};
use crate::extension::sympy::SympyOpDef;
use crate::extension::REGISTRY;
use crate::Tk2Op;

Expand Down Expand Up @@ -226,6 +227,34 @@ fn circ_add_angles_constants() -> Circuit {
h.finish_hugr_with_outputs(qbs, &REGISTRY).unwrap().into()
}

#[fixture]
/// An Rx operation using some complex ops to compute its angle `cos(pi) + 1`.
fn circ_complex_angle_computation() -> Circuit {
let qb_row = vec![QB_T];
let mut h = DFGBuilder::new(Signature::new(qb_row.clone(), qb_row)).unwrap();

let qb = h.input_wires().next().unwrap();

let point2 = h.add_load_value(ConstRotation::new(0.2).unwrap());
let sympy = h
.add_dataflow_op(SympyOpDef.with_expr("cos(pi)".to_string()), [])
.unwrap()
.out_wire(0);
let final_rot = h
.add_dataflow_op(RotationOp::radd, [sympy, point2])
.unwrap()
.out_wire(0);

// TODO: Mix in some float ops. This requires unwrapping the result of `RotationOp::from_halfturns`.

let qbs = h
.add_dataflow_op(Tk2Op::Rx, [qb, final_rot])
.unwrap()
.outputs();

h.finish_hugr_with_outputs(qbs, &REGISTRY).unwrap().into()
}

#[rstest]
#[case::simple(SIMPLE_JSON, 2, 2)]
#[case::simple(MULTI_REGISTER, 2, 3)]
Expand Down Expand Up @@ -289,8 +318,9 @@ fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: Signature) {
/// converted back to circuit inputs. This would require parsing symbolic
/// expressions.
#[rstest]
#[case::symbolic(circ_add_angles_symbolic(), "f0 + f1")]
#[case::constants(circ_add_angles_constants(), "0.2 + 0.3")]
#[case::symbolic(circ_add_angles_symbolic(), "(f0 + f1)")]
#[case::constants(circ_add_angles_constants(), "(0.2 + 0.3)")]
#[case::complex(circ_complex_angle_computation(), "(cos(pi) + 0.2)")]
fn test_add_angle_serialise(#[case] circ_add_angles: Circuit, #[case] param_str: &str) {
let ser: SerialCircuit = SerialCircuit::encode(&circ_add_angles).unwrap();
assert_eq!(ser.commands.len(), 1);
Expand Down
Loading