Skip to content

Commit

Permalink
Merge pull request #6 from orangain/brand-new-builders
Browse files Browse the repository at this point in the history
Brand new builders
  • Loading branch information
orangain authored Jul 9, 2022
2 parents 6bca707 + ea3fe8c commit 8260c20
Show file tree
Hide file tree
Showing 8 changed files with 476 additions and 385 deletions.
2 changes: 1 addition & 1 deletion ktcodeshift-cli/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies {
implementation("info.picocli:picocli:4.6.3")
implementation(project(":ktcodeshift-dsl")) // the script definition module
implementation("org.jetbrains.kotlin:kotlin-test")
testImplementation("com.github.orangain.ktast:ast-psi:0.7.2")
testImplementation("com.github.orangain.ktast:ast-psi:0.8.0")
}

tasks.withType<KotlinCompile> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ transform { fileInfo ->
Api
.parse(fileInfo.source)
.find<Node.ValueArgs>()
.filter { _, p -> p is Node.Expr.Call && isListOf(p) }
.filter { _, p -> p is Node.Expression.Call && isListOf(p) }
.replaceWith { v ->
v.copy(elements = v.elements.flatMap { element ->
val expr = element.expr
if (expr is Node.Expr.Call && isListOf(expr)) {
val expr = element.expression
if (expr is Node.Expression.Call && isListOf(expr)) {
expr.args?.elements ?: listOf()
} else {
listOf(element)
Expand All @@ -22,7 +22,7 @@ transform { fileInfo ->
.toSource()
}

fun isListOf(call: Node.Expr.Call): Boolean {
val expr = call.expr
return expr is Node.Expr.Name && expr.name == "listOf"
fun isListOf(call: Node.Expression.Call): Boolean {
val expr = call.expression as? Node.Expression.Name ?: return false
return expr.name == "listOf"
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import ktast.ast.Node
import ktast.ast.Visitor
import ktast.ast.Writer
import ktcodeshift.*
import ktcodeshift.Api
import ktcodeshift.isDataClass
import ktcodeshift.toSource
import ktcodeshift.transform
import org.jetbrains.kotlin.util.capitalizeDecapitalize.decapitalizeSmart
import java.nio.charset.StandardCharsets

Expand All @@ -14,8 +17,8 @@ transform { fileInfo ->
val fqNames = mutableSetOf<List<String>>()

object : Visitor() {
override fun visit(v: Node?, parent: Node) {
if (v is Node.Decl.Structured) {
override fun visit(v: Node, parent: Node?) {
if (v is Node.Declaration.Class) {
val name = v.name?.name.orEmpty()
nestedNames.add(name)

Expand All @@ -35,18 +38,19 @@ transform { fileInfo ->

fun toFqNameType(type: Node.Type.Simple, nestedNames: List<String>): Node.Type.Simple {

// e.g. Make List<Expression> to List<Node.Expression>
if (type.pieces.size == 1 && type.pieces[0].name.name == "List") {
val typeArgs = type.pieces[0].typeArgs
if (typeArgs != null && typeArgs.elements.size == 1) {
val typeArg = typeArgs.elements[0] as Node.TypeArg.Type
val typeArg = typeArgs.elements[0]
return simpleType(
pieces = type.pieces.map {
it.copy(
typeArgs = typeArgs(
typeArg.copy(
typeRef = typeArg.typeRef.copy(
typeRef = typeArg.typeRef?.copy(
type = toFqNameType(
typeArg.typeRef.type!!.asSimpleType(),
typeArg.typeRef?.type as Node.Type.Simple,
nestedNames
),
),
Expand Down Expand Up @@ -74,49 +78,57 @@ transform { fileInfo ->
mutableNestedNames.removeLast()
}

if (pieceNames == listOf("Receiver")) {
return simpleType(
pieces = listOf(
piece(name = nameExpression("Node")),
piece(name = nameExpression("Expression")),
piece(name = nameExpression("DoubleColon")),
) + type.pieces
)
}

return type
}

object : Visitor() {
override fun visit(v: Node?, parent: Node) {
if (v is Node.Decl.Structured) {
override fun visit(v: Node, parent: Node?) {
if (v is Node.Declaration.Class) {
val name = v.name?.name.orEmpty()
nestedNames.add(name)

if (v.isDataClass) {
val params = v.primaryConstructor?.params?.elements.orEmpty()
val functionName = toFunctionName(nestedNames)

val func = function(
val func = functionDeclaration(
name = nameExpression(functionName),
params = functionParams(
elements = params.map { p ->
val fqType = when (val type = p.typeRef?.type) {
is Node.Type.Simple -> toFqNameType(type, nestedNames)
is Node.Type.Nullable -> type.copy(
type = toFqNameType(type.type.asSimpleType(), nestedNames)
type = toFqNameType(type.type as Node.Type.Simple, nestedNames)
)
else -> type
}
functionParam(
name = p.name,
typeRef = p.typeRef?.copy(type = fqType),
initializer = initializerOf(fqType),
defaultValue = defaultValueOf(fqType),
)
},
),
body = functionExpressionBody(
expr = callExpression(
expr = nameExpression(nestedNames.joinToString(".")),
args = valueArgs(
elements = params.map { p ->
valueArg(
name = p.name,
expr = p.name,
)
},
),
)
body = callExpression(
expression = nameExpression(nestedNames.joinToString(".")),
args = valueArgs(
elements = params.map { p ->
valueArg(
name = p.name,
expression = expressionOf(functionName, p.name),
)
},
),
)
)
// println(nestedNames.joinToString(".") + "\t" + toFunctionName(nestedNames))
Expand All @@ -129,27 +141,26 @@ transform { fileInfo ->
"pieces",
).contains(firstParam.name.name)
) {
val firstParamType = firstParam.typeRef?.type?.asSimpleTypeOrNull()
val firstParamType = firstParam.typeRef?.type as? Node.Type.Simple
if (firstParamType != null) {
if (firstParamType.pieces.firstOrNull()?.name?.name == "List") {
val listElementType = firstParamType.pieces.first().typeArgs!!.elements[0].type
val listElementType =
firstParamType.pieces.first().typeArgs!!.elements[0].typeRef?.type
if (listElementType != null) {
val varargFunc = function(
val varargFunc = functionDeclaration(
name = nameExpression(functionName),
params = functionParams(
functionParam(
mods = modifiers(literalModifier(Node.Modifier.Keyword.VARARG)),
modifiers = modifiers(keywordModifier(Node.Modifier.Keyword.Token.VARARG)),
name = nameExpression(firstParam.name.name),
typeRef = typeRef(type = listElementType),
)
),
body = functionExpressionBody(
expr = callExpression(
expr = nameExpression(functionName),
args = valueArgs(
valueArg(
expr = nameExpression("${firstParam.name.name}.toList()"),
)
body = callExpression(
expression = nameExpression(functionName),
args = valueArgs(
valueArg(
expression = nameExpression("${firstParam.name.name}.toList()"),
)
)
)
Expand Down Expand Up @@ -184,55 +195,77 @@ fun toFunctionName(nestedNames: List<String>): String {
val prefix = "$parent."
return fqName.startsWith(prefix) && nestedNames.size == prefix.count { it == '.' } + 1
}
if (isChildOf("Node.Decl.Property.Variable")) {
return "variable"
if (isChildOf("Node.Declaration")) {
return "${name.decapitalizeSmart()}Declaration"
}
if (isChildOf("Node.Type")) {
return "${name.decapitalizeSmart()}Type"
}
if (isChildOf("Node.Expr")) {
if (isChildOf("Node.Expression")) {
return "${name.decapitalizeSmart()}Expression"
}
if (isChildOf("Node.Expr.When.Entry")) {
return "whenEntry$name"
if (isChildOf("Node.Expression.When.Branch")) {
return "${name.decapitalizeSmart()}Branch"
}
if (isChildOf("Node.Expression.When.Condition")) {
return "${name.decapitalizeSmart()}Condition"
}
if (isChildOf("Node.Expr.When.Cond")) {
return "whenCondition$name"
if (isChildOf("Node.Expression.DoubleColon.Receiver")) {
return "${name.decapitalizeSmart()}DoubleColonReceiver"
}
if (isChildOf("Node.Declaration.Class.Parent")) {
return "${name.decapitalizeSmart()}Parent"
}
return when (fqName) {
"Node.Package" -> "packageDirective"
"Node.Imports" -> "importDirectives"
"Node.Import" -> "importDirective"
"Node.Decl.Func" -> "function"
"Node.Decl.Func.Params" -> "functionParams"
"Node.Decl.Func.Param" -> "functionParam"
"Node.Decl.Func.Body.Block" -> "functionBlockBody"
"Node.Decl.Func.Body.Expr" -> "functionExpressionBody"
"Node.Expr.DoubleColonRef.Class" -> "doubleColonClassLiteral"
"Node.Modifier.Lit" -> "literalModifier"
"Node.Declaration.Class.Body" -> "classBody"
"Node.Declaration.Function.Params" -> "functionParams"
"Node.Declaration.Function.Param" -> "functionParam"
"Node.Type.Function.Receiver" -> "functionTypeReceiver"
"Node.Type.Function.Params" -> "functionTypeParams"
"Node.Type.Function.Param" -> "functionTypeParam"
"Node.Expression.Lambda.Params" -> "lambdaParams"
"Node.Expression.Lambda.Param" -> "lambdaParam"
"Node.Expression.Binary.Operator" -> "binaryOperator"
"Node.Expression.Unary.Operator" -> "unaryOperator"
"Node.Expression.BinaryType.Operator" -> "binaryTypeOperator"
"Node.Modifier.AnnotationSet.Target" -> "annotationSetTarget"
"Node.Modifier.Keyword" -> "keywordModifier"
else -> name.decapitalizeSmart()
}
}

fun initializerOf(type: Node.Type?): Node.Initializer? {
val expr = if (type is Node.Type.Nullable) {
Node.Expr.Name("null")
fun defaultValueOf(type: Node.Type?): Node.Expression? {
return if (type is Node.Type.Nullable) {
nameExpression("null")
} else if (type is Node.Type.Simple) {
val fqName = type.pieces.joinToString(".") { it.name.name }
if (fqName == "List") {
Node.Expr.Name("listOf()")
nameExpression("listOf()")
} else if (fqName == "Boolean") {
Node.Expr.Name("false")
nameExpression("false")
} else {
if (fqName.startsWith("Node.Keyword.") && !(fqName.contains(".ValOrVar") || fqName.contains(".Declaration"))) {
Node.Expr.Name("$fqName()")
nameExpression("$fqName()")
} else {
null
}
}
} else {
null
}
}

return expr?.let { Node.Initializer(Node.Keyword.Equal(), it) }
fun expressionOf(functionName: String, paramName: Node.Expression.Name): Node.Expression {
if (paramName.name == "equals") {
val expressionText = when (functionName) {
"functionDeclaration", "getter", "setter" -> "if (equals == null && body != null && body !is Node.Expression.Block) Node.Keyword.Equal() else equals"
"functionParam" -> "if (equals == null && defaultValue != null) Node.Keyword.Equal() else equals"
"propertyDeclaration" -> "if (equals == null && initializer != null) Node.Keyword.Equal() else equals"
else -> null
}
if (expressionText != null) {
return nameExpression(expressionText)
}
}
return paramName
}
58 changes: 26 additions & 32 deletions ktcodeshift-cli/src/test/resources/examples/JUnit4To5.transform.kts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ val annotationNameMap = mapOf(
transform { fileInfo ->
Api
.parse(fileInfo.source)
.find<Node.Import>()
.find<Node.ImportDirective>()
.filter { v ->
v.names.size == 3 && v.names.take(2).map { it.name } == listOf("org", "junit")
}
Expand All @@ -27,21 +27,19 @@ transform { fileInfo ->
}
.find<Node.Modifier.AnnotationSet.Annotation>()
.replaceWith { v ->
val name = annotationNameMap[v.constructorCallee.type.pieces.last().name.name]?.let(::nameExpression)
val name = annotationNameMap[v.type.pieces.last().name.name]?.let(::nameExpression)
if (name != null) {
v.copy(
constructorCallee = constructorCallee(
type = simpleType(
pieces = v.constructorCallee.type.pieces.dropLast(1) + v.constructorCallee.type.pieces.last()
.copy(name = name)
)
type = simpleType(
pieces = v.type.pieces.dropLast(1) + v.type.pieces.last()
.copy(name = name)
)
)
} else {
v
}
}
.find<Node.Decl.Func>()
.find<Node.Declaration.Function>()
.filter { v ->
val annotation = getAnnotationByName(v.annotations, "Test")
getValueArgByName(annotation?.args, "expected") != null
Expand All @@ -50,38 +48,34 @@ transform { fileInfo ->
val annotation = getAnnotationByName(v.annotations, "Test")
val arg = getValueArgByName(annotation?.args, "expected")
val exceptionType =
((arg?.expr as Node.Expr.DoubleColonRef.Class).recv as Node.Expr.DoubleColonRef.Recv.Type).type
val originalStatements = (v.body as Node.Decl.Func.Body.Block).block.statements
((arg?.expression as Node.Expression.ClassLiteral).lhs as Node.Expression.DoubleColon.Receiver.Type).type
val originalStatements = (v.body as Node.Expression.Block).statements

v.copy(
mods = v.mods!!.copy(
elements = v.mods!!.elements.map {
if (it is Node.Modifier.AnnotationSet && it.anns.contains(annotation)) {
it.copy(anns = listOf(annotation!!.copy(args = null)))
modifiers = v.modifiers!!.copy(
elements = v.modifiers!!.elements.map {
if (it is Node.Modifier.AnnotationSet && it.annotations.contains(annotation)) {
it.copy(annotations = listOf(annotation!!.copy(args = null)))
} else {
it
}
}
),
body = functionBlockBody(
block = blockExpression(
callExpression(
expr = nameExpression("Assertions.assertThrows"),
typeArgs = typeArgs(
type(
typeRef = typeRef(
type = exceptionType,
)
body = blockExpression(
callExpression(
expression = nameExpression("Assertions.assertThrows"),
typeArgs = typeArgs(
typeArg(
typeRef = typeRef(
type = exceptionType,
)
)
),
lambdaArg = lambdaArg(
expression = lambdaExpression(
body = body(originalStatements),
),
lambdaArgs = listOf(
lambdaArg(
func = lambdaExpression(
body = body(originalStatements),
),
)
),
)
),
)
)
)
Expand All @@ -93,7 +87,7 @@ fun getAnnotationByName(
annotations: List<Node.Modifier.AnnotationSet.Annotation>,
name: String
): Node.Modifier.AnnotationSet.Annotation? {
return annotations.find { it.constructorCallee.type.pieces.last().name.name == name }
return annotations.find { it.type.pieces.last().name.name == name }
}

fun getValueArgByName(args: Node.ValueArgs?, name: String): Node.ValueArg? {
Expand Down
Loading

0 comments on commit 8260c20

Please sign in to comment.