refactor: ArrowInliner refactoring (#777)

This commit is contained in:
Dima 2023-06-29 14:43:38 +03:00 committed by GitHub
parent 2985baadfc
commit 339d3a8217
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 260 additions and 120 deletions

View File

@ -1,14 +1,14 @@
package aqua.model.inline
import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler}
import aqua.model
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.model.*
import aqua.raw.ops.RawTag
import aqua.types.ArrowType
import aqua.types.{ArrowType, BoxType, StreamType}
import aqua.raw.value.{ValueRaw, VarRaw}
import aqua.types.{BoxType, StreamType}
import cats.data.{Chain, State, StateT}
import cats.syntax.traverse.*
import cats.syntax.show.*
import cats.Eval
import cats.data.{Chain, State}
import cats.syntax.bifunctor.*
import scribe.Logging
/**
@ -24,68 +24,171 @@ object ArrowInliner extends Logging {
): State[S, OpModel.Tree] =
callArrowRet(arrow, call).map(_._1)
// Get streams that was declared outside of a function
private def getOutsideStreamNames[S: Exports]: State[S, Set[String]] =
Exports[S].exports
.map(exports =>
exports.collect { case (n, VarModel(_, StreamType(_), _)) =>
n
}.toSet
)
// push results to streams if they are exported to streams
private def pushStreamResults[S: Mangler: Exports: Arrows](
outsideStreamNames: Set[String],
exportTo: List[CallModel.Export],
results: List[ValueRaw],
body: OpModel.Tree
): State[S, (List[OpModel.Tree], List[ValueModel])] =
for {
// Fix return values with exports collected in the body
resolvedResult <- RawValueInliner.valueListToModel(results)
} yield {
// Fix the return values
val (ops, rets) = (exportTo zip resolvedResult).map {
case (
CallModel.Export(n, StreamType(_)),
(res @ VarModel(_, StreamType(_), _), resDesugar)
) if !outsideStreamNames.contains(n) =>
resDesugar.toList -> res
case (CallModel.Export(exp, st @ StreamType(_)), (res, resDesugar)) =>
// pass nested function results to a stream
(resDesugar.toList :+ PushToStreamModel(
res,
CallModel.Export(exp, st)
).leaf) -> VarModel(
exp,
st,
Chain.empty
)
case (_, (res, resDesugar)) =>
resDesugar.toList -> res
}.foldLeft[(List[OpModel.Tree], List[ValueModel])](
(body :: Nil, Nil)
) { case ((ops, rets), (fo, r)) =>
(fo ::: ops, r :: rets)
}
(ops, rets)
}
// Apply a callable function, get its fully resolved body & optional value, if any
private def inline[S: Mangler: Arrows: Exports](
fn: FuncArrow,
call: CallModel
): State[S, (OpModel.Tree, List[ValueModel])] =
Exports[S].exports
.map(exports =>
exports.collect { case e @ (_, VarModel(_, StreamType(_), _)) =>
e
}
getOutsideStreamNames.flatMap { outsideDeclaredStreams =>
// Function's internal variables will not be available outside, hence the scope
Exports[S].scope(
for {
// Process renamings, prepare environment
tr <- prelude[S](fn, call)
(tree, results) = tr
// Register captured values as available exports
_ <- Exports[S].resolved(fn.capturedValues)
_ <- Mangler[S].forbid(fn.capturedValues.keySet)
// Now, substitute the arrows that were received as function arguments
// Use the new op tree (args are replaced with values, names are unique & safe)
callableFuncBodyNoTopology <- TagInliner.handleTree(tree, fn.funcName)
callableFuncBody =
fn.capturedTopology
.fold[OpModel](SeqModel)(ApplyTopologyModel.apply)
.wrap(callableFuncBodyNoTopology)
opsAndRets <- pushStreamResults(
outsideDeclaredStreams,
call.exportTo,
results,
callableFuncBody
)
(ops, rets) = opsAndRets
} yield SeqModel.wrap(ops.reverse: _*) -> rets.reverse
)
.flatMap { outsideDeclaredStreams =>
// Function's internal variables will not be available outside, hence the scope
Exports[S].scope(
for {
// Process renamings, prepare environment
tr <- prelude[S](fn, call)
(tree, result) = tr
}
// Register captured values as available exports
_ <- Exports[S].resolved(fn.capturedValues)
_ <- Mangler[S].forbid(fn.capturedValues.keySet)
// Now, substitute the arrows that were received as function arguments
// Use the new op tree (args are replaced with values, names are unique & safe)
callableFuncBodyNoTopology <- TagInliner.handleTree(tree, fn.funcName)
callableFuncBody =
fn.capturedTopology
.fold[OpModel](SeqModel)(ApplyTopologyModel.apply)
.wrap(callableFuncBodyNoTopology)
// Fix return values with exports collected in the body
resolvedResult <- RawValueInliner.valueListToModel(result)
// Fix the return values
(ops, rets) = (call.exportTo zip resolvedResult)
.map[(List[OpModel.Tree], ValueModel)] {
case (
CallModel.Export(n, StreamType(_)),
(res @ VarModel(_, StreamType(_), _), resDesugar)
) if !outsideDeclaredStreams.contains(n) =>
resDesugar.toList -> res
case (CallModel.Export(exp, st @ StreamType(_)), (res, resDesugar)) =>
// pass nested function results to a stream
(resDesugar.toList :+ PushToStreamModel(
res,
CallModel.Export(exp, st)
).leaf) -> VarModel(
exp,
st,
Chain.empty
)
case (_, (res, resDesugar)) =>
resDesugar.toList -> res
}
.foldLeft[(List[OpModel.Tree], List[ValueModel])](
(callableFuncBody :: Nil, Nil)
) { case ((ops, rets), (fo, r)) =>
(fo ::: ops, r :: rets)
}
} yield SeqModel.wrap(ops.reverse: _*) -> rets.reverse
)
// Get all arrows that is arguments from outer Arrows.
// Purge and push captured arrows and arrows as arguments into state.
// Grab all arrows that must be renamed.
private def updateArrowsAndRenameArrowArgs[S: Mangler: Arrows: Exports](
args: ArgsCall,
func: FuncArrow
): State[S, Map[String, String]] = {
for {
// Arrow arguments: expected type is Arrow, given by-name
argsToArrowsRaw <- Arrows[S].argsArrows(args)
argsToArrowsShouldRename <- Mangler[S].findNewNames(
argsToArrowsRaw.keySet
)
argsToArrows = argsToArrowsRaw.map { case (k, v) =>
argsToArrowsShouldRename.getOrElse(k, k) -> v
}
returnedArrows = func.ret.collect { case VarRaw(name, ArrowType(_, _)) =>
name
}.toSet
returnedArrowsShouldRename <- Mangler[S].findNewNames(returnedArrows)
renamedCapturedArrows = func.capturedArrows.map { case (k, v) =>
returnedArrowsShouldRename.getOrElse(k, k) -> v
}
// Going to resolve arrows: collect them all. Names should never collide: it's semantically checked
_ <- Arrows[S].purge
_ <- Arrows[S].resolved(renamedCapturedArrows ++ argsToArrows)
} yield {
argsToArrowsShouldRename ++ returnedArrowsShouldRename
}
}
private def updateExportsAndRenameDataArgs[S: Mangler: Arrows: Exports](
args: ArgsCall
): State[S, Map[String, String]] = {
// DataType arguments
val argsToDataRaw = args.dataArgs
for {
// Find all duplicates in arguments
// we should not rename arguments that will be renamed by 'streamToRename'
argsToDataShouldRename <- Mangler[S].findNewNames(
argsToDataRaw.keySet
)
// Do not rename arguments if they just match external names
argsToData = argsToDataRaw.map { case (k, v) =>
argsToDataShouldRename.getOrElse(k, k) -> v
}
_ <- Exports[S].resolved(argsToData)
} yield argsToDataShouldRename
}
// Rename all exports-to-stream for streams that passed as arguments
private def renameStreams(
tree: RawTag.Tree,
args: ArgsCall
): RawTag.Tree = {
// Stream arguments
val streamArgs = args.streamArgs
// collect arguments with stream type
// to exclude it from resolving and rename it with a higher-level stream that passed by argument
val streamsToRename = streamArgs.view.mapValues(_.name).toMap
if (streamsToRename.isEmpty) tree
else
tree
.map(_.mapValues(_.map {
// if an argument is a BoxType (Array or Option), but we pass a stream,
// change a type as stream to not miss `$` sign in air
// @see ArrowInlinerSpec `pass stream to callback properly` test
case v @ VarRaw(name, baseType: BoxType) if streamsToRename.contains(name) =>
v.copy(baseType = StreamType(baseType.element))
case v: VarRaw if streamsToRename.contains(v.name) =>
v.copy(baseType = StreamType(v.baseType))
case v => v
}))
.renameExports(streamsToRename)
}
/**
* Prepare the state context for this function call
@ -105,69 +208,22 @@ object ArrowInliner extends Logging {
): State[S, (RawTag.Tree, List[ValueRaw])] =
for {
// Collect all arguments: what names are used inside the function, what values are received
argsFull <- State.pure(ArgsCall(fn.arrowType.domain, call.args))
// DataType arguments
argsToDataRaw = argsFull.dataArgs
// Arrow arguments: expected type is Arrow, given by-name
argsToArrowsRaw <- Arrows[S].argsArrows(argsFull)
// collect arguments with stream type
// to exclude it from resolving and rename it with a higher-level stream that passed by argument
// TODO: what if we have streams in property???
streamToRename = argsFull.streamArgs.view.mapValues(_.name).toMap
// Find all duplicates in arguments
// we should not rename arguments that will be renamed by 'streamToRename'
argsShouldRename <- Mangler[S].findNewNames(
argsToDataRaw.keySet ++ argsToArrowsRaw.keySet -- streamToRename.keySet
)
// Do not rename arguments if they just match external names
argsToData = argsToDataRaw.map { case (k, v) =>
argsShouldRename.getOrElse(k, k) -> v
}
_ <- Exports[S].resolved(argsToData)
argsToArrows = argsToArrowsRaw.map { case (k, v) => argsShouldRename.getOrElse(k, k) -> v }
returnedArrows = fn.ret.collect { case VarRaw(name, ArrowType(_, _)) =>
name
}.toSet
returnedArrowsShouldRename <- Mangler[S].findNewNames(returnedArrows)
renamedCapturedArrows = fn.capturedArrows.map { case (k, v) =>
returnedArrowsShouldRename.getOrElse(k, k) -> v
}
// Going to resolve arrows: collect them all. Names should never collide: it's semantically checked
_ <- Arrows[S].purge
_ <- Arrows[S].resolved(renamedCapturedArrows ++ argsToArrows)
args <- State.pure(ArgsCall(fn.arrowType.domain, call.args))
// Update states and rename tags
renamedArrows <- updateArrowsAndRenameArrowArgs(args, fn)
argsToDataShouldRename <- updateExportsAndRenameDataArgs(args)
allShouldRename = argsToDataShouldRename ++ renamedArrows
// Rename all renamed arguments in the body
treeRenamed =
fn.body
.rename(argsShouldRename)
.rename(returnedArrowsShouldRename)
.map(_.mapValues(_.map {
// if an argument is a BoxType (Array or Option), but we pass a stream,
// change a type as stream to not miss `$` sign in air
// @see ArrowInlinerSpec `pass stream to callback properly` test
case v @ VarRaw(name, baseType: BoxType) if streamToRename.contains(name) =>
v.copy(baseType = StreamType(baseType.element))
case v: VarRaw if streamToRename.contains(v.name) =>
v.copy(baseType = StreamType(v.baseType))
case v => v
}))
.renameExports(streamToRename)
treeRenamed = fn.body.rename(allShouldRename)
treeStreamsRenamed = renameStreams(treeRenamed, args)
// Function body on its own defines some values; collect their names
// except stream arguments. They should be already renamed
treeDefines =
treeRenamed.definesVarNames.value --
argsFull.streamArgs.keySet --
argsFull.streamArgs.values.map(_.name) --
treeStreamsRenamed.definesVarNames.value --
args.streamArgs.keySet --
args.streamArgs.values.map(_.name) --
call.exportTo.filter { exp =>
exp.`type` match {
case StreamType(_) => false
@ -176,14 +232,14 @@ object ArrowInliner extends Logging {
}.map(_.name)
// We have some names in scope (forbiddenNames), can't introduce them again; so find new names
shouldRename <- Mangler[S].findNewNames(treeDefines).map(_ ++ argsShouldRename)
shouldRename <- Mangler[S].findNewNames(treeDefines).map(_ ++ allShouldRename)
_ <- Mangler[S].forbid(treeDefines ++ shouldRename.values.toSet)
// If there was a collision, rename exports and usages with new names
tree = treeRenamed.rename(shouldRename)
tree = treeStreamsRenamed.rename(shouldRename)
// Result could be renamed; take care about that
} yield (tree, fn.ret.map(_.renameVars(shouldRename ++ returnedArrowsShouldRename)))
} yield (tree, fn.ret.map(_.renameVars(shouldRename)))
private[inline] def callArrowRet[S: Exports: Arrows: Mangler](
arrow: FuncArrow,

View File

@ -1660,4 +1660,88 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
) should be(true)
}
/*
service Get("get"):
get() -> string
func inner() -> string:
results <- DTGetter.get_dt()
<- results
func outer() -> []string:
results: *string
results <- use_name1()
<- results
*/
"arrow inliner" should "generate result in right order" in {
val innerName = "inner"
val results = VarRaw("results", ScalarType.string)
val resultsOut = VarRaw("results", StreamType(ScalarType.string))
val inner = FuncArrow(
innerName,
SeqTag.wrap(
CallArrowRawTag
.service(
LiteralRaw.quote("Get"),
"get",
Call(Nil, Call.Export(results.name, results.baseType) :: Nil)
)
.leaf
),
ArrowType(
ProductType(Nil),
ProductType(ScalarType.string :: Nil)
),
results :: Nil,
Map.empty,
Map.empty,
None
)
val captured = Map.apply((innerName, inner))
val outer = FuncArrow(
"outer",
SeqTag.wrap(
DeclareStreamTag(resultsOut).leaf,
CallArrowRawTag
.func(innerName, Call(Nil, Call.Export(resultsOut.name, resultsOut.baseType) :: Nil))
.leaf
),
ArrowType(
ProductType(Nil),
ProductType(ArrayType(ScalarType.string) :: Nil)
),
resultsOut :: Nil,
captured,
Map.empty,
None
)
val (state, model: OpModel.Tree) = ArrowInliner
.callArrow[InliningState](outer, CallModel(Nil, Nil))
.run(InliningState())
.value
val resultModel = VarModel("results-0", ScalarType.string)
model.equalsOrShowDiff(
MetaModel
.CallArrowModel(innerName)
.wrap(
SeqModel.wrap(
CallServiceModel(
LiteralModel.quote("Get"),
"get",
CallModel(Nil, CallModel.Export(resultModel.name, resultModel.`type`) :: Nil)
).leaf,
PushToStreamModel(
resultModel,
CallModel.Export(resultsOut.name, resultsOut.baseType)
).leaf
)
)
) should be(true)
}
}