Skip to content

Commit d5a5f66

Browse files
authored
Add ASIC area model (#145)
1 parent 5558d93 commit d5a5f66

File tree

9 files changed

+173
-53
lines changed

9 files changed

+173
-53
lines changed

src/asic.rs

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ use super::check::Check;
88
use super::driver::Comparison;
99
use super::driver::Report;
1010
use super::driver::{Canonical, CircuitLang, EquivCheck, Explanable, Extractable};
11+
use super::verilog::PrimitiveType;
1112
use egg::{
1213
Analysis, AstSize, CostFunction, DidMerge, EGraph, Id, Language, RecExpr, Rewrite, Symbol,
1314
define_language, rewrite,
1415
};
1516
use serde::Serialize;
1617
use std::collections::BTreeMap;
18+
use std::str::FromStr;
1719

1820
define_language! {
1921
/// Definitions of e-node types. Programs are the only node type that is not a net/signal.
@@ -119,6 +121,28 @@ impl CostFunction<CellLang> for CellCountFn {
119121
}
120122
}
121123

124+
/// A cost function that extracts a circuit with the least area
125+
pub struct AreaFn;
126+
127+
impl CostFunction<CellLang> for AreaFn {
128+
type Cost = f32;
129+
fn cost<C>(&mut self, enode: &CellLang, mut costs: C) -> Self::Cost
130+
where
131+
C: FnMut(Id) -> Self::Cost,
132+
{
133+
let op_cost = match enode {
134+
CellLang::Const(_) | CellLang::Var(_) => PrimitiveType::INV.get_min_area().unwrap(),
135+
CellLang::Cell(n, _l) => {
136+
let prim = PrimitiveType::from_str(n.as_str()).unwrap();
137+
prim.get_min_area().unwrap_or(1.33)
138+
}
139+
_ => f32::MAX,
140+
};
141+
142+
enode.fold(op_cost, |sum, id| sum + costs(id))
143+
}
144+
}
145+
122146
impl Extractable for CellLang {
123147
fn depth_cost_fn() -> impl CostFunction<Self, Cost = i64> {
124148
DepthCostFn
@@ -128,6 +152,10 @@ impl Extractable for CellLang {
128152
CellCountFn::new(cut_size)
129153
}
130154

155+
fn exact_area_cost_fn() -> impl CostFunction<Self> {
156+
AreaFn
157+
}
158+
131159
fn filter_cost_fn(_set: std::collections::HashSet<String>) -> impl CostFunction<Self> {
132160
eprintln!("TODO: CellLang::filter_cost_fn");
133161
AstSize
@@ -186,22 +214,51 @@ impl Analysis<CellLang> for CellAnalysis {
186214
fn make(_egraph: &mut EGraph<CellLang, Self>, _enode: &CellLang) -> Self::Data {}
187215
}
188216

189-
fn get_cell_counts(expr: &RecExpr<CellLang>) -> BTreeMap<String, usize> {
190-
let mut counts = BTreeMap::new();
191-
for node in expr.iter() {
192-
if let CellLang::Cell(name, _) = node {
193-
*counts.entry(name.to_string()).or_insert(0) += 1;
194-
}
195-
}
196-
counts
197-
}
198-
199217
#[derive(Debug, Serialize)]
200218
struct CircuitStats {
201219
/// AST size of the circuit
202220
ast_size: usize,
203221
/// Number of cells in the circuit
204222
cell_counts: BTreeMap<String, usize>,
223+
/// The area of the circuit
224+
area: f32,
225+
}
226+
227+
impl CircuitStats {
228+
fn get_cell_counts(expr: &RecExpr<CellLang>) -> BTreeMap<String, usize> {
229+
let mut counts = BTreeMap::new();
230+
for node in expr.iter() {
231+
if let CellLang::Cell(name, _) = node {
232+
*counts.entry(name.to_string()).or_insert(0) += 1;
233+
}
234+
}
235+
counts
236+
}
237+
238+
fn get_area(expr: &RecExpr<CellLang>) -> f32 {
239+
expr.iter()
240+
.map(|n| {
241+
if let CellLang::Cell(name, _) = n {
242+
PrimitiveType::from_str(name.as_str())
243+
.unwrap()
244+
.get_min_area()
245+
.unwrap_or(1.33)
246+
} else if matches!(n, CellLang::Const(_) | CellLang::Var(_)) {
247+
PrimitiveType::INV.get_min_area().unwrap()
248+
} else {
249+
0.0
250+
}
251+
})
252+
.sum()
253+
}
254+
255+
fn new(expr: &RecExpr<CellLang>) -> Self {
256+
Self {
257+
ast_size: expr.len(),
258+
cell_counts: Self::get_cell_counts(expr),
259+
area: Self::get_area(expr),
260+
}
261+
}
205262
}
206263

207264
/// An empty report struct for synthesizing CellLang
@@ -235,14 +292,8 @@ impl Report<CellLang> for CellRpt {
235292
{
236293
Ok(CellRpt::new(
237294
"top".to_string(),
238-
CircuitStats {
239-
ast_size: input.len(),
240-
cell_counts: get_cell_counts(input),
241-
},
242-
CircuitStats {
243-
ast_size: output.len(),
244-
cell_counts: get_cell_counts(output),
245-
},
295+
CircuitStats::new(input),
296+
CircuitStats::new(output),
246297
))
247298
}
248299

src/bin/cellmap.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ struct Args {
2828
#[arg(long)]
2929
dump_graph: Option<PathBuf>,
3030

31-
/// Return an error if the graph does not reach saturation
31+
/// Use a cost model that weighs the cells by exact area
3232
#[arg(short = 'a', long, default_value_t = false)]
33-
assert_sat: bool,
33+
area: bool,
3434

3535
/// Perform an exact extraction using ILP (much slower)
3636
#[cfg(feature = "exactness")]
@@ -118,12 +118,6 @@ fn main() -> std::io::Result<()> {
118118
}
119119
};
120120

121-
let req = if args.assert_sat {
122-
req.with_asserts()
123-
} else {
124-
req
125-
};
126-
127121
let req = if args.verbose { req.with_proof() } else { req };
128122

129123
let req = if args.report.is_some() {
@@ -140,6 +134,8 @@ fn main() -> std::io::Result<()> {
140134

141135
let req = if args.min_depth {
142136
req.with_min_depth()
137+
} else if args.area {
138+
req.with_area()
143139
} else {
144140
req.with_k(args.k)
145141
};

src/bin/opt.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,14 @@ fn get_main_runner(
3131
/// parse an expression, simplify it with DSD and at most 4 fan-in, and pretty print it back out
3232
fn simplify(s: &str) -> String {
3333
let mut req = get_main_runner(s).unwrap();
34-
req.simplify_expr::<SynthReport>()
35-
.unwrap()
36-
.get_expr()
37-
.to_string()
34+
req.synth::<SynthReport>().unwrap().get_expr().to_string()
3835
}
3936

4037
#[allow(dead_code)]
4138
/// parse an expression, simplify it with DSD and at most 4 fan-in, and pretty print it back out
4239
fn simplify_w_proof(s: &str) -> String {
4340
let mut req = get_main_runner(s).unwrap().with_proof();
44-
req.simplify_expr::<SynthReport>()
45-
.unwrap()
46-
.get_expr()
47-
.to_string()
41+
req.synth::<SynthReport>().unwrap().get_expr().to_string()
4842
}
4943

5044
/// Technology Mapping Optimization with E-Graphs

src/bin/optcell.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,14 @@ fn get_main_runner(
2626
/// parse an expression, simplify it with DSD and at most 4 fan-in, and pretty print it back out
2727
fn simplify(s: &str) -> String {
2828
let mut req = get_main_runner(s).unwrap();
29-
req.simplify_expr::<CellRpt>()
30-
.unwrap()
31-
.get_expr()
32-
.to_string()
29+
req.synth::<CellRpt>().unwrap().get_expr().to_string()
3330
}
3431

3532
#[allow(dead_code)]
3633
/// parse an expression, simplify it with DSD and at most 4 fan-in, and pretty print it back out
3734
fn simplify_w_proof(s: &str) -> String {
3835
let mut req = get_main_runner(s).unwrap().with_proof();
39-
req.simplify_expr::<CellRpt>()
40-
.unwrap()
41-
.get_expr()
42-
.to_string()
36+
req.synth::<CellRpt>().unwrap().get_expr().to_string()
4337
}
4438

4539
/// ASIC Technology Mapping Optimization with E-Graphs
@@ -207,7 +201,7 @@ fn simple_tests() {
207201
#[test]
208202
fn cell_rpt() {
209203
let mut req = get_main_runner("(INV a)").unwrap().with_report();
210-
let result = req.simplify_expr::<CellRpt>().unwrap();
204+
let result = req.synth::<CellRpt>().unwrap();
211205
let rpt = result.write_report_to_string();
212206
assert!(rpt.is_ok());
213207
let rpt = rpt.unwrap();

src/driver.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ enum BuildStrat {
341341
/// Only [ExtractStrat::Exact] uses ILP.
342342
#[derive(Debug, Clone)]
343343
enum ExtractStrat {
344+
/// Extract the cirucit using exact cell areas.
345+
Area,
344346
/// Extract maximum circuit depth (RAM bomb).
345347
MaxDepth,
346348
/// Extract minimum circuit depth.
@@ -430,6 +432,9 @@ where
430432
Self::cell_cost_with_reg_weight_fn(cut_size, 1)
431433
}
432434

435+
/// Returns the cost function using exact cell areas.
436+
fn exact_area_cost_fn() -> impl CostFunction<Self>;
437+
433438
/// Returns a cost function used for extracting only certain types nodes.
434439
fn filter_cost_fn(set: HashSet<String>) -> impl CostFunction<Self>;
435440
}
@@ -575,6 +580,14 @@ where
575580
}
576581
}
577582

583+
/// Request greedy extraction using exact cell areas.
584+
pub fn with_area(self) -> Self {
585+
Self {
586+
extract_strat: ExtractStrat::Area,
587+
..self
588+
}
589+
}
590+
578591
/// Request exact LUT extraction using ILP with `timeout` in seconds.
579592
#[cfg(feature = "exactness")]
580593
pub fn with_exactness(self, timeout: u64) -> Self {
@@ -988,12 +1001,13 @@ where
9881001
L::get_explanations(root_expr, best, runner)
9891002
}
9901003

991-
/// Simplify expression with the extraction strategy set in `self`.
992-
pub fn simplify_expr<R>(&mut self) -> Result<SynthOutput<L, R>, String>
1004+
/// Synthesize with the extraction strategy set in `self`.
1005+
pub fn synth<R>(&mut self) -> Result<SynthOutput<L, R>, String>
9931006
where
9941007
R: Report<L>,
9951008
{
9961009
match self.extract_strat.to_owned() {
1010+
ExtractStrat::Area => self.greedy_extract_with(L::exact_area_cost_fn()),
9971011
ExtractStrat::MinDepth => self.greedy_extract_with(L::depth_cost_fn()),
9981012
ExtractStrat::MaxDepth => {
9991013
eprintln!("WARNING: Maximizing cost on e-graphs with cycles will crash.");
@@ -1062,7 +1076,7 @@ where
10621076
let mut req = req.with_expr(expr.clone());
10631077

10641078
let result = req
1065-
.simplify_expr()
1079+
.synth()
10661080
.map_err(|s| std::io::Error::new(std::io::ErrorKind::Other, s))?;
10671081

10681082
#[cfg(feature = "graph_dumps")]

src/lib.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ mod tests {
2525
use std::collections::HashMap;
2626

2727
use analysis::LutAnalysis;
28-
use asic::CellLang;
29-
use driver::Canonical;
28+
use asic::{CellAnalysis, CellLang, CellRpt, asic_rewrites};
29+
use driver::{Canonical, SynthRequest};
3030
use egg::{Analysis, Language, RecExpr};
3131
use lut::{LutExprInfo, LutLang};
3232
use verilog::{PrimitiveType, SVModule, sv_parse_wrapper};
@@ -967,7 +967,25 @@ endmodule\n"
967967
PrimitiveType::AOI22.get_input_list(),
968968
vec!["A1", "A2", "B1", "B2"]
969969
);
970+
971+
assert_eq!(
972+
PrimitiveType::AOI211.get_input_list(),
973+
vec!["A", "B", "C1", "C2"]
974+
);
975+
976+
assert_eq!(
977+
PrimitiveType::AOI221.get_input_list(),
978+
vec!["A", "B1", "B2", "C1", "C2"]
979+
);
980+
981+
assert_eq!(PrimitiveType::XOR2.get_output(), "Z".to_string());
982+
970983
// LUT input list is backwards relative to the IR
984+
assert_eq!(
985+
PrimitiveType::LUT5.get_input_list(),
986+
vec!["I4", "I3", "I2", "I1", "I0"]
987+
);
988+
971989
assert_eq!(
972990
PrimitiveType::LUT6.get_input_list(),
973991
vec!["I5", "I4", "I3", "I2", "I1", "I0"]
@@ -1030,4 +1048,24 @@ endmodule\n"
10301048
let expr3: RecExpr<LutLang> = "(LUT 51952 s0 (LUT 61643 s0 s1 b d) a c)".parse().unwrap();
10311049
assert!(LutLang::func_equiv(&expr2, &expr3).is_not_equiv());
10321050
}
1051+
1052+
#[test]
1053+
fn test_cell_area() {
1054+
assert!(PrimitiveType::INV.get_min_area().is_some());
1055+
assert_eq!(PrimitiveType::INV.get_min_area().unwrap(), 0.532);
1056+
let expr: RecExpr<CellLang> = "(INV a)".parse().unwrap();
1057+
let mut req: SynthRequest<CellLang, CellAnalysis> = SynthRequest::default()
1058+
.with_expr(expr)
1059+
.with_report()
1060+
.with_rules(asic_rewrites())
1061+
.without_progress_bar();
1062+
let result = req
1063+
.synth::<CellRpt>()
1064+
.unwrap()
1065+
.write_report_to_string()
1066+
.unwrap();
1067+
1068+
assert!(result.contains("area"));
1069+
assert!(result.contains("\"area\": 1.064"));
1070+
}
10331071
}

src/lut.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,10 @@ impl Extractable for LutLang {
10991099
KLUTCostFn::new(cut_size).with_reg_weight(w)
11001100
}
11011101

1102+
fn exact_area_cost_fn() -> impl CostFunction<Self> {
1103+
KLUTCostFn::new(6).with_reg_weight(1)
1104+
}
1105+
11021106
fn filter_cost_fn(set: std::collections::HashSet<String>) -> impl CostFunction<Self> {
11031107
GateCostFn::new(set)
11041108
}

src/rewrite.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -930,11 +930,7 @@ pub mod decomp {
930930
.without_progress_bar()
931931
.with_joint_limits(20, 20_000, 30);
932932

933-
let ans = req
934-
.simplify_expr::<SynthReport>()
935-
.unwrap()
936-
.get_expr()
937-
.to_string();
933+
let ans = req.synth::<SynthReport>().unwrap().get_expr().to_string();
938934
assert_eq!(ans, "(LUT 202 s1 s0 (LUT 202 s0 c d))");
939935
}
940936
}

0 commit comments

Comments
 (0)