diff --git a/wire-compiler/src/test/java/com/squareup/wire/schema/OptionsLinkingTest.kt b/wire-compiler/src/test/java/com/squareup/wire/schema/OptionsLinkingTest.kt index 3fc4052783..4a2df8a160 100644 --- a/wire-compiler/src/test/java/com/squareup/wire/schema/OptionsLinkingTest.kt +++ b/wire-compiler/src/test/java/com/squareup/wire/schema/OptionsLinkingTest.kt @@ -18,12 +18,14 @@ package com.squareup.wire.schema import assertk.assertThat +import assertk.assertions.contains import assertk.assertions.isEqualTo import assertk.assertions.isNotNull import com.squareup.wire.testing.add import okio.Path import okio.Path.Companion.toPath import okio.fakefilesystem.FakeFileSystem +import org.junit.Assert.fail import org.junit.Test class OptionsLinkingTest { @@ -169,6 +171,62 @@ class OptionsLinkingTest { assertThat(typeRange.field("max")).isNotNull() } + @Test + fun rejectsInvalidOptionScalarLiterals() { + fs.add( + "source-path/a.proto", + """ + |import "formatting_options.proto"; + | + |message A { + | option (message_options).enabled = "false; static { } //"; + | optional string s = 1 [ + | (formatting_options).max = "80; static { } //", + | (formatting_options).casing = "LOWER_CASE; static { } //" + | ]; + |} + """.trimMargin(), + ) + fs.add( + "source-path/formatting_options.proto", + """ + |import "google/protobuf/descriptor.proto"; + | + |message MessageOptions { + | optional bool enabled = 1; + |} + | + |message FormattingOptions { + | optional int32 max = 1; + | optional StringCasing casing = 2; + | optional string documentation = 3; + |} + | + |enum StringCasing { + | LOWER_CASE = 1; + |} + | + |extend google.protobuf.MessageOptions { + | optional MessageOptions message_options = 22001; + |} + | + |extend google.protobuf.FieldOptions { + | optional FormattingOptions formatting_options = 22002; + |} + """.trimMargin(), + ) + + try { + loadAndLinkSchema() + fail() + } catch (expected: SchemaException) { + val message = expected.message!! + assertThat(message).contains("invalid option value \"false; static { } //\" for bool") + assertThat(message).contains("invalid option value \"80; static { } //\" for int32") + assertThat(message).contains("invalid option value \"LOWER_CASE; static { } //\" for StringCasing") + } + } + @Test fun extensionTypesInExternalFile() { fs.add( diff --git a/wire-java-generator/src/test/java/com/squareup/wire/java/JavaGeneratorTest.java b/wire-java-generator/src/test/java/com/squareup/wire/java/JavaGeneratorTest.java index 39820078a9..78d5787a75 100644 --- a/wire-java-generator/src/test/java/com/squareup/wire/java/JavaGeneratorTest.java +++ b/wire-java-generator/src/test/java/com/squareup/wire/java/JavaGeneratorTest.java @@ -24,6 +24,7 @@ import com.squareup.wire.schema.MessageType; import com.squareup.wire.schema.PruningRules; import com.squareup.wire.schema.Schema; +import com.squareup.wire.schema.SchemaException; import java.io.IOException; import okio.Path; import org.junit.Test; @@ -608,21 +609,20 @@ public void defaultValues() throws IOException { @Test public void defaultValuesMustNotBeOctal() throws IOException { - Schema schema = - new SchemaBuilder() - .add( - Path.get("message.proto"), - "" - + "message Message {\n" - + " optional int32 a = 1 [default = 020 ];\n" - + " optional int64 b = 2 [default = 021 ];\n" - + "}\n") - .build(); try { - new JavaWithProfilesGenerator(schema).generateJava("Message"); + new SchemaBuilder() + .add( + Path.get("message.proto"), + "" + + "message Message {\n" + + " optional int32 a = 1 [default = 020 ];\n" + + " optional int64 b = 2 [default = 021 ];\n" + + "}\n") + .build(); fail(); - } catch (IllegalStateException expected) { - assertThat(expected).hasMessageThat().contains("Octal literal unsupported: 020"); + } catch (SchemaException expected) { + assertThat(expected).hasMessageThat().contains("invalid default value \"020\" for int32"); + assertThat(expected).hasMessageThat().contains("invalid default value \"021\" for int64"); } } diff --git a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Field.kt b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Field.kt index c353a904a2..62b0b383c6 100644 --- a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Field.kt +++ b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Field.kt @@ -162,6 +162,7 @@ data class Field( } } syntaxRules.validateDefaultValue(default != null, linker.errors) + validateDefaultValue(linker) if (type!!.isMap) { val valueType = linker.get(type!!.valueType!!) if (valueType is EnumType && valueType.constants[0].tag != 0) { @@ -171,6 +172,15 @@ data class Field( linker.validateImportForType(location, type!!) } + private fun validateDefaultValue(linker: Linker) { + val default = default ?: return + val type = type!! + + if (!isValidLiteral(linker, type, default)) { + linker.errors += "invalid default value \"$default\" for $type" + } + } + fun retainAll(schema: Schema, markSet: MarkSet, enclosingType: ProtoType): Field? { // TODO(jwilson): perform this transformation in the Linker. val type = type ?: return null diff --git a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/LiteralValidation.kt b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/LiteralValidation.kt new file mode 100644 index 0000000000..a04c7199ad --- /dev/null +++ b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/LiteralValidation.kt @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2026 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.squareup.wire.schema + +internal fun isValidLiteral(linker: Linker, type: ProtoType, value: String): Boolean = when (type) { + ProtoType.BOOL -> value == "true" || value == "false" + ProtoType.BYTES, ProtoType.STRING -> true + ProtoType.DOUBLE, ProtoType.FLOAT -> value.isValidFloatingPointDefault() + ProtoType.FIXED32, ProtoType.UINT32 -> value.isValidUnsignedIntegerDefault(UINT32_MAX) + ProtoType.FIXED64, ProtoType.UINT64 -> value.isValidUnsignedIntegerDefault(UINT64_MAX) + ProtoType.INT32, ProtoType.SFIXED32, ProtoType.SINT32 -> { + value.isValidSignedIntegerDefault(INT32_MIN, INT32_MAX) + } + ProtoType.INT64, ProtoType.SFIXED64, ProtoType.SINT64 -> { + value.isValidSignedIntegerDefault(INT64_MIN, INT64_MAX) + } + else -> { + val valueType = linker.get(type) + valueType is EnumType && valueType.constant(value) != null + } +} + +private const val INT32_MAX = "2147483647" +private const val INT32_MIN = "-2147483648" +private const val UINT32_MAX = "4294967295" +private const val INT64_MAX = "9223372036854775807" +private const val INT64_MIN = "-9223372036854775808" +private const val UINT64_MAX = "18446744073709551615" + +private val DECIMAL_INTEGER_REGEX = Regex("-?[0-9]+") +private val FLOATING_POINT_REGEX = Regex("-?((([0-9]+)(\\.[0-9]*)?)|(\\.[0-9]+))([eE][+-]?[0-9]+)?") +private val HEX_INTEGER_REGEX = Regex("-?0[xX][0-9a-fA-F]+") + +private fun String.isValidFloatingPointDefault(): Boolean { + if (this in listOf("inf", "-inf", "nan", "-nan")) return true + return FLOATING_POINT_REGEX.matches(this) && toDoubleOrNull() != null +} + +private fun String.isValidSignedIntegerDefault(min: String, max: String): Boolean { + if (!isIntegerLiteral() || hasOctalPrefix()) return false + return compareIntegerLiterals(min, this) <= 0 && compareIntegerLiterals(this, max) <= 0 +} + +private fun String.isValidUnsignedIntegerDefault(max: String): Boolean { + if (!isIntegerLiteral() || hasOctalPrefix() || startsWith("-")) return false + return compareIntegerLiterals(this, max) <= 0 +} + +private fun String.isIntegerLiteral() = DECIMAL_INTEGER_REGEX.matches(this) || HEX_INTEGER_REGEX.matches(this) + +private fun String.hasOctalPrefix(): Boolean { + val digits = removePrefix("-") + return digits.length > 1 && digits[0] == '0' && digits[1] != 'x' && digits[1] != 'X' +} + +private fun compareIntegerLiterals(a: String, b: String): Int { + val aNegative = a.startsWith("-") + val bNegative = b.startsWith("-") + if (aNegative != bNegative) return if (aNegative) -1 else 1 + + val magnitudeComparison = compareIntegerMagnitudes( + a.removePrefix("-"), + b.removePrefix("-"), + ) + return if (aNegative) -magnitudeComparison else magnitudeComparison +} + +private fun compareIntegerMagnitudes(a: String, b: String): Int { + val aDecimal = a.decimalMagnitude() + val bDecimal = b.decimalMagnitude() + if (aDecimal.length != bDecimal.length) return aDecimal.length.compareTo(bDecimal.length) + return aDecimal.compareTo(bDecimal) +} + +private fun String.decimalMagnitude(): String { + val stripped = if (startsWith("0x") || startsWith("0X")) { + hexMagnitudeToDecimal(substring(2)) + } else { + trimStart('0').ifEmpty { "0" } + } + return stripped.trimStart('0').ifEmpty { "0" } +} + +private fun hexMagnitudeToDecimal(hex: String): String { + var decimal = "0" + hex.forEach { digit -> + decimal = decimal.multiplyDecimalBy(16) + decimal = decimal.addDecimal(digit.digitToInt(16)) + } + return decimal +} + +private fun String.multiplyDecimalBy(multiplier: Int): String { + var carry = 0 + val result = StringBuilder(length + 2) + for (i in indices.reversed()) { + val value = (this[i] - '0') * multiplier + carry + result.append(value % 10) + carry = value / 10 + } + while (carry > 0) { + result.append(carry % 10) + carry /= 10 + } + return result.reverse().toString().trimStart('0').ifEmpty { "0" } +} + +private fun String.addDecimal(addend: Int): String { + var carry = addend + val result = StringBuilder(length + 2) + for (i in indices.reversed()) { + val value = (this[i] - '0') + carry + result.append(value % 10) + carry = value / 10 + } + while (carry > 0) { + result.append(carry % 10) + carry /= 10 + } + return result.reverse().toString().trimStart('0').ifEmpty { "0" } +} diff --git a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Options.kt b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Options.kt index 1d66fb3115..87311e2e43 100644 --- a/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Options.kt +++ b/wire-schema/src/commonMain/kotlin/com/squareup/wire/schema/Options.kt @@ -220,6 +220,7 @@ class Options( } is String -> { + validateOptionValue(linker, context, value) return coerceValueForField(context, value, isRepeated) } @@ -233,6 +234,12 @@ class Options( } } + private fun validateOptionValue(linker: Linker, context: ProtoType, value: String) { + if (!isValidLiteral(linker, context, value)) { + linker.errors += "invalid option value \"$value\" for $context" + } + } + private fun coerceValueForField(context: ProtoType, value: Any, isRepeated: Boolean): Any = when { isRepeated || context.isMap -> value as? List<*> ?: listOf(value) value is List<*> -> value.single()!! diff --git a/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/SchemaTest.kt b/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/SchemaTest.kt index df84028371..4a0265cd24 100644 --- a/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/SchemaTest.kt +++ b/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/SchemaTest.kt @@ -2162,6 +2162,68 @@ class SchemaTest { } } + @Test + fun proto2AllowsValidDefaultValues() { + buildSchema { + add( + "defaults.proto".toPath(), + """ + |syntax = "proto2"; + | + |message Defaults { + | optional bool bool_field = 1 [default = false]; + | optional int32 int32_field = 2 [default = -0x80000000]; + | optional uint32 uint32_field = 3 [default = 4294967295]; + | optional int64 int64_field = 4 [default = -9223372036854775808]; + | optional uint64 uint64_field = 5 [default = 18446744073709551615]; + | optional float float_field = 6 [default = inf]; + | optional double double_field = 7 [default = -1.23e45]; + | optional string string_field = 8 [default = "source syntax is just data here"]; + | optional bytes bytes_field = 9 [default = "abc"]; + | optional Choice enum_field = 10 [default = TWO]; + | + | enum Choice { + | ONE = 0; + | TWO = 1; + | } + |} + """.trimMargin(), + ) + } + } + + @Test + fun proto2RejectsInvalidDefaultValues() { + try { + buildSchema { + add( + "defaults.proto".toPath(), + """ + |syntax = "proto2"; + | + |message Defaults { + | optional bool bool_field = 1 [default = "false; static { } //"]; + | optional float float_field = 2 [default = "0.0f; init { } //"]; + | optional double double_field = 3 [default = "0.0d; static { } //"]; + | optional Choice enum_field = 4 [default = "ONE; static { } //"]; + | + | enum Choice { + | ONE = 0; + | } + |} + """.trimMargin(), + ) + } + fail() + } catch (expected: SchemaException) { + val message = expected.message!! + assertThat(message).contains("invalid default value \"false; static { } //\" for bool") + assertThat(message).contains("invalid default value \"0.0f; init { } //\" for float") + assertThat(message).contains("invalid default value \"0.0d; static { } //\" for double") + assertThat(message).contains("invalid default value \"ONE; static { } //\" for Defaults.Choice") + } + } + @Test fun repeatedNumericScalarsShouldBePackedByDefaultForProto3() { val schema = buildSchema {