Skip to content

fix!: Fix rotation -> float param type conversion #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tket-py/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ impl PyPauliIter {
//
// TODO: These can no longer be constructed from Python. Since `hugr-rs 0.14`
// we need an extension and `OpDef` to defines these.
// If fixing this, make sure to fix `PyHugrType` too.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dummy change so this gets a line in the tket-py changelog

// When fixing this, make sure to fix `PyHugrType` too.
#[pyclass]
#[pyo3(name = "CustomOp")]
#[repr(transparent)]
Expand Down
35 changes: 29 additions & 6 deletions tket-qsystem/src/pytket/qsystem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use hugr::extension::simple_op::MakeExtensionOp;
use hugr::extension::ExtensionId;
use hugr::ops::ExtensionOp;
use hugr::HugrView;
use itertools::Itertools as _;
use tket::serialize::pytket::decoder::{
DecodeStatus, LoadedParameter, PytketDecoderContext, TrackedBit, TrackedQubit,
DecodeStatus, LoadedParameter, ParameterType, PytketDecoderContext, TrackedBit, TrackedQubit,
};
use tket::serialize::pytket::encoder::EncodeStatus;
use tket::serialize::pytket::encoder::{make_tk1_operation, EncodeStatus};
use tket::serialize::pytket::extension::PytketDecoder;
use tket::serialize::pytket::{
PytketDecodeError, PytketEmitter, PytketEncodeError, PytketEncoderContext,
Expand Down Expand Up @@ -83,8 +84,22 @@ impl QSystemEmitter {
}
};

// Most operations map directly to a pytket one.
encoder.emit_node(serial_op, node, circ)?;
// pytket parameters are always in half-turns.
// Since the `tket.qsystem` op inputs are in radians, we have to convert them here.
encoder.emit_node_command(
node,
circ,
|_| Vec::new(),
move |mut inputs| {
for param in inputs.params.to_mut() {
*param = match param.strip_suffix(") * (pi)") {
Some(s) if s.starts_with("(") => s[1..].to_string(),
_ => format!("{param} / (pi)"),
};
}
make_tk1_operation(serial_op, inputs)
},
)?;

Ok(EncodeStatus::Success)
}
Expand Down Expand Up @@ -126,15 +141,23 @@ impl PytketDecoder for QSystemEmitter {
PytketOptype::ZZPhase => QSystemOp::ZZPhase,
PytketOptype::ZZMax => {
// This is a ZZPhase with a 1/2 angle.
let param = decoder.load_parameter("pi/2");
let param =
Arc::new(decoder.load_parameter_with_type("pi/2", ParameterType::FloatRadians));
decoder.add_node_with_wires(QSystemOp::ZZPhase, qubits, bits, &[param])?;
return Ok(DecodeStatus::Success);
}
_ => {
return Ok(DecodeStatus::Unsupported);
}
};
decoder.add_node_with_wires(op, qubits, bits, params)?;

// We expect all parameters to be floats in radians.
let params = params
.iter()
.map(|p| Arc::new(p.as_float_radians(&mut decoder.builder)))
.collect_vec();

decoder.add_node_with_wires(op, qubits, bits, &params)?;

Ok(DecodeStatus::Success)
}
Expand Down
2 changes: 1 addition & 1 deletion tket-qsystem/src/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) {
let count_b = b_command_count.get(a).copied().unwrap_or_default();
assert_eq!(
count_a, count_b,
"command {a:?} appears {count_a} times in rhs and {count_b} times in lhs"
"command {a:?} appears {count_a} times in rhs and {count_b} times in lhs.\ncounts for a: {a_command_count:#?}\ncounts for b: {b_command_count:#?}"
);
}
assert_eq!(a_command_count.len(), b_command_count.len());
Expand Down
26 changes: 22 additions & 4 deletions tket/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod param;
mod tracked_elem;
mod wires;

pub use param::{LoadedParameter, LoadedParameterType};
pub use param::{LoadedParameter, ParameterType};
pub use tracked_elem::{TrackedBit, TrackedQubit};
pub use wires::TrackedWires;

Expand Down Expand Up @@ -420,8 +420,13 @@ impl<'h> PytketDecoderContext<'h> {
/// the first registers in `wires` for the bit inputs and the remaining
/// registers for the outputs.
///
/// The input wire types must match the operation's input signature,
/// no type conversion is performed.
/// The input wire types must match the operation's input signature, no type
/// conversion is performed.
///
/// The caller must take care of converting the parameter wires to the
/// required types and units expected by the operation. An error will be
/// returned if the parameter does not match the expected wire type, but the
/// unit (radians or half-turns) cannot be checked automatically.
///
/// # Arguments
///
Expand All @@ -435,6 +440,8 @@ impl<'h> PytketDecoderContext<'h> {
/// input ports.
/// - Returns an error if the node's output ports cannot be assigned to
/// arguments from the input wire set.
/// - Returns an error if the parameter wires do not match the expected
/// types.
pub fn add_node_with_wires(
&mut self,
op: impl Into<OpType>,
Expand Down Expand Up @@ -593,7 +600,18 @@ impl<'h> PytketDecoderContext<'h> {
/// - If the parameter is a variable, adds a new `rotation` input to the region.
/// - If the parameter is a sympy expressions, adds it as a [`SympyOpDef`][crate::extension::sympy::SympyOpDef] black box.
pub fn load_parameter(&mut self, param: &str) -> Arc<LoadedParameter> {
self.wire_tracker.load_parameter(&mut self.builder, param)
Arc::new(
self.wire_tracker
.load_parameter(&mut self.builder, param, None),
)
}

/// Loads the given parameter expression as a [`LoadedParameter`] in the hugr, and converts it to the requested type and unit.
///
/// See [`PytketDecoderContext::load_parameter`] for more details.
pub fn load_parameter_with_type(&mut self, param: &str, typ: ParameterType) -> LoadedParameter {
self.wire_tracker
.load_parameter(&mut self.builder, param, Some(typ))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems strange to pass the typ to both load and with, but I guess it is optional in the first case

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the load_parameter docs, that is a type hint that lets us decide for example between loading a rotation const or a float const. The output type is not guaranteed to have that type.

I guess it'd be less confusing if we change the hint for an assurance and do the with_type internally. I'll make the change.

}
}

Expand Down
166 changes: 121 additions & 45 deletions tket/src/serialize/pytket/decoder/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,110 +4,186 @@ pub(super) mod parser;
use std::sync::LazyLock;

use hugr::builder::{Dataflow, FunctionBuilder};
use hugr::std_extensions::arithmetic::float_types::float64_type;
use hugr::ops::Value;
use hugr::std_extensions::arithmetic::float_ops::FloatOps;
use hugr::std_extensions::arithmetic::float_types::{float64_type, ConstF64};
use hugr::types::Type;
use hugr::{Hugr, Wire};

use crate::extension::rotation::{rotation_type, RotationOp};

/// The type of a loaded parameter in the Hugr.
/// The type of a loaded parameter in the Hugr, including its unit.
#[derive(Debug, derive_more::Display, Clone, Copy, Hash, PartialEq, Eq)]
pub enum LoadedParameterType {
/// A float parameter.
Float,
/// A rotation parameter.
pub enum ParameterType {
/// A float parameter in radians.
FloatRadians,
/// A float parameter in half-turns.
FloatHalfTurns,
/// A rotation parameter in half-turns.
Rotation,
}

impl ParameterType {
/// Returns the type of the parameter.
pub fn to_type(&self) -> &'static Type {
static FLOAT_TYPE: LazyLock<Type> = LazyLock::new(float64_type);
static ROTATION_TYPE: LazyLock<Type> = LazyLock::new(rotation_type);
match self {
ParameterType::FloatRadians => &FLOAT_TYPE,
ParameterType::FloatHalfTurns => &FLOAT_TYPE,
ParameterType::Rotation => &ROTATION_TYPE,
}
}
}

/// A loaded parameter in the Hugr.
///
/// Tracking the type of the wire lets us delay conversion between the types
/// until they are actually needed.
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct LoadedParameter {
/// The type of the parameter.
pub typ: LoadedParameterType,
typ: ParameterType,
/// The wire where the parameter is loaded.
pub wire: Wire,
wire: Wire,
}

impl LoadedParameter {
/// Returns a `LoadedParameter` for a float param.
pub fn float(wire: Wire) -> LoadedParameter {
/// Returns a `LoadedParameter` with the given type and unit.
pub fn new(typ: ParameterType, wire: Wire) -> LoadedParameter {
LoadedParameter { typ, wire }
}

/// Returns the type of the parameter.
#[inline]
pub fn typ(&self) -> ParameterType {
self.typ
}

/// Returns the wire where the parameter is loaded.
#[inline]
pub fn wire(&self) -> Wire {
self.wire
}

/// Returns a `LoadedParameter` for a float param in radians.
#[inline]
pub fn float_radians(wire: Wire) -> LoadedParameter {
LoadedParameter {
typ: LoadedParameterType::Float,
typ: ParameterType::FloatRadians,
wire,
}
}

/// Returns a `LoadedParameter` for a rotation param.
/// Returns a `LoadedParameter` for a float param in half-turns.
#[inline]
pub fn float_half_turns(wire: Wire) -> LoadedParameter {
LoadedParameter {
typ: ParameterType::FloatHalfTurns,
wire,
}
}

/// Returns a `LoadedParameter` for a rotation param in half-turns.
#[inline]
pub fn rotation(wire: Wire) -> LoadedParameter {
LoadedParameter {
typ: LoadedParameterType::Rotation,
typ: ParameterType::Rotation,
wire,
}
}

/// Returns the hugr type for the parameter.
pub fn wire_type(&self) -> &Type {
static FLOAT_TYPE: LazyLock<Type> = LazyLock::new(float64_type);
static ROTATION_TYPE: LazyLock<Type> = LazyLock::new(rotation_type);
match self.typ {
LoadedParameterType::Float => &FLOAT_TYPE,
LoadedParameterType::Rotation => &ROTATION_TYPE,
}
#[inline]
pub fn wire_type(&self) -> &'static Type {
self.typ.to_type()
}

/// Convert the parameter into a given type, if necessary.
///
/// Adds the necessary operations to the Hugr and returns a new wire.
///
/// See [`LoadedParameter::as_float`] and [`LoadedParameter::as_rotation`]
/// for more convenient methods.
/// See [`LoadedParameter::as_rotation`],
/// [`LoadedParameter::as_float_radians`] and
/// [`LoadedParameter::as_float_half_turns`] for more convenient methods.
#[inline]
pub fn with_type<H: AsRef<Hugr> + AsMut<Hugr>>(
&self,
typ: LoadedParameterType,
typ: ParameterType,
hugr: &mut FunctionBuilder<H>,
) -> LoadedParameter {
match (self.typ, typ) {
(LoadedParameterType::Float, LoadedParameterType::Rotation) => {
let wire = hugr
.add_dataflow_op(RotationOp::from_halfturns_unchecked, [self.wire])
.unwrap()
.out_wire(0);
LoadedParameter::rotation(wire)
}
(LoadedParameterType::Rotation, LoadedParameterType::Float) => {
let wire = hugr
.add_dataflow_op(RotationOp::to_halfturns, [self.wire])
.unwrap()
match typ {
ParameterType::FloatRadians => self.as_float_radians(hugr),
ParameterType::FloatHalfTurns => self.as_float_half_turns(hugr),
ParameterType::Rotation => self.as_rotation(hugr),
}
}

/// Convert the parameter into a float in radians.
///
/// Adds the necessary operations to the Hugr and returns a new wire.
pub fn as_float_radians<H: AsRef<Hugr> + AsMut<Hugr>>(
&self,
hugr: &mut FunctionBuilder<H>,
) -> LoadedParameter {
match self.typ {
ParameterType::FloatRadians => *self,
ParameterType::FloatHalfTurns => {
let pi = hugr.add_load_const(Value::from(ConstF64::new(std::f64::consts::PI)));
let float_radians = hugr
.add_dataflow_op(FloatOps::fmul, [self.wire(), pi])
.expect("Error converting float to rotation")
.out_wire(0);
LoadedParameter::float(wire)
}
_ => {
debug_assert_eq!(self.typ, typ, "cannot convert {} to {}", self.typ, typ);
*self
LoadedParameter::float_radians(float_radians)
}
ParameterType::Rotation => self.as_float_half_turns(hugr).as_float_radians(hugr),
}
}

/// Convert the parameter into a float, if necessary.
/// Convert the parameter into a float in half-turns.
///
/// Adds the necessary operations to the Hugr and returns a new wire.
pub fn as_float<H: AsRef<Hugr> + AsMut<Hugr>>(
pub fn as_float_half_turns<H: AsRef<Hugr> + AsMut<Hugr>>(
&self,
hugr: &mut FunctionBuilder<H>,
) -> LoadedParameter {
self.with_type(LoadedParameterType::Float, hugr)
match self.typ {
ParameterType::FloatHalfTurns => *self,
ParameterType::FloatRadians => {
let pi = hugr.add_load_const(Value::from(ConstF64::new(std::f64::consts::PI)));
let float_halfturns = hugr
.add_dataflow_op(FloatOps::fdiv, [self.wire, pi])
.expect("Error converting float to rotation")
.out_wire(0);
LoadedParameter::float_half_turns(float_halfturns)
}
ParameterType::Rotation => {
let wire = hugr
.add_dataflow_op(RotationOp::to_halfturns, [self.wire()])
.unwrap()
.out_wire(0);
LoadedParameter::float_half_turns(wire)
}
}
}

/// Convert the parameter into a rotation, if necessary.
/// Convert the parameter into a rotation in half-turns.
///
/// Adds the necessary operations to the Hugr and returns a new wire.
pub fn as_rotation<H: AsRef<Hugr> + AsMut<Hugr>>(
&self,
hugr: &mut FunctionBuilder<H>,
) -> LoadedParameter {
self.with_type(LoadedParameterType::Rotation, hugr)
match self.typ {
ParameterType::Rotation => *self,
ParameterType::FloatHalfTurns => {
let wire = hugr
.add_dataflow_op(RotationOp::from_halfturns_unchecked, [self.wire()])
.unwrap()
.out_wire(0);
LoadedParameter::rotation(wire)
}
ParameterType::FloatRadians => self.as_float_half_turns(hugr).as_rotation(hugr),
}
}
}
Loading
Loading