Skip to content

Commit 251fe9a

Browse files
committed
feat: Add a LoadedParam struct for the pytket decoder
1 parent 2cdc11b commit 251fe9a

File tree

5 files changed

+406
-233
lines changed

5 files changed

+406
-233
lines changed

tket/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ derive_more = { workspace = true, features = [
5757
"into",
5858
"sum",
5959
"add",
60+
"add_assign",
6061
] }
6162
hugr = { workspace = true }
6263
hugr-core = { workspace = true }

tket/src/serialize/pytket/decoder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use crate::extension::rotation::{rotation_type, RotationOp};
3131
use crate::serialize::pytket::METADATA_INPUT_PARAMETERS;
3232
use crate::symbolic_constant_op;
3333
use op::Tk1Op;
34-
use param::{parse_pytket_param, PytketParam};
34+
use param::parser::{parse_pytket_param, PytketParam};
3535

3636
/// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`].
3737
///
Lines changed: 99 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -1,250 +1,117 @@
1-
//! Definitions for decoding parameter expressions from pytket operations.
2-
//!
3-
//! This is based on the `pest` grammar defined in `param.pest`.
1+
//! Definition of a loaded parameter (either floating point or a rotation type) attached to a HUGR wire.
2+
#![expect(
3+
dead_code,
4+
reason = "Temporarily unused while we refactor the pytket decoder"
5+
)]
46

5-
use derive_more::Display;
6-
use hugr::ops::OpType;
7-
use hugr::std_extensions::arithmetic::float_ops::FloatOps;
8-
use itertools::Itertools;
9-
use pest::iterators::{Pair, Pairs};
10-
use pest::pratt_parser::PrattParser;
11-
use pest::Parser;
12-
use pest_derive::Parser;
7+
pub(super) mod parser;
8+
use std::sync::LazyLock;
139

14-
/// The parsed AST for a pytket operation parameter.
15-
///
16-
/// The leafs of the AST are either a constant value, a variable name, or an
17-
/// unrecognized sympy expression.
18-
///
19-
/// Return type of [`parse_pytket_param`].
20-
#[derive(Debug, Display, Clone, PartialEq)]
21-
pub enum PytketParam<'a> {
22-
/// A constant value that can be loaded directly.
23-
#[display("{_0}")]
24-
Constant(f64),
25-
/// A variable that should be routed as an input.
26-
#[display("\"{name}\"")]
27-
InputVariable {
28-
/// The variable name.
29-
name: &'a str,
30-
},
31-
/// Unrecognized sympy expression.
32-
/// Will be emitted as a [`SympyOp`].
33-
#[display("Sympy(\"{_0}\")")]
34-
Sympy(&'a str),
35-
/// An operation on some nested expressions.
36-
#[display("{}({})", op.to_string(), args.iter().map(|a| a.to_string()).join(", "))]
37-
Operation {
38-
op: OpType,
39-
args: Vec<PytketParam<'a>>,
40-
},
41-
}
42-
43-
/// Parse a TKET1 operation parameter, and return an AST representing the expression.
44-
#[inline]
45-
pub fn parse_pytket_param(param: &str) -> PytketParam<'_> {
46-
let Ok(mut parsed) = ParamParser::parse(Rule::parameter, param) else {
47-
// The parameter could not be parsed, so we just return it as an opaque sympy expression.
48-
return PytketParam::Sympy(param);
49-
};
50-
let parsed = parsed
51-
.next()
52-
.expect("The `parameter` rule can only be matched once.");
10+
use hugr::builder::{Dataflow, FunctionBuilder};
11+
use hugr::std_extensions::arithmetic::float_types::float64_type;
12+
use hugr::types::Type;
13+
use hugr::{Hugr, Wire};
5314

54-
parse_infix_ops(parsed.into_inner())
55-
}
56-
57-
#[derive(Parser)]
58-
#[grammar = "serialize/pytket/decoder/param.pest"]
59-
struct ParamParser;
60-
61-
lazy_static::lazy_static! {
62-
/// Precedence parser used to define the order of infix operations.
63-
///
64-
/// Based on the calculator example from `pest`.
65-
/// https://pest.rs/book/examples/calculator.html
66-
static ref PRATT_PARSER: PrattParser<Rule> = {
67-
use pest::pratt_parser::{Assoc::*, Op};
68-
use Rule::*;
15+
use crate::extension::rotation::{rotation_type, RotationOp};
6916

70-
// Precedence is defined lowest to highest
71-
PrattParser::new()
72-
// Addition and subtract have equal precedence
73-
.op(Op::infix(add, Left) | Op::infix(subtract, Left))
74-
.op(Op::infix(multiply, Left) | Op::infix(divide, Left))
75-
.op(Op::infix(power, Left))
76-
};
17+
/// The type of a loaded parameter in the Hugr.
18+
#[derive(Debug, derive_more::Display, Clone, Copy, Hash, PartialEq, Eq)]
19+
pub enum LoadedParameterType {
20+
/// A float parameter.
21+
Float,
22+
/// A rotation parameter.
23+
Rotation,
7724
}
7825

79-
/// Parse a match of the [`Rule::expr`] rule.
26+
/// A loaded parameter in the Hugr.
8027
///
81-
/// This takes a sequence of rule matches alternating [`Rule::term`]s and infix operations.
82-
fn parse_infix_ops(pairs: Pairs<'_, Rule>) -> PytketParam<'_> {
83-
PRATT_PARSER
84-
.map_primary(|primary| parse_term(primary))
85-
.map_infix(|lhs, op, rhs| {
86-
let op = match op.as_rule() {
87-
Rule::add => FloatOps::fadd,
88-
Rule::subtract => FloatOps::fsub,
89-
Rule::multiply => FloatOps::fmul,
90-
Rule::divide => FloatOps::fdiv,
91-
Rule::power => FloatOps::fpow,
92-
rule => unreachable!("Expr::parse expected infix operation, found {:?}", rule),
93-
}
94-
.into();
95-
PytketParam::Operation {
96-
op,
97-
args: vec![lhs, rhs],
98-
}
99-
})
100-
.parse(pairs)
28+
/// Tracking the type of the wire lets us delay conversion between the types
29+
/// until they are actually needed.
30+
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
31+
pub struct LoadedParameter {
32+
/// The type of the parameter.
33+
pub typ: LoadedParameterType,
34+
/// The wire where the parameter is loaded.
35+
pub wire: Wire,
10136
}
10237

103-
/// Parse a match of the silent [`Rule::term`] rule.
104-
fn parse_term(pair: Pair<'_, Rule>) -> PytketParam<'_> {
105-
match pair.as_rule() {
106-
Rule::expr => parse_infix_ops(pair.into_inner()),
107-
Rule::implicit_multiply => {
108-
let mut pairs = pair.into_inner();
109-
let lhs = parse_term(pairs.next().unwrap());
110-
let rhs = parse_term(pairs.next().unwrap());
111-
PytketParam::Operation {
112-
op: FloatOps::fmul.into(),
113-
args: vec![lhs, rhs],
114-
}
38+
impl LoadedParameter {
39+
/// Returns a `LoadedParameter` for a float param.
40+
pub fn float(wire: Wire) -> LoadedParameter {
41+
LoadedParameter {
42+
typ: LoadedParameterType::Float,
43+
wire,
11544
}
116-
Rule::num => parse_number(pair),
117-
Rule::unary_minus => PytketParam::Operation {
118-
op: FloatOps::fneg.into(),
119-
args: vec![parse_term(pair.into_inner().next().unwrap())],
120-
},
121-
Rule::function_call => parse_function_call(pair),
122-
Rule::ident => PytketParam::InputVariable {
123-
name: pair.as_str(),
124-
},
125-
rule => unreachable!("Term::parse expected a term, found {:?}", rule),
12645
}
127-
}
12846

129-
/// Parse a match of the [`Rule::num`] rule.
130-
fn parse_number(pair: Pair<'_, Rule>) -> PytketParam<'_> {
131-
let num = pair.as_str();
132-
let half_turns = num
133-
.parse::<f64>()
134-
.unwrap_or_else(|_| panic!("`num` rule matched invalid number \"{num}\""));
135-
PytketParam::Constant(half_turns)
136-
}
47+
/// Returns a `LoadedParameter` for a rotation param.
48+
pub fn rotation(wire: Wire) -> LoadedParameter {
49+
LoadedParameter {
50+
typ: LoadedParameterType::Rotation,
51+
wire,
52+
}
53+
}
13754

138-
/// Parse a match of the [`Rule::function_call`] rule.
139-
fn parse_function_call(pair: Pair<'_, Rule>) -> PytketParam<'_> {
140-
let pair_str = pair.as_str();
141-
let mut args = pair.into_inner();
142-
let name = args
143-
.next()
144-
.expect("Function call must have a name")
145-
.as_str();
146-
let op = match name {
147-
"max" => FloatOps::fmax.into(),
148-
"min" => FloatOps::fmin.into(),
149-
"abs" => FloatOps::fabs.into(),
150-
"floor" => FloatOps::ffloor.into(),
151-
"ceil" => FloatOps::fceil.into(),
152-
"round" => FloatOps::fround.into(),
153-
// Unrecognized function name.
154-
// Treat it as an opaque sympy expression.
155-
_ => return PytketParam::Sympy(pair_str),
156-
};
55+
/// Returns the hugr type for the parameter.
56+
pub fn wire_type(&self) -> &Type {
57+
static FLOAT_TYPE: LazyLock<Type> = LazyLock::new(float64_type);
58+
static ROTATION_TYPE: LazyLock<Type> = LazyLock::new(rotation_type);
59+
match self.typ {
60+
LoadedParameterType::Float => &FLOAT_TYPE,
61+
LoadedParameterType::Rotation => &ROTATION_TYPE,
62+
}
63+
}
15764

158-
let args = args.map(|arg| parse_term(arg)).collect::<Vec<_>>();
159-
PytketParam::Operation { op, args }
160-
}
65+
/// Convert the parameter into a given type, if necessary.
66+
///
67+
/// Adds the necessary operations to the Hugr and returns a new wire.
68+
///
69+
/// See [`LoadedParameter::as_float`] and [`LoadedParameter::as_rotation`]
70+
/// for more convenient methods.
71+
pub fn with_type<H: AsRef<Hugr> + AsMut<Hugr>>(
72+
&self,
73+
typ: LoadedParameterType,
74+
hugr: &mut FunctionBuilder<H>,
75+
) -> LoadedParameter {
76+
match (self.typ, typ) {
77+
(LoadedParameterType::Float, LoadedParameterType::Rotation) => {
78+
let wire = hugr
79+
.add_dataflow_op(RotationOp::from_halfturns_unchecked, [self.wire])
80+
.unwrap()
81+
.out_wire(0);
82+
LoadedParameter::rotation(wire)
83+
}
84+
(LoadedParameterType::Rotation, LoadedParameterType::Float) => {
85+
let wire = hugr
86+
.add_dataflow_op(RotationOp::to_halfturns, [self.wire])
87+
.unwrap()
88+
.out_wire(0);
89+
LoadedParameter::float(wire)
90+
}
91+
_ => {
92+
debug_assert_eq!(self.typ, typ, "cannot convert {} to {}", self.typ, typ);
93+
*self
94+
}
95+
}
96+
}
16197

162-
#[cfg(test)]
163-
mod test {
164-
use super::*;
165-
use rstest::rstest;
98+
/// Convert the parameter into a float, if necessary.
99+
///
100+
/// Adds the necessary operations to the Hugr and returns a new wire.
101+
pub fn as_float<H: AsRef<Hugr> + AsMut<Hugr>>(
102+
&self,
103+
hugr: &mut FunctionBuilder<H>,
104+
) -> LoadedParameter {
105+
self.with_type(LoadedParameterType::Float, hugr)
106+
}
166107

167-
#[rstest]
168-
#[case::int("42", PytketParam::Constant(42.0))]
169-
#[case::float("42.37", PytketParam::Constant(42.37))]
170-
#[case::float_pointless("37.", PytketParam::Constant(37.))]
171-
#[case::exp("42e4", PytketParam::Constant(42e4))]
172-
#[case::neg("-42.55", PytketParam::Constant(-42.55))]
173-
#[case::parens("(42)", PytketParam::Constant(42.))]
174-
#[case::var("f64", PytketParam::InputVariable{name: "f64"})]
175-
#[case::add("42 + f64", PytketParam::Operation {
176-
op: FloatOps::fadd.into(),
177-
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}]
178-
})]
179-
#[case::sub("42 - 2", PytketParam::Operation {
180-
op: FloatOps::fsub.into(),
181-
args: vec![PytketParam::Constant(42.), PytketParam::Constant(2.)]
182-
})]
183-
#[case::product_implicit("42 f64", PytketParam::Operation {
184-
op: FloatOps::fmul.into(),
185-
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}]
186-
})]
187-
#[case::product_implicit2("42f64", PytketParam::Operation {
188-
op: FloatOps::fmul.into(),
189-
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}]
190-
})]
191-
#[case::product_implicit3("42 e4", PytketParam::Operation {
192-
op: FloatOps::fmul.into(),
193-
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "e4"}]
194-
})]
195-
#[case::max("max(42, f64)", PytketParam::Operation {
196-
op: FloatOps::fmax.into(),
197-
args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}]
198-
})]
199-
#[case::minus("-f64", PytketParam::Operation {
200-
op: FloatOps::fneg.into(),
201-
args: vec![PytketParam::InputVariable{name: "f64"}]
202-
})]
203-
#[case::unknown("unknown_op(42, f64)", PytketParam::Sympy("unknown_op(42, f64)"))]
204-
#[case::unknown_no_params("unknown_op()", PytketParam::Sympy("unknown_op()"))]
205-
#[case::nested("max(42, unknown_op(37))", PytketParam::Operation {
206-
op: FloatOps::fmax.into(),
207-
args: vec![PytketParam::Constant(42.), PytketParam::Sympy("unknown_op(37)")]
208-
})]
209-
#[case::precedence("5-2/3x+4**6", PytketParam::Operation {
210-
op: FloatOps::fadd.into(),
211-
args: vec![
212-
PytketParam::Operation {
213-
op: FloatOps::fsub.into(),
214-
args: vec![
215-
PytketParam::Constant(5.),
216-
PytketParam::Operation { op: FloatOps::fdiv.into(), args: vec![
217-
PytketParam::Constant(2.),
218-
PytketParam::Operation { op: FloatOps::fmul.into(), args: vec![
219-
PytketParam::Constant(3.),
220-
PytketParam::InputVariable{name: "x"},
221-
]}
222-
]}
223-
]
224-
},
225-
PytketParam::Operation { op: FloatOps::fpow.into(), args: vec![
226-
PytketParam::Constant(4.),
227-
PytketParam::Constant(6.),
228-
]}
229-
]
230-
})]
231-
#[case::associativity("1-2-3+4", PytketParam::Operation {
232-
op: FloatOps::fadd.into(),
233-
args: vec![
234-
PytketParam::Operation { op: FloatOps::fsub.into(), args: vec![
235-
PytketParam::Operation { op: FloatOps::fsub.into(), args: vec![
236-
PytketParam::Constant(1.),
237-
PytketParam::Constant(2.),
238-
]},
239-
PytketParam::Constant(3.),
240-
]},
241-
PytketParam::Constant(4.),
242-
]
243-
})]
244-
fn parse_param(#[case] param: &str, #[case] expected: PytketParam) {
245-
let parsed = parse_pytket_param(param);
246-
if parsed != expected {
247-
panic!("Incorrect parameter parsing\n\texpression: \"{param}\"\n\tparsed: {parsed}\n\texpected: {expected}");
248-
}
108+
/// Convert the parameter into a rotation, if necessary.
109+
///
110+
/// Adds the necessary operations to the Hugr and returns a new wire.
111+
pub fn as_rotation<H: AsRef<Hugr> + AsMut<Hugr>>(
112+
&self,
113+
hugr: &mut FunctionBuilder<H>,
114+
) -> LoadedParameter {
115+
self.with_type(LoadedParameterType::Rotation, hugr)
249116
}
250117
}

0 commit comments

Comments
 (0)