|
1 |
| -//! Definitions for decoding parameter expressions from pytket operations. |
2 |
| -//! |
3 |
| -//! This is based on the `pest` grammar defined in `param.pest`. |
| 1 | +//! Definition of a loaded parameter (either floating point or a rotation type) attached to a HUGR wire. |
| 2 | +#![expect( |
| 3 | + dead_code, |
| 4 | + reason = "Temporarily unused while we refactor the pytket decoder" |
| 5 | +)] |
4 | 6 |
|
5 |
| -use derive_more::Display; |
6 |
| -use hugr::ops::OpType; |
7 |
| -use hugr::std_extensions::arithmetic::float_ops::FloatOps; |
8 |
| -use itertools::Itertools; |
9 |
| -use pest::iterators::{Pair, Pairs}; |
10 |
| -use pest::pratt_parser::PrattParser; |
11 |
| -use pest::Parser; |
12 |
| -use pest_derive::Parser; |
| 7 | +pub(super) mod parser; |
| 8 | +use std::sync::LazyLock; |
13 | 9 |
|
14 |
| -/// The parsed AST for a pytket operation parameter. |
15 |
| -/// |
16 |
| -/// The leafs of the AST are either a constant value, a variable name, or an |
17 |
| -/// unrecognized sympy expression. |
18 |
| -/// |
19 |
| -/// Return type of [`parse_pytket_param`]. |
20 |
| -#[derive(Debug, Display, Clone, PartialEq)] |
21 |
| -pub enum PytketParam<'a> { |
22 |
| - /// A constant value that can be loaded directly. |
23 |
| - #[display("{_0}")] |
24 |
| - Constant(f64), |
25 |
| - /// A variable that should be routed as an input. |
26 |
| - #[display("\"{name}\"")] |
27 |
| - InputVariable { |
28 |
| - /// The variable name. |
29 |
| - name: &'a str, |
30 |
| - }, |
31 |
| - /// Unrecognized sympy expression. |
32 |
| - /// Will be emitted as a [`SympyOp`]. |
33 |
| - #[display("Sympy(\"{_0}\")")] |
34 |
| - Sympy(&'a str), |
35 |
| - /// An operation on some nested expressions. |
36 |
| - #[display("{}({})", op.to_string(), args.iter().map(|a| a.to_string()).join(", "))] |
37 |
| - Operation { |
38 |
| - op: OpType, |
39 |
| - args: Vec<PytketParam<'a>>, |
40 |
| - }, |
41 |
| -} |
42 |
| - |
43 |
| -/// Parse a TKET1 operation parameter, and return an AST representing the expression. |
44 |
| -#[inline] |
45 |
| -pub fn parse_pytket_param(param: &str) -> PytketParam<'_> { |
46 |
| - let Ok(mut parsed) = ParamParser::parse(Rule::parameter, param) else { |
47 |
| - // The parameter could not be parsed, so we just return it as an opaque sympy expression. |
48 |
| - return PytketParam::Sympy(param); |
49 |
| - }; |
50 |
| - let parsed = parsed |
51 |
| - .next() |
52 |
| - .expect("The `parameter` rule can only be matched once."); |
| 10 | +use hugr::builder::{Dataflow, FunctionBuilder}; |
| 11 | +use hugr::std_extensions::arithmetic::float_types::float64_type; |
| 12 | +use hugr::types::Type; |
| 13 | +use hugr::{Hugr, Wire}; |
53 | 14 |
|
54 |
| - parse_infix_ops(parsed.into_inner()) |
55 |
| -} |
56 |
| - |
57 |
| -#[derive(Parser)] |
58 |
| -#[grammar = "serialize/pytket/decoder/param.pest"] |
59 |
| -struct ParamParser; |
60 |
| - |
61 |
| -lazy_static::lazy_static! { |
62 |
| - /// Precedence parser used to define the order of infix operations. |
63 |
| - /// |
64 |
| - /// Based on the calculator example from `pest`. |
65 |
| - /// https://pest.rs/book/examples/calculator.html |
66 |
| - static ref PRATT_PARSER: PrattParser<Rule> = { |
67 |
| - use pest::pratt_parser::{Assoc::*, Op}; |
68 |
| - use Rule::*; |
| 15 | +use crate::extension::rotation::{rotation_type, RotationOp}; |
69 | 16 |
|
70 |
| - // Precedence is defined lowest to highest |
71 |
| - PrattParser::new() |
72 |
| - // Addition and subtract have equal precedence |
73 |
| - .op(Op::infix(add, Left) | Op::infix(subtract, Left)) |
74 |
| - .op(Op::infix(multiply, Left) | Op::infix(divide, Left)) |
75 |
| - .op(Op::infix(power, Left)) |
76 |
| - }; |
| 17 | +/// The type of a loaded parameter in the Hugr. |
| 18 | +#[derive(Debug, derive_more::Display, Clone, Copy, Hash, PartialEq, Eq)] |
| 19 | +pub enum LoadedParameterType { |
| 20 | + /// A float parameter. |
| 21 | + Float, |
| 22 | + /// A rotation parameter. |
| 23 | + Rotation, |
77 | 24 | }
|
78 | 25 |
|
79 |
| -/// Parse a match of the [`Rule::expr`] rule. |
| 26 | +/// A loaded parameter in the Hugr. |
80 | 27 | ///
|
81 |
| -/// This takes a sequence of rule matches alternating [`Rule::term`]s and infix operations. |
82 |
| -fn parse_infix_ops(pairs: Pairs<'_, Rule>) -> PytketParam<'_> { |
83 |
| - PRATT_PARSER |
84 |
| - .map_primary(|primary| parse_term(primary)) |
85 |
| - .map_infix(|lhs, op, rhs| { |
86 |
| - let op = match op.as_rule() { |
87 |
| - Rule::add => FloatOps::fadd, |
88 |
| - Rule::subtract => FloatOps::fsub, |
89 |
| - Rule::multiply => FloatOps::fmul, |
90 |
| - Rule::divide => FloatOps::fdiv, |
91 |
| - Rule::power => FloatOps::fpow, |
92 |
| - rule => unreachable!("Expr::parse expected infix operation, found {:?}", rule), |
93 |
| - } |
94 |
| - .into(); |
95 |
| - PytketParam::Operation { |
96 |
| - op, |
97 |
| - args: vec![lhs, rhs], |
98 |
| - } |
99 |
| - }) |
100 |
| - .parse(pairs) |
| 28 | +/// Tracking the type of the wire lets us delay conversion between the types |
| 29 | +/// until they are actually needed. |
| 30 | +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] |
| 31 | +pub struct LoadedParameter { |
| 32 | + /// The type of the parameter. |
| 33 | + pub typ: LoadedParameterType, |
| 34 | + /// The wire where the parameter is loaded. |
| 35 | + pub wire: Wire, |
101 | 36 | }
|
102 | 37 |
|
103 |
| -/// Parse a match of the silent [`Rule::term`] rule. |
104 |
| -fn parse_term(pair: Pair<'_, Rule>) -> PytketParam<'_> { |
105 |
| - match pair.as_rule() { |
106 |
| - Rule::expr => parse_infix_ops(pair.into_inner()), |
107 |
| - Rule::implicit_multiply => { |
108 |
| - let mut pairs = pair.into_inner(); |
109 |
| - let lhs = parse_term(pairs.next().unwrap()); |
110 |
| - let rhs = parse_term(pairs.next().unwrap()); |
111 |
| - PytketParam::Operation { |
112 |
| - op: FloatOps::fmul.into(), |
113 |
| - args: vec![lhs, rhs], |
114 |
| - } |
| 38 | +impl LoadedParameter { |
| 39 | + /// Returns a `LoadedParameter` for a float param. |
| 40 | + pub fn float(wire: Wire) -> LoadedParameter { |
| 41 | + LoadedParameter { |
| 42 | + typ: LoadedParameterType::Float, |
| 43 | + wire, |
115 | 44 | }
|
116 |
| - Rule::num => parse_number(pair), |
117 |
| - Rule::unary_minus => PytketParam::Operation { |
118 |
| - op: FloatOps::fneg.into(), |
119 |
| - args: vec![parse_term(pair.into_inner().next().unwrap())], |
120 |
| - }, |
121 |
| - Rule::function_call => parse_function_call(pair), |
122 |
| - Rule::ident => PytketParam::InputVariable { |
123 |
| - name: pair.as_str(), |
124 |
| - }, |
125 |
| - rule => unreachable!("Term::parse expected a term, found {:?}", rule), |
126 | 45 | }
|
127 |
| -} |
128 | 46 |
|
129 |
| -/// Parse a match of the [`Rule::num`] rule. |
130 |
| -fn parse_number(pair: Pair<'_, Rule>) -> PytketParam<'_> { |
131 |
| - let num = pair.as_str(); |
132 |
| - let half_turns = num |
133 |
| - .parse::<f64>() |
134 |
| - .unwrap_or_else(|_| panic!("`num` rule matched invalid number \"{num}\"")); |
135 |
| - PytketParam::Constant(half_turns) |
136 |
| -} |
| 47 | + /// Returns a `LoadedParameter` for a rotation param. |
| 48 | + pub fn rotation(wire: Wire) -> LoadedParameter { |
| 49 | + LoadedParameter { |
| 50 | + typ: LoadedParameterType::Rotation, |
| 51 | + wire, |
| 52 | + } |
| 53 | + } |
137 | 54 |
|
138 |
| -/// Parse a match of the [`Rule::function_call`] rule. |
139 |
| -fn parse_function_call(pair: Pair<'_, Rule>) -> PytketParam<'_> { |
140 |
| - let pair_str = pair.as_str(); |
141 |
| - let mut args = pair.into_inner(); |
142 |
| - let name = args |
143 |
| - .next() |
144 |
| - .expect("Function call must have a name") |
145 |
| - .as_str(); |
146 |
| - let op = match name { |
147 |
| - "max" => FloatOps::fmax.into(), |
148 |
| - "min" => FloatOps::fmin.into(), |
149 |
| - "abs" => FloatOps::fabs.into(), |
150 |
| - "floor" => FloatOps::ffloor.into(), |
151 |
| - "ceil" => FloatOps::fceil.into(), |
152 |
| - "round" => FloatOps::fround.into(), |
153 |
| - // Unrecognized function name. |
154 |
| - // Treat it as an opaque sympy expression. |
155 |
| - _ => return PytketParam::Sympy(pair_str), |
156 |
| - }; |
| 55 | + /// Returns the hugr type for the parameter. |
| 56 | + pub fn wire_type(&self) -> &Type { |
| 57 | + static FLOAT_TYPE: LazyLock<Type> = LazyLock::new(float64_type); |
| 58 | + static ROTATION_TYPE: LazyLock<Type> = LazyLock::new(rotation_type); |
| 59 | + match self.typ { |
| 60 | + LoadedParameterType::Float => &FLOAT_TYPE, |
| 61 | + LoadedParameterType::Rotation => &ROTATION_TYPE, |
| 62 | + } |
| 63 | + } |
157 | 64 |
|
158 |
| - let args = args.map(|arg| parse_term(arg)).collect::<Vec<_>>(); |
159 |
| - PytketParam::Operation { op, args } |
160 |
| -} |
| 65 | + /// Convert the parameter into a given type, if necessary. |
| 66 | + /// |
| 67 | + /// Adds the necessary operations to the Hugr and returns a new wire. |
| 68 | + /// |
| 69 | + /// See [`LoadedParameter::as_float`] and [`LoadedParameter::as_rotation`] |
| 70 | + /// for more convenient methods. |
| 71 | + pub fn with_type<H: AsRef<Hugr> + AsMut<Hugr>>( |
| 72 | + &self, |
| 73 | + typ: LoadedParameterType, |
| 74 | + hugr: &mut FunctionBuilder<H>, |
| 75 | + ) -> LoadedParameter { |
| 76 | + match (self.typ, typ) { |
| 77 | + (LoadedParameterType::Float, LoadedParameterType::Rotation) => { |
| 78 | + let wire = hugr |
| 79 | + .add_dataflow_op(RotationOp::from_halfturns_unchecked, [self.wire]) |
| 80 | + .unwrap() |
| 81 | + .out_wire(0); |
| 82 | + LoadedParameter::rotation(wire) |
| 83 | + } |
| 84 | + (LoadedParameterType::Rotation, LoadedParameterType::Float) => { |
| 85 | + let wire = hugr |
| 86 | + .add_dataflow_op(RotationOp::to_halfturns, [self.wire]) |
| 87 | + .unwrap() |
| 88 | + .out_wire(0); |
| 89 | + LoadedParameter::float(wire) |
| 90 | + } |
| 91 | + _ => { |
| 92 | + debug_assert_eq!(self.typ, typ, "cannot convert {} to {}", self.typ, typ); |
| 93 | + *self |
| 94 | + } |
| 95 | + } |
| 96 | + } |
161 | 97 |
|
162 |
| -#[cfg(test)] |
163 |
| -mod test { |
164 |
| - use super::*; |
165 |
| - use rstest::rstest; |
| 98 | + /// Convert the parameter into a float, if necessary. |
| 99 | + /// |
| 100 | + /// Adds the necessary operations to the Hugr and returns a new wire. |
| 101 | + pub fn as_float<H: AsRef<Hugr> + AsMut<Hugr>>( |
| 102 | + &self, |
| 103 | + hugr: &mut FunctionBuilder<H>, |
| 104 | + ) -> LoadedParameter { |
| 105 | + self.with_type(LoadedParameterType::Float, hugr) |
| 106 | + } |
166 | 107 |
|
167 |
| - #[rstest] |
168 |
| - #[case::int("42", PytketParam::Constant(42.0))] |
169 |
| - #[case::float("42.37", PytketParam::Constant(42.37))] |
170 |
| - #[case::float_pointless("37.", PytketParam::Constant(37.))] |
171 |
| - #[case::exp("42e4", PytketParam::Constant(42e4))] |
172 |
| - #[case::neg("-42.55", PytketParam::Constant(-42.55))] |
173 |
| - #[case::parens("(42)", PytketParam::Constant(42.))] |
174 |
| - #[case::var("f64", PytketParam::InputVariable{name: "f64"})] |
175 |
| - #[case::add("42 + f64", PytketParam::Operation { |
176 |
| - op: FloatOps::fadd.into(), |
177 |
| - args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}] |
178 |
| - })] |
179 |
| - #[case::sub("42 - 2", PytketParam::Operation { |
180 |
| - op: FloatOps::fsub.into(), |
181 |
| - args: vec![PytketParam::Constant(42.), PytketParam::Constant(2.)] |
182 |
| - })] |
183 |
| - #[case::product_implicit("42 f64", PytketParam::Operation { |
184 |
| - op: FloatOps::fmul.into(), |
185 |
| - args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}] |
186 |
| - })] |
187 |
| - #[case::product_implicit2("42f64", PytketParam::Operation { |
188 |
| - op: FloatOps::fmul.into(), |
189 |
| - args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}] |
190 |
| - })] |
191 |
| - #[case::product_implicit3("42 e4", PytketParam::Operation { |
192 |
| - op: FloatOps::fmul.into(), |
193 |
| - args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "e4"}] |
194 |
| - })] |
195 |
| - #[case::max("max(42, f64)", PytketParam::Operation { |
196 |
| - op: FloatOps::fmax.into(), |
197 |
| - args: vec![PytketParam::Constant(42.), PytketParam::InputVariable{name: "f64"}] |
198 |
| - })] |
199 |
| - #[case::minus("-f64", PytketParam::Operation { |
200 |
| - op: FloatOps::fneg.into(), |
201 |
| - args: vec![PytketParam::InputVariable{name: "f64"}] |
202 |
| - })] |
203 |
| - #[case::unknown("unknown_op(42, f64)", PytketParam::Sympy("unknown_op(42, f64)"))] |
204 |
| - #[case::unknown_no_params("unknown_op()", PytketParam::Sympy("unknown_op()"))] |
205 |
| - #[case::nested("max(42, unknown_op(37))", PytketParam::Operation { |
206 |
| - op: FloatOps::fmax.into(), |
207 |
| - args: vec![PytketParam::Constant(42.), PytketParam::Sympy("unknown_op(37)")] |
208 |
| - })] |
209 |
| - #[case::precedence("5-2/3x+4**6", PytketParam::Operation { |
210 |
| - op: FloatOps::fadd.into(), |
211 |
| - args: vec![ |
212 |
| - PytketParam::Operation { |
213 |
| - op: FloatOps::fsub.into(), |
214 |
| - args: vec![ |
215 |
| - PytketParam::Constant(5.), |
216 |
| - PytketParam::Operation { op: FloatOps::fdiv.into(), args: vec![ |
217 |
| - PytketParam::Constant(2.), |
218 |
| - PytketParam::Operation { op: FloatOps::fmul.into(), args: vec![ |
219 |
| - PytketParam::Constant(3.), |
220 |
| - PytketParam::InputVariable{name: "x"}, |
221 |
| - ]} |
222 |
| - ]} |
223 |
| - ] |
224 |
| - }, |
225 |
| - PytketParam::Operation { op: FloatOps::fpow.into(), args: vec![ |
226 |
| - PytketParam::Constant(4.), |
227 |
| - PytketParam::Constant(6.), |
228 |
| - ]} |
229 |
| - ] |
230 |
| - })] |
231 |
| - #[case::associativity("1-2-3+4", PytketParam::Operation { |
232 |
| - op: FloatOps::fadd.into(), |
233 |
| - args: vec![ |
234 |
| - PytketParam::Operation { op: FloatOps::fsub.into(), args: vec![ |
235 |
| - PytketParam::Operation { op: FloatOps::fsub.into(), args: vec![ |
236 |
| - PytketParam::Constant(1.), |
237 |
| - PytketParam::Constant(2.), |
238 |
| - ]}, |
239 |
| - PytketParam::Constant(3.), |
240 |
| - ]}, |
241 |
| - PytketParam::Constant(4.), |
242 |
| - ] |
243 |
| - })] |
244 |
| - fn parse_param(#[case] param: &str, #[case] expected: PytketParam) { |
245 |
| - let parsed = parse_pytket_param(param); |
246 |
| - if parsed != expected { |
247 |
| - panic!("Incorrect parameter parsing\n\texpression: \"{param}\"\n\tparsed: {parsed}\n\texpected: {expected}"); |
248 |
| - } |
| 108 | + /// Convert the parameter into a rotation, if necessary. |
| 109 | + /// |
| 110 | + /// Adds the necessary operations to the Hugr and returns a new wire. |
| 111 | + pub fn as_rotation<H: AsRef<Hugr> + AsMut<Hugr>>( |
| 112 | + &self, |
| 113 | + hugr: &mut FunctionBuilder<H>, |
| 114 | + ) -> LoadedParameter { |
| 115 | + self.with_type(LoadedParameterType::Rotation, hugr) |
249 | 116 | }
|
250 | 117 | }
|
0 commit comments