[go: nahoru, domu]

Skip to content

Commit

Permalink
Add cli, fix options, add support for extensions, add support for ext…
Browse files Browse the repository at this point in the history
…ension ranges, reserved ranges, reserved names
  • Loading branch information
valaphee committed Apr 5, 2022
1 parent b4a8477 commit 89cc9c7
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 27 deletions.
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ repositories {
dependencies {
implementation("com.google.protobuf:protobuf-kotlin:3.19.4")
testImplementation("org.jetbrains.kotlin:kotlin-test-junit:1.6.20")
implementation("org.jetbrains.kotlinx:kotlinx-cli:0.3.4")
}

tasks {
Expand Down
2 changes: 0 additions & 2 deletions protoc.bat

This file was deleted.

40 changes: 35 additions & 5 deletions src/main/kotlin/com/valaphee/protod/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,18 @@ package com.valaphee.protod
import com.google.protobuf.CodedInputStream
import com.google.protobuf.DescriptorProtos
import com.valaphee.protod.util.occurrencesOf
import kotlinx.cli.ArgParser
import kotlinx.cli.ArgType
import kotlinx.cli.required
import java.io.File

fun main() {
val bytes = File("C:\\Program Files (x86)\\Battle.net\\Battle.net.13401\\battle.net.dll").readBytes()
fun main(arguments: Array<String>) {
val argumentParser = ArgParser("protod")
val inputArgument by argumentParser.option(ArgType.String, "input", "i", "Input file").required()
val outputArgument by argumentParser.option(ArgType.String, "output", "o", "Output path").required()
argumentParser.parse(arguments)

val bytes = File(inputArgument).readBytes()
val files = mutableListOf<DescriptorProtos.FileDescriptorProto>()
String(bytes, Charsets.US_ASCII).occurrencesOf(".proto").forEach {
var offset = 0
Expand Down Expand Up @@ -53,11 +61,33 @@ fun main() {
}
}

val enums = mutableMapOf<String, DescriptorProtos.EnumDescriptorProto>()
val messages = mutableMapOf<String, DescriptorProtos.DescriptorProto>()
files.forEach { file -> file.messageTypeList.forEach { message -> messages[".${file.`package`}.${message.name}"] = message } }
files.forEach { file ->
file.enumTypeList.forEach { enums[".${file.`package`}.${it.name}"] = it }
file.messageTypeList.forEach {
fun flatten(name: String, message: DescriptorProtos.DescriptorProto) {
messages[name] = message
message.enumTypeList.forEach { enums["$name.${it.name}"] = it }
message.nestedTypeList.forEach { flatten("$name.${it.name}", it) }
}

flatten(".${file.`package`}.${it.name}", it)
}
}
val messageExtensions = mutableMapOf<String, MutableMap<Int, DescriptorProtos.FieldDescriptorProto>>()
files.forEach { file -> file.extensionList.forEach { extension -> messages[extension.typeName]?.let { messageExtensions.getOrPut(extension.extendee) { mutableMapOf() }[extension.number] = extension } } }

val outputPath = File("output")
files.forEach { file -> File(outputPath, file.name).apply { parentFile.mkdirs() }.printWriter().use { printWriter -> ProtoWriter(printWriter, messages, messageExtensions).print(file) } }
val outputPath = File(outputArgument)
files.forEach { file ->
File(outputPath, file.name).apply { parentFile.mkdirs() }.printWriter().use { printWriter ->
printWriter.println(
"""
/* AUTO-GENERATED FILE. DO NOT MODIFY.
*/
""".trimIndent()
)
ProtoWriter(printWriter, enums, messages, messageExtensions).print(file)
}
}
}
154 changes: 134 additions & 20 deletions src/main/kotlin/com/valaphee/protod/ProtoWriter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,32 @@ package com.valaphee.protod

import com.google.protobuf.DescriptorProtos
import com.google.protobuf.GeneratedMessageV3
import com.google.protobuf.UnknownFieldSet
import com.google.protobuf.WireFormat
import java.io.PrintWriter

/**
* @author Kevin Ludwig
*/
class ProtoWriter(
private val printWriter: PrintWriter,
private val enums: MutableMap<String, DescriptorProtos.EnumDescriptorProto>,
private val messages: MutableMap<String, DescriptorProtos.DescriptorProto>,
private val messageExtensions: MutableMap<String, out Map<Int, DescriptorProtos.FieldDescriptorProto>>
) {
private var first = false
private var indentLevel = 0

fun print(file: DescriptorProtos.FileDescriptorProto) {
println("""syntax = "proto2";""")
println()
println("""package ${file.`package`};""")
printImports(file.dependencyList)
println("syntax = \"proto2\";")
if (file.hasPackage()) {
println()
println("package ${file.`package`};")
}

printImports(file.dependencyList, file.publicDependencyList)
printOptions(file.options)
printExtensions(file.extensionList)
printEnums(file.enumTypeList)
printMessages(file.messageTypeList)
printServices(file.serviceList)
}
Expand All @@ -52,68 +58,176 @@ class ProtoWriter(
printWriter.println(value)
}

private fun printImports(dependencyList: List<String>) {
private fun printImports(dependencyList: List<String>, publicDependencyList: List<Int>) {
if (dependencyList.isNotEmpty()) if (first) first = false else println()
dependencyList.forEach { println("""import "$it";""") }
dependencyList.forEachIndexed { i, dependency -> println("import ${if (publicDependencyList.contains(i)) "public " else ""}\"$dependency\";") }
}

private fun printOptions(options: GeneratedMessageV3.ExtendableMessage<*>) {
if (options.allFields.isNotEmpty() || options.unknownFields.asMap().isNotEmpty()) if (first) first = false else println()
val generatedOptions = generateOptions(options)
if (generatedOptions.isNotEmpty()) if (first) first = false else println()
generatedOptions.forEach { println("option $it;") }
}

private fun generateOptions(options: GeneratedMessageV3.ExtendableMessage<*>): MutableList<String> {
val generatedOptions = mutableListOf<String>()
options.allFields.forEach {
println("""option ${it.key.name} = ${when (it.value) {
is String -> """"${it.value}""""
generatedOptions += "${it.key.name} = ${when (it.value) {
is String -> "\"${it.value}\""
else -> it.value
}};""")
}}"
}
val messageExtensions = messageExtensions[".${options.descriptorForType.fullName}"] ?: emptyMap()
options.unknownFields.asMap().forEach { unknownField -> messageExtensions[unknownField.key]?.let { messageExtension -> messages[messageExtension.typeName]?.fieldList?.let { messageExtensionFields -> unknownField.value.lengthDelimitedList.forEach { UnknownFieldSet.parseFrom(it).asMap().forEach { messageExtensionField -> println("""option (${messageExtension.name}).${messageExtensionFields.single { it.number == messageExtensionField.key}?.name} = ${messageExtensionField.value.varintList.firstOrNull() ?: messageExtensionField.value.fixed32List.firstOrNull() ?: messageExtensionField.value.fixed64List.firstOrNull() ?: messageExtensionField.value.lengthDelimitedList.firstOrNull()?.toStringUtf8()?.let { """"$it"""" } ?: TODO()};""") } } } } }
options.unknownFields.asMap().forEach { unknownField ->
messageExtensions[unknownField.key]?.let { messageExtension ->
val messageExtensionFields = checkNotNull(messages[messageExtension.typeName]).fieldList
unknownField.value.lengthDelimitedList.forEach {
val codedInputStream = it.newCodedInput()
while (true) {
val tag = codedInputStream.readTag()
if (tag == 0) break
val fieldNumber = WireFormat.getTagFieldNumber(tag)
val messageExtensionField = messageExtensionFields.single { it.number == fieldNumber}
generatedOptions += "(${messageExtension.name}).${messageExtensionField.name} = ${when (messageExtensionField.type) {
DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE -> codedInputStream.readDouble()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_FLOAT -> codedInputStream.readFloat()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64 -> codedInputStream.readInt64()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT64 -> codedInputStream.readUInt64()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT32 -> codedInputStream.readInt32()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED64 -> codedInputStream.readFixed64()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED32 -> codedInputStream.readFixed32()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL -> codedInputStream.readBool()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING -> "\"${codedInputStream.readString()}\""
DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES -> codedInputStream.readBytes()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT32 -> codedInputStream.readUInt32()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_ENUM -> {
val value = codedInputStream.readEnum()
checkNotNull(enums[messageExtensionField.typeName]).valueList.single { it.number == value}.name
}
DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED32 -> codedInputStream.readSFixed32()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED64 -> codedInputStream.readSFixed64()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT32 -> codedInputStream.readSInt32()
DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT64 -> codedInputStream.readSInt64()
else -> TODO()
}}"
}
}
}
}
return generatedOptions
}

private fun printExtensions(extensionList: List<DescriptorProtos.FieldDescriptorProto>) {
var first0 = first
extensionList.groupBy { it.extendee }.forEach {
if (first0) first0 = false else println()
println("extend ${it.key} {")
first = true
indentLevel++
printFields(it.value)
first = first0
indentLevel--
println("}")
}
}

private fun printEnums(enumTypeList: List<DescriptorProtos.EnumDescriptorProto>) {
var first0 = first
enumTypeList.forEach {
if (first0) first0 = false else println()
println("enum ${it.name} {")
first = true
indentLevel++
printOptions(it.options)
printEnumValues(it.valueList)
if (it.reservedRangeList.isNotEmpty()) {
if (first) first = false else println()
println(it.reservedRangeList.joinToString(prefix = "reserved ", postfix = ";") { if (it.start != it.end) "${it.start} to ${if (it.end != max) it.end else "max"}" else "${it.start}" })
}
if (it.reservedNameList.isNotEmpty()) {
if (first) first = false else println()
println(it.reservedNameList.joinToString(prefix = "reserved ", postfix = ";"))
}
first = first0
indentLevel--
println("}")
}
}

private fun printEnumValues(enumValueList: List<DescriptorProtos.EnumValueDescriptorProto>) {
if (enumValueList.isNotEmpty()) if (first) first = false else println()
enumValueList.forEach {
val generatedOptions = generateOptions(it.options)
println("${it.name} = ${it.number}${if (generatedOptions.isNotEmpty()) " [${generatedOptions.joinToString()}]" else ""};")
}
}

private fun printMessages(messageTypeList: List<DescriptorProtos.DescriptorProto>) {
var first0 = first
messageTypeList.forEach {
if (first0) first0 = false else println()
println("""message ${it.name} {""")
println("message ${it.name} {")
first = true
indentLevel++
printOptions(it.options)
printExtensions(it.extensionList)
/*if (it.oneofDeclList.size != 0)*/
printEnums(it.enumTypeList)
printMessages(it.nestedTypeList)
printFields(it.fieldList)
if (it.extensionRangeCount != 0) {
if (first) first = false else println()
println(it.extensionRangeList.joinToString(prefix = "extensions ", postfix = ";") { if (it.start != it.end) "${it.start} to ${if (it.end != max) it.end else "max"}" else "${it.start}" })
}
if (it.reservedRangeCount != 0) {
if (first) first = false else println()
println(it.reservedRangeList.joinToString(prefix = "reserved ", postfix = ";") { if (it.start != it.end) "${it.start} to ${if (it.end != max) it.end else "max"}" else "${it.start}" })
}
if (it.reservedNameCount != 0) {
if (first) first = false else println()
println(it.reservedNameList.joinToString(prefix = "reserved ", postfix = ";"))
}
first = first0
indentLevel--
println("""}""")
println("}")
}
}

private fun printFields(fieldList: List<DescriptorProtos.FieldDescriptorProto>) {
if (fieldList.isNotEmpty()) if (first) first = false else println()
fieldList.forEach { println("""${checkNotNull(fieldLabels[it.label])} ${if (it.hasTypeName()) it.typeName else checkNotNull(fieldTypes[it.type])} ${it.name} = ${it.number};""") }
fieldList.forEach {
val generatedOptions = generateOptions(it.options)
if (it.hasDefaultValue()) generatedOptions += "default = ${if (it.type == DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING) "\"${it.defaultValue}\"" else it.defaultValue}"
println("${checkNotNull(fieldLabels[it.label])} ${if (it.hasTypeName()) it.typeName else checkNotNull(fieldTypes[it.type])} ${it.name} = ${it.number}${if (generatedOptions.isNotEmpty()) " [${generatedOptions.joinToString()}]" else ""};")
}
}

private fun printServices(serviceList: List<DescriptorProtos.ServiceDescriptorProto>) {
var first0 = first
serviceList.forEach {
if (first0) first0 = false else println()
println("""service ${it.name} {""")
println("service ${it.name} {")
first = true
indentLevel++
printOptions(it.options)
printMethods(it.methodList)
first = first0
indentLevel--
println("""}""")
println("}")
}
}

private fun printMethods(methodList: List<DescriptorProtos.MethodDescriptorProto>) {
var first0 = first
methodList.forEach {
if (first0) first0 = false else println()
println("""rpc ${it.name} (${it.inputType}) returns (${it.outputType}) {""")
println("rpc ${it.name} (${it.inputType}) returns (${it.outputType}) {")
first = true
indentLevel++
printOptions(it.options)
first = first0
indentLevel--
println("""}""")
println("}")
}
}

Expand All @@ -123,7 +237,6 @@ class ProtoWriter(
DescriptorProtos.FieldDescriptorProto.Label.LABEL_REQUIRED to "required",
DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED to "repeated"
)

private val fieldTypes = mapOf(
DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE to "double",
DescriptorProtos.FieldDescriptorProto.Type.TYPE_FLOAT to "float",
Expand All @@ -141,5 +254,6 @@ class ProtoWriter(
DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT32 to "sint32",
DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT64 to "sint64",
)
private val max = 1 shl 29
}
}

0 comments on commit 89cc9c7

Please sign in to comment.