leo_passes/loop_unrolling/
unroller.rsuse leo_ast::{
Block,
Expression,
ExpressionReconstructor,
IterationStatement,
Literal,
Node,
NodeBuilder,
NodeID,
Statement,
StatementReconstructor,
Type,
Value,
};
use leo_errors::{emitter::Handler, loop_unroller::LoopUnrollerError};
use leo_span::{Span, Symbol};
use crate::{Clusivity, LoopBound, RangeIterator, SymbolTable, TypeTable};
pub struct Unroller<'a> {
pub(crate) symbol_table: &'a mut SymbolTable,
pub(crate) type_table: &'a TypeTable,
pub(crate) handler: &'a Handler,
pub(crate) node_builder: &'a NodeBuilder,
pub(crate) is_unrolling: bool,
pub(crate) current_program: Option<Symbol>,
pub(crate) loop_not_unrolled: Option<Span>,
pub(crate) loop_unrolled: bool,
}
impl<'a> Unroller<'a> {
pub(crate) fn new(
symbol_table: &'a mut SymbolTable,
type_table: &'a TypeTable,
handler: &'a Handler,
node_builder: &'a NodeBuilder,
) -> Self {
Self {
symbol_table,
type_table,
handler,
node_builder,
is_unrolling: false,
current_program: None,
loop_not_unrolled: None,
loop_unrolled: false,
}
}
pub(crate) fn in_scope<T>(&mut self, id: NodeID, func: impl FnOnce(&mut Self) -> T) -> T {
self.symbol_table.enter_scope(Some(id));
let result = func(self);
self.symbol_table.enter_parent();
result
}
pub(crate) fn emit_err(&self, err: LoopUnrollerError) {
self.handler.emit_err(err);
}
pub(crate) fn unroll_iteration_statement<I: LoopBound>(&mut self, input: IterationStatement) -> Statement {
let start: Value = input.start_value.borrow().as_ref().expect("Failed to get start value").clone();
let stop: Value = input.stop_value.borrow().as_ref().expect("Failed to get stop value").clone();
let cast_to_number = |v: Value| -> Result<I, Statement> {
match v.try_into() {
Ok(val_as_u128) => Ok(val_as_u128),
Err(err) => {
self.handler.emit_err(err);
Err(Statement::dummy(input.span, self.node_builder.next_id()))
}
}
};
let start = match cast_to_number(start) {
Ok(v) => v,
Err(s) => return s,
};
let stop = match cast_to_number(stop) {
Ok(v) => v,
Err(s) => return s,
};
let new_block_id = self.node_builder.next_id();
self.in_scope(new_block_id, |slf| {
Statement::Block(Block {
span: input.span,
statements: match input.inclusive {
true => {
let iter = RangeIterator::new(start, stop, Clusivity::Inclusive);
iter.map(|iteration_count| slf.unroll_single_iteration(&input, iteration_count)).collect()
}
false => {
let iter = RangeIterator::new(start, stop, Clusivity::Exclusive);
iter.map(|iteration_count| slf.unroll_single_iteration(&input, iteration_count)).collect()
}
},
id: new_block_id,
})
})
}
fn unroll_single_iteration<I: LoopBound>(&mut self, input: &IterationStatement, iteration_count: I) -> Statement {
let const_id = self.node_builder.next_id();
self.type_table.insert(const_id, input.type_.clone());
let outer_block_id = self.node_builder.next_id();
let Type::Integer(integer_type) = &input.type_ else {
unreachable!("Type checking enforces that the iteration variable is of integer type");
};
self.in_scope(outer_block_id, |slf| {
let value = Literal::Integer(*integer_type, iteration_count.to_string(), Default::default(), const_id);
slf.symbol_table.insert_const(
slf.current_program.unwrap(),
input.variable.name,
Expression::Literal(value),
);
let duplicated_body = super::duplicate::duplicate(input.block.clone(), slf.symbol_table, slf.node_builder);
let prior_is_unrolling = slf.is_unrolling;
slf.is_unrolling = true;
let result = Statement::Block(slf.reconstruct_block(duplicated_body).0);
slf.is_unrolling = prior_is_unrolling;
Statement::Block(Block { statements: vec![result], span: input.span(), id: outer_block_id })
})
}
}
impl ExpressionReconstructor for Unroller<'_> {
type AdditionalOutput = ();
}