Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package com.squareup.wire.schema

import assertk.assertThat
import assertk.assertions.contains
import assertk.assertions.isEqualTo
import assertk.assertions.isNotNull
import com.squareup.wire.testing.add
import okio.Path
import okio.Path.Companion.toPath
import okio.fakefilesystem.FakeFileSystem
import org.junit.Assert.fail
import org.junit.Test

class OptionsLinkingTest {
Expand Down Expand Up @@ -169,6 +171,62 @@ class OptionsLinkingTest {
assertThat(typeRange.field("max")).isNotNull()
}

@Test
fun rejectsInvalidOptionScalarLiterals() {
fs.add(
"source-path/a.proto",
"""
|import "formatting_options.proto";
|
|message A {
| option (message_options).enabled = "false; static { } //";
| optional string s = 1 [
| (formatting_options).max = "80; static { } //",
| (formatting_options).casing = "LOWER_CASE; static { } //"
| ];
|}
""".trimMargin(),
)
fs.add(
"source-path/formatting_options.proto",
"""
|import "google/protobuf/descriptor.proto";
|
|message MessageOptions {
| optional bool enabled = 1;
|}
|
|message FormattingOptions {
| optional int32 max = 1;
| optional StringCasing casing = 2;
| optional string documentation = 3;
|}
|
|enum StringCasing {
| LOWER_CASE = 1;
|}
|
|extend google.protobuf.MessageOptions {
| optional MessageOptions message_options = 22001;
|}
|
|extend google.protobuf.FieldOptions {
| optional FormattingOptions formatting_options = 22002;
|}
""".trimMargin(),
)

try {
loadAndLinkSchema()
fail()
} catch (expected: SchemaException) {
val message = expected.message!!
assertThat(message).contains("invalid option value \"false; static { } //\" for bool")
assertThat(message).contains("invalid option value \"80; static { } //\" for int32")
assertThat(message).contains("invalid option value \"LOWER_CASE; static { } //\" for StringCasing")
}
}

@Test
fun extensionTypesInExternalFile() {
fs.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.squareup.wire.schema.MessageType;
import com.squareup.wire.schema.PruningRules;
import com.squareup.wire.schema.Schema;
import com.squareup.wire.schema.SchemaException;
import java.io.IOException;
import okio.Path;
import org.junit.Test;
Expand Down Expand Up @@ -608,21 +609,20 @@ public void defaultValues() throws IOException {

@Test
public void defaultValuesMustNotBeOctal() throws IOException {
Schema schema =
new SchemaBuilder()
.add(
Path.get("message.proto"),
""
+ "message Message {\n"
+ " optional int32 a = 1 [default = 020 ];\n"
+ " optional int64 b = 2 [default = 021 ];\n"
+ "}\n")
.build();
try {
new JavaWithProfilesGenerator(schema).generateJava("Message");
new SchemaBuilder()
.add(
Path.get("message.proto"),
""
+ "message Message {\n"
+ " optional int32 a = 1 [default = 020 ];\n"
+ " optional int64 b = 2 [default = 021 ];\n"
+ "}\n")
.build();
fail();
} catch (IllegalStateException expected) {
assertThat(expected).hasMessageThat().contains("Octal literal unsupported: 020");
} catch (SchemaException expected) {
assertThat(expected).hasMessageThat().contains("invalid default value \"020\" for int32");
assertThat(expected).hasMessageThat().contains("invalid default value \"021\" for int64");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ data class Field(
}
}
syntaxRules.validateDefaultValue(default != null, linker.errors)
validateDefaultValue(linker)
if (type!!.isMap) {
val valueType = linker.get(type!!.valueType!!)
if (valueType is EnumType && valueType.constants[0].tag != 0) {
Expand All @@ -171,6 +172,15 @@ data class Field(
linker.validateImportForType(location, type!!)
}

private fun validateDefaultValue(linker: Linker) {
val default = default ?: return
val type = type!!

if (!isValidLiteral(linker, type, default)) {
linker.errors += "invalid default value \"$default\" for $type"
}
}

fun retainAll(schema: Schema, markSet: MarkSet, enclosingType: ProtoType): Field? {
// TODO(jwilson): perform this transformation in the Linker.
val type = type ?: return null
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright (C) 2026 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.squareup.wire.schema

internal fun isValidLiteral(linker: Linker, type: ProtoType, value: String): Boolean = when (type) {
ProtoType.BOOL -> value == "true" || value == "false"
ProtoType.BYTES, ProtoType.STRING -> true
ProtoType.DOUBLE, ProtoType.FLOAT -> value.isValidFloatingPointDefault()
ProtoType.FIXED32, ProtoType.UINT32 -> value.isValidUnsignedIntegerDefault(UINT32_MAX)
ProtoType.FIXED64, ProtoType.UINT64 -> value.isValidUnsignedIntegerDefault(UINT64_MAX)
ProtoType.INT32, ProtoType.SFIXED32, ProtoType.SINT32 -> {
value.isValidSignedIntegerDefault(INT32_MIN, INT32_MAX)
}
ProtoType.INT64, ProtoType.SFIXED64, ProtoType.SINT64 -> {
value.isValidSignedIntegerDefault(INT64_MIN, INT64_MAX)
}
else -> {
val valueType = linker.get(type)
valueType is EnumType && valueType.constant(value) != null
}
}

private const val INT32_MAX = "2147483647"
private const val INT32_MIN = "-2147483648"
private const val UINT32_MAX = "4294967295"
private const val INT64_MAX = "9223372036854775807"
private const val INT64_MIN = "-9223372036854775808"
private const val UINT64_MAX = "18446744073709551615"

private val DECIMAL_INTEGER_REGEX = Regex("-?[0-9]+")
private val FLOATING_POINT_REGEX = Regex("-?((([0-9]+)(\\.[0-9]*)?)|(\\.[0-9]+))([eE][+-]?[0-9]+)?")
private val HEX_INTEGER_REGEX = Regex("-?0[xX][0-9a-fA-F]+")

private fun String.isValidFloatingPointDefault(): Boolean {
if (this in listOf("inf", "-inf", "nan", "-nan")) return true
return FLOATING_POINT_REGEX.matches(this) && toDoubleOrNull() != null
}

private fun String.isValidSignedIntegerDefault(min: String, max: String): Boolean {
if (!isIntegerLiteral() || hasOctalPrefix()) return false
return compareIntegerLiterals(min, this) <= 0 && compareIntegerLiterals(this, max) <= 0
}

private fun String.isValidUnsignedIntegerDefault(max: String): Boolean {
if (!isIntegerLiteral() || hasOctalPrefix() || startsWith("-")) return false
return compareIntegerLiterals(this, max) <= 0
}

private fun String.isIntegerLiteral() = DECIMAL_INTEGER_REGEX.matches(this) || HEX_INTEGER_REGEX.matches(this)

private fun String.hasOctalPrefix(): Boolean {
val digits = removePrefix("-")
return digits.length > 1 && digits[0] == '0' && digits[1] != 'x' && digits[1] != 'X'
}

private fun compareIntegerLiterals(a: String, b: String): Int {
val aNegative = a.startsWith("-")
val bNegative = b.startsWith("-")
if (aNegative != bNegative) return if (aNegative) -1 else 1

val magnitudeComparison = compareIntegerMagnitudes(
a.removePrefix("-"),
b.removePrefix("-"),
)
return if (aNegative) -magnitudeComparison else magnitudeComparison
}

private fun compareIntegerMagnitudes(a: String, b: String): Int {
val aDecimal = a.decimalMagnitude()
val bDecimal = b.decimalMagnitude()
if (aDecimal.length != bDecimal.length) return aDecimal.length.compareTo(bDecimal.length)
return aDecimal.compareTo(bDecimal)
}

private fun String.decimalMagnitude(): String {
val stripped = if (startsWith("0x") || startsWith("0X")) {
hexMagnitudeToDecimal(substring(2))
} else {
trimStart('0').ifEmpty { "0" }
}
return stripped.trimStart('0').ifEmpty { "0" }
}

private fun hexMagnitudeToDecimal(hex: String): String {
var decimal = "0"
hex.forEach { digit ->
decimal = decimal.multiplyDecimalBy(16)
decimal = decimal.addDecimal(digit.digitToInt(16))
}
return decimal
}

private fun String.multiplyDecimalBy(multiplier: Int): String {
var carry = 0
val result = StringBuilder(length + 2)
for (i in indices.reversed()) {
val value = (this[i] - '0') * multiplier + carry
result.append(value % 10)
carry = value / 10
}
while (carry > 0) {
result.append(carry % 10)
carry /= 10
}
return result.reverse().toString().trimStart('0').ifEmpty { "0" }
}

private fun String.addDecimal(addend: Int): String {
var carry = addend
val result = StringBuilder(length + 2)
for (i in indices.reversed()) {
val value = (this[i] - '0') + carry
result.append(value % 10)
carry = value / 10
}
while (carry > 0) {
result.append(carry % 10)
carry /= 10
}
return result.reverse().toString().trimStart('0').ifEmpty { "0" }
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class Options(
}

is String -> {
validateOptionValue(linker, context, value)
return coerceValueForField(context, value, isRepeated)
}

Expand All @@ -233,6 +234,12 @@ class Options(
}
}

private fun validateOptionValue(linker: Linker, context: ProtoType, value: String) {
if (!isValidLiteral(linker, context, value)) {
linker.errors += "invalid option value \"$value\" for $context"
}
}

private fun coerceValueForField(context: ProtoType, value: Any, isRepeated: Boolean): Any = when {
isRepeated || context.isMap -> value as? List<*> ?: listOf(value)
value is List<*> -> value.single()!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,68 @@ class SchemaTest {
}
}

@Test
fun proto2AllowsValidDefaultValues() {
buildSchema {
add(
"defaults.proto".toPath(),
"""
|syntax = "proto2";
|
|message Defaults {
| optional bool bool_field = 1 [default = false];
| optional int32 int32_field = 2 [default = -0x80000000];
| optional uint32 uint32_field = 3 [default = 4294967295];
| optional int64 int64_field = 4 [default = -9223372036854775808];
| optional uint64 uint64_field = 5 [default = 18446744073709551615];
| optional float float_field = 6 [default = inf];
| optional double double_field = 7 [default = -1.23e45];
| optional string string_field = 8 [default = "source syntax is just data here"];
| optional bytes bytes_field = 9 [default = "abc"];
| optional Choice enum_field = 10 [default = TWO];
|
| enum Choice {
| ONE = 0;
| TWO = 1;
| }
|}
""".trimMargin(),
)
}
}

@Test
fun proto2RejectsInvalidDefaultValues() {
try {
buildSchema {
add(
"defaults.proto".toPath(),
"""
|syntax = "proto2";
|
|message Defaults {
| optional bool bool_field = 1 [default = "false; static { } //"];
| optional float float_field = 2 [default = "0.0f; init { } //"];
| optional double double_field = 3 [default = "0.0d; static { } //"];
| optional Choice enum_field = 4 [default = "ONE; static { } //"];
|
| enum Choice {
| ONE = 0;
| }
|}
""".trimMargin(),
)
}
fail()
} catch (expected: SchemaException) {
val message = expected.message!!
assertThat(message).contains("invalid default value \"false; static { } //\" for bool")
assertThat(message).contains("invalid default value \"0.0f; init { } //\" for float")
assertThat(message).contains("invalid default value \"0.0d; static { } //\" for double")
assertThat(message).contains("invalid default value \"ONE; static { } //\" for Defaults.Choice")
}
}

@Test
fun repeatedNumericScalarsShouldBePackedByDefaultForProto3() {
val schema = buildSchema {
Expand Down
Loading