use super::*;
use crate::data_keeper::MergeCtx;
use air_interpreter_data::FoldSubTraceLore;
use air_interpreter_data::SubTraceDesc;
use air_interpreter_data::TracePos;
use std::collections::HashMap;
pub type FoldStatesCount = u32;
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ResolvedFold {
pub lore: HashMap<TracePos, ResolvedSubTraceDescs>,
pub fold_states_count: FoldStatesCount,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ResolvedSubTraceDescs {
pub before_subtrace: SubTraceDesc,
pub after_subtrace: SubTraceDesc,
}
pub(super) fn resolve_fold_lore(fold: &FoldResult, merge_ctx: &MergeCtx) -> MergeResult<ResolvedFold> {
let (fold_states_count, lens) = compute_lens_convolution(fold, merge_ctx)?;
let lore = fold.lore.iter().zip(lens).try_fold::<_, _, MergeResult<_>>(
HashMap::with_capacity(fold.lore.len()),
|mut resolved_lore, (lore, lens)| {
let before_subtrace = SubTraceDesc::new(lore.subtraces_desc[0].begin_pos, lens.before_len as _);
let after_subtrace = SubTraceDesc::new(lore.subtraces_desc[1].begin_pos, lens.after_len as _);
let resolved_descs = ResolvedSubTraceDescs::new(before_subtrace, after_subtrace);
match resolved_lore.insert(lore.value_pos, resolved_descs) {
Some(_) => {
Err(FoldResultError::SeveralRecordsWithSamePos(fold.clone(), lore.value_pos)).map_err(Into::into)
}
None => Ok(resolved_lore),
}
},
)?;
let resolved_fold_lore = ResolvedFold::new(lore, fold_states_count);
Ok(resolved_fold_lore)
}
fn compute_lens_convolution(fold: &FoldResult, merge_ctx: &MergeCtx) -> MergeResult<(FoldStatesCount, Vec<LoresLen>)> {
let subtraces_count = fold.lore.len();
let mut lens = Vec::with_capacity(subtraces_count);
let mut fold_states_count: FoldStatesCount = 0;
let mut last_seen_generation = GenerationIdx::from(0);
let mut last_seen_generation_pos = 0;
let mut cum_after_len = 0;
for subtrace_id in 0..subtraces_count {
let subtrace_lore = &fold.lore[subtrace_id];
check_subtrace_lore(subtrace_lore)?;
let current_generation = merge_ctx.try_get_generation(subtrace_lore.value_pos)?;
if last_seen_generation != current_generation {
if subtrace_id > 0 {
compute_before_lens(&mut lens, last_seen_generation_pos, subtrace_id - 1);
}
last_seen_generation = current_generation;
last_seen_generation_pos = subtrace_id;
cum_after_len = 0;
}
let before_len = subtrace_lore.subtraces_desc[0].subtrace_len;
let after_len = subtrace_lore.subtraces_desc[1].subtrace_len;
fold_states_count = fold_states_count
.checked_add(before_len)
.and_then(|v| v.checked_add(after_len))
.ok_or_else(|| FoldResultError::SubtraceLenOverflow {
fold_result: fold.clone(),
count: subtrace_id,
})?;
cum_after_len += after_len;
let new_lens = LoresLen::new(before_len, cum_after_len);
lens.push(new_lens);
}
if subtraces_count > 0 {
compute_before_lens(&mut lens, last_seen_generation_pos, subtraces_count - 1);
}
Ok((fold_states_count, lens))
}
fn compute_before_lens(lore_lens: &mut [LoresLen], begin_pos: usize, end_pos: usize) {
let mut cum_before_len = 0;
let after_len = lore_lens[end_pos].after_len;
for subtrace_id in (begin_pos..=end_pos).rev() {
let before_len = &mut lore_lens[subtrace_id].before_len;
cum_before_len += *before_len;
*before_len = cum_before_len + after_len;
}
}
fn check_subtrace_lore(subtrace_lore: &FoldSubTraceLore) -> MergeResult<()> {
const SUBTRACE_DESC_COUNT: usize = 2;
if subtrace_lore.subtraces_desc.len() != SUBTRACE_DESC_COUNT {
return Err(FoldResultError::FoldIncorrectSubtracesCount(
subtrace_lore.subtraces_desc.len(),
))
.map_err(Into::into);
}
Ok(())
}
impl ResolvedFold {
pub fn new(lore: HashMap<TracePos, ResolvedSubTraceDescs>, fold_states_count: FoldStatesCount) -> Self {
Self {
lore,
fold_states_count,
}
}
}
impl ResolvedSubTraceDescs {
pub fn new(before_subtrace: SubTraceDesc, after_subtrace: SubTraceDesc) -> Self {
Self {
before_subtrace,
after_subtrace,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
struct LoresLen {
before_len: u32,
after_len: u32,
}
impl LoresLen {
fn new(before_len: u32, after_len: u32) -> Self {
Self { before_len, after_len }
}
}
#[cfg(test)]
mod tests {
use super::compute_lens_convolution;
use crate::data_keeper::TraceSlider;
use crate::merger::fold_merger::fold_lore_resolver::LoresLen;
use crate::MergeCtx;
use air_interpreter_data::ApResult;
use air_interpreter_data::ExecutedState;
use air_interpreter_data::FoldResult;
use air_interpreter_data::FoldSubTraceLore;
use air_interpreter_data::SubTraceDesc;
use air_interpreter_data::TracePos;
fn subtrace_desc(begin_pos: impl Into<TracePos>, subtrace_len: u32) -> SubTraceDesc {
SubTraceDesc {
begin_pos: begin_pos.into(),
subtrace_len,
}
}
#[test]
fn empty_fold_result() {
let lore = vec![];
let fold_result = FoldResult { lore };
let slider = TraceSlider::new(vec![]);
let ctx = MergeCtx { slider };
let (all_states, convoluted_lens) =
compute_lens_convolution(&fold_result, &ctx).expect("convolution should be successful");
assert_eq!(all_states, 0);
assert_eq!(convoluted_lens, vec![]);
}
#[test]
fn convolution_test_1() {
let lore = vec![
FoldSubTraceLore {
value_pos: 0.into(),
subtraces_desc: vec![subtrace_desc(0, 1), subtrace_desc(0, 1)],
},
FoldSubTraceLore {
value_pos: 0.into(),
subtraces_desc: vec![subtrace_desc(0, 2), subtrace_desc(0, 2)],
},
FoldSubTraceLore {
value_pos: 0.into(),
subtraces_desc: vec![subtrace_desc(0, 3), subtrace_desc(0, 3)],
},
];
let fold_result = FoldResult { lore };
let slider = TraceSlider::new(vec![ExecutedState::Ap(ApResult::new(0.into()))]);
let ctx = MergeCtx { slider };
let (all_states, convoluted_lens) =
compute_lens_convolution(&fold_result, &ctx).expect("convolution should be successful");
assert_eq!(all_states, 12);
let expected_lens = vec![LoresLen::new(12, 1), LoresLen::new(11, 3), LoresLen::new(9, 6)];
assert_eq!(convoluted_lens, expected_lens);
}
#[test]
fn convolution_test_2() {
let lore = vec![
FoldSubTraceLore {
value_pos: 0.into(),
subtraces_desc: vec![subtrace_desc(0, 1), subtrace_desc(0, 1)],
},
FoldSubTraceLore {
value_pos: 0.into(),
subtraces_desc: vec![subtrace_desc(0, 2), subtrace_desc(0, 2)],
},
FoldSubTraceLore {
value_pos: 0.into(),
subtraces_desc: vec![subtrace_desc(0, 3), subtrace_desc(0, 3)],
},
FoldSubTraceLore {
value_pos: 1.into(),
subtraces_desc: vec![subtrace_desc(0, 4), subtrace_desc(0, 4)],
},
FoldSubTraceLore {
value_pos: 1.into(),
subtraces_desc: vec![subtrace_desc(0, 5), subtrace_desc(0, 5)],
},
FoldSubTraceLore {
value_pos: 2.into(),
subtraces_desc: vec![subtrace_desc(0, 1), subtrace_desc(0, 1)],
},
];
let fold_result = FoldResult { lore };
let slider = TraceSlider::new(vec![
ExecutedState::Ap(ApResult::new(0.into())),
ExecutedState::Ap(ApResult::new(1.into())),
ExecutedState::Ap(ApResult::new(2.into())),
]);
let ctx = MergeCtx { slider };
let (all_states, convoluted_lens) =
compute_lens_convolution(&fold_result, &ctx).expect("convolution should be successful");
assert_eq!(all_states, 32);
let expected_lens = vec![
LoresLen::new(12, 1),
LoresLen::new(11, 3),
LoresLen::new(9, 6),
LoresLen::new(18, 4),
LoresLen::new(14, 9),
LoresLen::new(2, 1),
];
assert_eq!(convoluted_lens, expected_lens);
}
}