diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index 6528009f6..2f042b7bd 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -55,9 +55,15 @@ impl ConstFold for LogicOp { (!res || inps.len() as u64 == 1) .then_some(vec![(0.into(), ops::Value::from_bool(res))]) } + Self::Xor => { + let inps = read_inputs(consts)?; + let res = inps.iter().fold(false, |acc, x| acc ^ *x); + (inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))]) + } } } } + /// Logic extension operation definitions. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs)] @@ -67,12 +73,13 @@ pub enum LogicOp { Or, Eq, Not, + Xor, } impl MakeOpDef for LogicOp { fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { match self { - LogicOp::And | LogicOp::Or | LogicOp::Eq => { + LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => { Signature::new(vec![bool_t(); 2], vec![bool_t()]) } LogicOp::Not => Signature::new_endo(vec![bool_t()]), @@ -90,6 +97,7 @@ impl MakeOpDef for LogicOp { LogicOp::Or => "logical 'or'", LogicOp::Eq => "test if bools are equal", LogicOp::Not => "logical 'not'", + LogicOp::Xor => "logical 'xor'", } .to_string() } @@ -181,7 +189,7 @@ pub(crate) mod test { fn test_logic_extension() { let r: Arc = extension(); assert_eq!(r.name() as &str, "logic"); - assert_eq!(r.operations().count(), 4); + assert_eq!(r.operations().count(), 5); for op in LogicOp::iter() { assert_eq!( @@ -230,6 +238,8 @@ pub(crate) mod test { #[case(LogicOp::Eq, [false, false], true)] #[case(LogicOp::Not, [false], true)] #[case(LogicOp::Not, [true], false)] + #[case(LogicOp::Xor, [true, false], true)] + #[case(LogicOp::Xor, [true, true], false)] fn const_fold( #[case] op: LogicOp, #[case] ins: impl IntoIterator, @@ -256,6 +266,7 @@ pub(crate) mod test { #[case(LogicOp::Or, [None, Some(true)], Some(true))] #[case(LogicOp::Eq, [None, Some(true)], None)] #[case(LogicOp::Not, [None], None)] + #[case(LogicOp::Xor, [None, Some(true)], None)] fn partial_const_fold( #[case] op: LogicOp, #[case] ins: impl IntoIterator>, diff --git a/hugr-llvm/src/extension/logic.rs b/hugr-llvm/src/extension/logic.rs index 32a58923c..4441e272f 100644 --- a/hugr-llvm/src/extension/logic.rs +++ b/hugr-llvm/src/extension/logic.rs @@ -57,6 +57,13 @@ fn emit_logic_op<'c, H: HugrView>( acc } LogicOp::Not => builder.build_not(inputs[0], "")?, + LogicOp::Xor => { + let mut acc = inputs[0]; + for inp in inputs.into_iter().skip(1) { + acc = builder.build_xor(acc, inp, "")?; + } + acc + } op => { return Err(anyhow!("LogicOpEmitter: Unknown op: {op:?}")); } @@ -80,6 +87,7 @@ pub fn add_logic_extensions<'a, H: HugrView + 'a>( .extension_op(logic::EXTENSION_ID, LogicOp::And.name(), emit_logic_op) .extension_op(logic::EXTENSION_ID, LogicOp::Or.name(), emit_logic_op) .extension_op(logic::EXTENSION_ID, LogicOp::Not.name(), emit_logic_op) + .extension_op(logic::EXTENSION_ID, LogicOp::Xor.name(), emit_logic_op) // Added Xor } impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { @@ -148,4 +156,11 @@ mod test { let hugr = test_logic_op(LogicOp::Not, 1); check_emission!(hugr, llvm_ctx); } + + #[rstest] + fn xor(mut llvm_ctx: TestContext) { + llvm_ctx.add_extensions(add_logic_extensions); + let hugr = test_logic_op(LogicOp::Xor, 2); + check_emission!(hugr, llvm_ctx); + } } diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@llvm14.snap new file mode 100644 index 000000000..934ef459b --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@llvm14.snap @@ -0,0 +1,16 @@ +--- +source: hugr-llvm/src/extension/logic.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i1 @_hl.main.1(i1 %0, i1 %1) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %2 = xor i1 %0, %1 + %3 = select i1 %2, i1 true, i1 false + ret i1 %3 +} diff --git a/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..a5dcf022d --- /dev/null +++ b/hugr-llvm/src/extension/snapshots/hugr_llvm__extension__logic__test__xor@pre-mem2reg@llvm14.snap @@ -0,0 +1,28 @@ +--- +source: hugr-llvm/src/extension/logic.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i1 @_hl.main.1(i1 %0, i1 %1) { +alloca_block: + %"0" = alloca i1, align 1 + %"2_0" = alloca i1, align 1 + %"2_1" = alloca i1, align 1 + %"4_0" = alloca i1, align 1 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i1 %0, i1* %"2_0", align 1 + store i1 %1, i1* %"2_1", align 1 + %"2_01" = load i1, i1* %"2_0", align 1 + %"2_12" = load i1, i1* %"2_1", align 1 + %2 = xor i1 %"2_01", %"2_12" + %3 = select i1 %2, i1 true, i1 false + store i1 %3, i1* %"4_0", align 1 + %"4_03" = load i1, i1* %"4_0", align 1 + store i1 %"4_03", i1* %"0", align 1 + %"04" = load i1, i1* %"0", align 1 + ret i1 %"04" +} diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index 7f90392ff..ad9f02019 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -150,6 +150,37 @@ } }, "binary": false + }, + "Xor": { + "extension": "logic", + "name": "Xor", + "description": "logical 'xor'", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "runtime_reqs": [] + } + }, + "binary": false } } } diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index 7f90392ff..ad9f02019 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -150,6 +150,37 @@ } }, "binary": false + }, + "Xor": { + "extension": "logic", + "name": "Xor", + "description": "logical 'xor'", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "runtime_reqs": [] + } + }, + "binary": false } } }