use leo_ast::{
Block,
Expression,
IntegerType,
IterationStatement,
Literal,
NodeBuilder,
Statement,
StatementReconstructor,
Type,
Value,
};
use std::cell::RefCell;
use leo_errors::{emitter::Handler, loop_unroller::LoopUnrollerError};
use leo_span::Symbol;
use crate::{
Clusivity,
LoopBound,
RangeIterator,
SymbolTable,
TypeTable,
constant_propagation_table::ConstantPropagationTable,
};
pub struct Unroller<'a> {
pub(crate) constant_propagation_table: RefCell<ConstantPropagationTable>,
pub(crate) symbol_table: RefCell<SymbolTable>,
pub(crate) type_table: &'a TypeTable,
pub(crate) scope_index: usize,
pub(crate) handler: &'a Handler,
pub(crate) node_builder: &'a NodeBuilder,
pub(crate) is_unrolling: bool,
pub(crate) current_program: Option<Symbol>,
}
impl<'a> Unroller<'a> {
pub(crate) fn new(
symbol_table: SymbolTable,
type_table: &'a TypeTable,
handler: &'a Handler,
node_builder: &'a NodeBuilder,
) -> Self {
Self {
constant_propagation_table: RefCell::new(ConstantPropagationTable::default()),
symbol_table: RefCell::new(symbol_table),
type_table,
scope_index: 0,
handler,
node_builder,
is_unrolling: false,
current_program: None,
}
}
pub(crate) fn current_scope_index(&mut self) -> usize {
if self.is_unrolling { self.symbol_table.borrow_mut().insert_block() } else { self.scope_index }
}
pub(crate) fn enter_scope(&mut self, index: usize) -> usize {
let previous_symbol_table = std::mem::take(&mut self.symbol_table);
self.symbol_table.swap(previous_symbol_table.borrow().lookup_scope_by_index(index).unwrap());
self.symbol_table.borrow_mut().parent = Some(Box::new(previous_symbol_table.into_inner()));
self.constant_propagation_table.borrow_mut().insert_block();
let previous_constant_propagation_table = std::mem::take(&mut self.constant_propagation_table);
self.constant_propagation_table
.swap(previous_constant_propagation_table.borrow().lookup_scope_by_index(index).unwrap());
self.constant_propagation_table.borrow_mut().parent =
Some(Box::new(previous_constant_propagation_table.into_inner()));
core::mem::replace(&mut self.scope_index, 0)
}
pub(crate) fn exit_scope(&mut self, index: usize) {
let prev_ct = *self.constant_propagation_table.borrow_mut().parent.take().unwrap();
self.constant_propagation_table.swap(prev_ct.lookup_scope_by_index(index).unwrap());
self.constant_propagation_table = RefCell::new(prev_ct);
let prev_st = *self.symbol_table.borrow_mut().parent.take().unwrap();
self.symbol_table.swap(prev_st.lookup_scope_by_index(index).unwrap());
self.symbol_table = RefCell::new(prev_st);
self.scope_index = index + 1;
}
pub(crate) fn emit_err(&self, err: LoopUnrollerError) {
self.handler.emit_err(err);
}
pub(crate) fn unroll_iteration_statement<I: LoopBound>(&mut self, input: IterationStatement) -> Statement {
let start: Value = input.start_value.borrow().as_ref().expect("Failed to get start value").clone();
let stop: Value = input.stop_value.borrow().as_ref().expect("Failed to get stop value").clone();
let cast_to_number = |v: Value| -> Result<I, Statement> {
match v.try_into() {
Ok(val_as_u128) => Ok(val_as_u128),
Err(err) => {
self.handler.emit_err(err);
Err(Statement::dummy(input.span, self.node_builder.next_id()))
}
}
};
let start = match cast_to_number(start) {
Ok(v) => v,
Err(s) => return s,
};
let stop = match cast_to_number(stop) {
Ok(v) => v,
Err(s) => return s,
};
let scope_index = self.current_scope_index();
let previous_scope_index = self.enter_scope(scope_index);
self.symbol_table.borrow_mut().variables.clear();
self.symbol_table.borrow_mut().scopes.clear();
self.symbol_table.borrow_mut().scope_index = 0;
let iter_blocks = Statement::Block(Block {
span: input.span,
statements: match input.inclusive {
true => {
let iter = RangeIterator::new(start, stop, Clusivity::Inclusive);
iter.map(|iteration_count| self.unroll_single_iteration(&input, iteration_count)).collect()
}
false => {
let iter = RangeIterator::new(start, stop, Clusivity::Exclusive);
iter.map(|iteration_count| self.unroll_single_iteration(&input, iteration_count)).collect()
}
},
id: input.id,
});
self.exit_scope(previous_scope_index);
iter_blocks
}
fn unroll_single_iteration<I: LoopBound>(&mut self, input: &IterationStatement, iteration_count: I) -> Statement {
let scope_index = self.symbol_table.borrow_mut().insert_block();
let previous_scope_index = self.enter_scope(scope_index);
let prior_is_unrolling = self.is_unrolling;
self.is_unrolling = true;
let id = self.node_builder.next_id();
self.type_table.insert(id, input.type_.clone());
let value = match input.type_ {
Type::Integer(IntegerType::I8) => {
Literal::Integer(IntegerType::I8, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::I16) => {
Literal::Integer(IntegerType::I16, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::I32) => {
Literal::Integer(IntegerType::I32, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::I64) => {
Literal::Integer(IntegerType::I64, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::I128) => {
Literal::Integer(IntegerType::I128, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::U8) => {
Literal::Integer(IntegerType::U8, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::U16) => {
Literal::Integer(IntegerType::U16, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::U32) => {
Literal::Integer(IntegerType::U32, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::U64) => {
Literal::Integer(IntegerType::U64, iteration_count.to_string(), Default::default(), id)
}
Type::Integer(IntegerType::U128) => {
Literal::Integer(IntegerType::U128, iteration_count.to_string(), Default::default(), id)
}
_ => unreachable!(
"The iteration variable must be an integer type. This should be enforced by type checking."
),
};
self.constant_propagation_table
.borrow_mut()
.insert_constant(input.variable.name, Expression::Literal(value.clone()))
.expect("Failed to insert constant into CPT");
let statements: Vec<_> = input
.block
.statements
.clone()
.into_iter()
.filter_map(|s| {
let (reconstructed_statement, additional_output) = self.reconstruct_statement(s);
if additional_output {
None } else {
Some(reconstructed_statement)
}
})
.collect();
let block = Statement::Block(Block { statements, span: input.block.span, id: input.block.id });
self.is_unrolling = prior_is_unrolling;
self.exit_scope(previous_scope_index);
block
}
}