Skip to content

Commit

Permalink
feat: add xor to logic extension
Browse files Browse the repository at this point in the history
Closes  #1418
  • Loading branch information
ss2165 committed Feb 6, 2025
1 parent 6489977 commit fa2e966
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 2 deletions.
15 changes: 13 additions & 2 deletions hugr-core/src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,15 @@ impl ConstFold for LogicOp {
(!res || inps.len() as u64 == 1)
.then_some(vec![(0.into(), ops::Value::from_bool(res))])
}
Self::Xor => {
let inps = read_inputs(consts)?;
let res = inps.iter().fold(false, |acc, x| acc ^ *x);
(inps.len() as u64 == 2).then_some(vec![(0.into(), ops::Value::from_bool(res))])
}
}
}
}

/// Logic extension operation definitions.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[allow(missing_docs)]
Expand All @@ -67,12 +73,13 @@ pub enum LogicOp {
Or,
Eq,
Not,
Xor,
}

impl MakeOpDef for LogicOp {
fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
match self {
LogicOp::And | LogicOp::Or | LogicOp::Eq => {
LogicOp::And | LogicOp::Or | LogicOp::Eq | LogicOp::Xor => {
Signature::new(vec![bool_t(); 2], vec![bool_t()])
}
LogicOp::Not => Signature::new_endo(vec![bool_t()]),
Expand All @@ -90,6 +97,7 @@ impl MakeOpDef for LogicOp {
LogicOp::Or => "logical 'or'",
LogicOp::Eq => "test if bools are equal",
LogicOp::Not => "logical 'not'",
LogicOp::Xor => "logical 'xor'",
}
.to_string()
}
Expand Down Expand Up @@ -181,7 +189,7 @@ pub(crate) mod test {
fn test_logic_extension() {
let r: Arc<Extension> = extension();
assert_eq!(r.name() as &str, "logic");
assert_eq!(r.operations().count(), 4);
assert_eq!(r.operations().count(), 5);

for op in LogicOp::iter() {
assert_eq!(
Expand Down Expand Up @@ -230,6 +238,8 @@ pub(crate) mod test {
#[case(LogicOp::Eq, [false, false], true)]
#[case(LogicOp::Not, [false], true)]
#[case(LogicOp::Not, [true], false)]
#[case(LogicOp::Xor, [true, false], true)]
#[case(LogicOp::Xor, [true, true], false)]
fn const_fold(
#[case] op: LogicOp,
#[case] ins: impl IntoIterator<Item = bool>,
Expand All @@ -256,6 +266,7 @@ pub(crate) mod test {
#[case(LogicOp::Or, [None, Some(true)], Some(true))]
#[case(LogicOp::Eq, [None, Some(true)], None)]
#[case(LogicOp::Not, [None], None)]
#[case(LogicOp::Xor, [None, Some(true)], None)]
fn partial_const_fold(
#[case] op: LogicOp,
#[case] ins: impl IntoIterator<Item = Option<bool>>,
Expand Down
15 changes: 15 additions & 0 deletions hugr-llvm/src/extension/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ fn emit_logic_op<'c, H: HugrView>(
acc
}
LogicOp::Not => builder.build_not(inputs[0], "")?,
LogicOp::Xor => {
let mut acc = inputs[0];
for inp in inputs.into_iter().skip(1) {
acc = builder.build_xor(acc, inp, "")?;
}
acc
}
op => {
return Err(anyhow!("LogicOpEmitter: Unknown op: {op:?}"));
}
Expand All @@ -80,6 +87,7 @@ pub fn add_logic_extensions<'a, H: HugrView + 'a>(
.extension_op(logic::EXTENSION_ID, LogicOp::And.name(), emit_logic_op)
.extension_op(logic::EXTENSION_ID, LogicOp::Or.name(), emit_logic_op)
.extension_op(logic::EXTENSION_ID, LogicOp::Not.name(), emit_logic_op)
.extension_op(logic::EXTENSION_ID, LogicOp::Xor.name(), emit_logic_op) // Added Xor
}

impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> {
Expand Down Expand Up @@ -148,4 +156,11 @@ mod test {
let hugr = test_logic_op(LogicOp::Not, 1);
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn xor(mut llvm_ctx: TestContext) {
llvm_ctx.add_extensions(add_logic_extensions);
let hugr = test_logic_op(LogicOp::Xor, 2);
check_emission!(hugr, llvm_ctx);
}
}
16 changes: 16 additions & 0 deletions hugr-llvm/src/extension/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
---
source: hugr-llvm/src/extension/logic.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i1 @_hl.main.1(i1 %0, i1 %1) {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
%2 = xor i1 %0, %1
%3 = select i1 %2, i1 true, i1 false
ret i1 %3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
---
source: hugr-llvm/src/extension/logic.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i1 @_hl.main.1(i1 %0, i1 %1) {
alloca_block:
%"0" = alloca i1, align 1
%"2_0" = alloca i1, align 1
%"2_1" = alloca i1, align 1
%"4_0" = alloca i1, align 1
br label %entry_block

entry_block: ; preds = %alloca_block
store i1 %0, i1* %"2_0", align 1
store i1 %1, i1* %"2_1", align 1
%"2_01" = load i1, i1* %"2_0", align 1
%"2_12" = load i1, i1* %"2_1", align 1
%2 = xor i1 %"2_01", %"2_12"
%3 = select i1 %2, i1 true, i1 false
store i1 %3, i1* %"4_0", align 1
%"4_03" = load i1, i1* %"4_0", align 1
store i1 %"4_03", i1* %"0", align 1
%"04" = load i1, i1* %"0", align 1
ret i1 %"04"
}
31 changes: 31 additions & 0 deletions hugr-py/src/hugr/std/_json_defs/logic.json
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,37 @@
}
},
"binary": false
},
"Xor": {
"extension": "logic",
"name": "Xor",
"description": "logical 'xor'",
"signature": {
"params": [],
"body": {
"input": [
{
"t": "Sum",
"s": "Unit",
"size": 2
},
{
"t": "Sum",
"s": "Unit",
"size": 2
}
],
"output": [
{
"t": "Sum",
"s": "Unit",
"size": 2
}
],
"runtime_reqs": []
}
},
"binary": false
}
}
}
31 changes: 31 additions & 0 deletions specification/std_extensions/logic.json
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,37 @@
}
},
"binary": false
},
"Xor": {
"extension": "logic",
"name": "Xor",
"description": "logical 'xor'",
"signature": {
"params": [],
"body": {
"input": [
{
"t": "Sum",
"s": "Unit",
"size": 2
},
{
"t": "Sum",
"s": "Unit",
"size": 2
}
],
"output": [
{
"t": "Sum",
"s": "Unit",
"size": 2
}
],
"runtime_reqs": []
}
},
"binary": false
}
}
}

0 comments on commit fa2e966

Please sign in to comment.