Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "bugfix",
"category": "Amazon SNS Message Manager",
"description": "Fixed `SnsMessageManager` rejecting valid SNS messages whose signature timestamp falls on a whole second (zero milliseconds): the canonical string was rebuilt with `Instant#toString()`, which drops the `.000` fraction and no longer matched the value Amazon SNS signed.",
"contributor": "henricook"
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureException;
import java.time.Instant;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.Locale;
import java.util.StringJoiner;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkBytes;
Expand Down Expand Up @@ -53,6 +57,9 @@ public final class SignatureValidator {

private static final String NEWLINE = "\n";

private static final DateTimeFormatter TIMESTAMP_FORMATTER =
DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'", Locale.ROOT).withZone(ZoneOffset.UTC);

public void validateSignature(SnsMessage message, PublicKey publicKey) {
Validate.paramNotNull(message, "message");
Validate.paramNotNull(publicKey, "publicKey");
Expand Down Expand Up @@ -101,7 +108,7 @@ private static String buildCanonicalMessage(SnsNotification notification) {
joiner.add(SUBJECT).add(notification.subject());
}

joiner.add(TIMESTAMP).add(notification.timestamp().toString());
joiner.add(TIMESTAMP).add(formatTimestamp(notification.timestamp()));
joiner.add(TOPIC_ARN).add(notification.topicArn());
joiner.add(TYPE).add(notification.type().toString());

Expand All @@ -114,7 +121,7 @@ private static String buildCanonicalMessage(SnsSubscriptionConfirmation message)
joiner.add(MESSAGE).add(message.message());
joiner.add(MESSAGE_ID).add(message.messageId());
joiner.add(SUBSCRIBE_URL).add(message.subscribeUrl().toString());
joiner.add(TIMESTAMP).add(message.timestamp().toString());
joiner.add(TIMESTAMP).add(formatTimestamp(message.timestamp()));
joiner.add(TOKEN).add(message.token());
joiner.add(TOPIC_ARN).add(message.topicArn());
joiner.add(TYPE).add(message.type().toString());
Expand All @@ -127,14 +134,18 @@ private static String buildCanonicalMessage(SnsUnsubscribeConfirmation message)
joiner.add(MESSAGE).add(message.message());
joiner.add(MESSAGE_ID).add(message.messageId());
joiner.add(SUBSCRIBE_URL).add(message.subscribeUrl().toString());
joiner.add(TIMESTAMP).add(message.timestamp().toString());
joiner.add(TIMESTAMP).add(formatTimestamp(message.timestamp()));
joiner.add(TOKEN).add(message.token());
joiner.add(TOPIC_ARN).add(message.topicArn());
joiner.add(TYPE).add(message.type().toString());

return joiner.toString();
}

private static String formatTimestamp(Instant timestamp) {
return TIMESTAMP_FORMATTER.format(timestamp);
}

private static void verifySignature(String canonicalMessage, SdkBytes messageSignature, PublicKey publicKey,
Signature signature) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,33 @@

package software.amazon.awssdk.messagemanager.sns.internal;

import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.Signature;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.util.Base64;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.messagemanager.sns.model.SignatureVersion;
Expand All @@ -44,12 +53,17 @@ class SignatureValidatorTest {
private static final String SIGNING_CERT_RESOURCE = "SimpleNotificationService-7506a1e35b36ef5a444dd1a8e7cc3ed8.pem";
private static final SignatureValidator VALIDATOR = new SignatureValidator();
private static X509Certificate signingCertificate;
private static KeyPair signingKeyPair;

@BeforeAll
static void setup() throws CertificateException {
static void setup() throws Exception {
InputStream is = resourceAsStream(SIGNING_CERT_RESOURCE);
CertificateFactory factory = CertificateFactory.getInstance("X.509");
signingCertificate = (X509Certificate) factory.generateCertificate(is);

KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
keyPairGenerator.initialize(2048);
signingKeyPair = keyPairGenerator.generateKeyPair();
}

@ParameterizedTest(name = "{0}")
Expand All @@ -60,6 +74,18 @@ void validateSignature_signatureValid_doesNotThrow(TestCase tc) {
VALIDATOR.validateSignature(msg, signingCertificate.getPublicKey());
}

@ParameterizedTest(name = "timestamp={0}")
@ValueSource(strings = {
"2024-01-01T00:00:00.000Z", // whole second: Instant#toString() drops the ".000", changing the canonical string
"2024-06-15T12:30:45.123Z" // non-zero milliseconds: control that validates regardless
})
void validateSignature_signatureCoversRawMillisecondTimestamp_doesNotThrow(String timestamp) throws Exception {
SnsMessage notification = signedNotification(timestamp, signingKeyPair);

assertThatCode(() -> VALIDATOR.validateSignature(notification, signingKeyPair.getPublic()))
.doesNotThrowAnyException();
}

@Test
void validateSignature_signatureMismatch_throws() {
SnsNotification notification = SnsNotification.builder()
Expand Down Expand Up @@ -182,6 +208,40 @@ private static InputStream resourceAsStream(String resourceName) {
return SignatureValidatorTest.class.getResourceAsStream(RESOURCE_ROOT + resourceName);
}

private static SnsMessage signedNotification(String timestamp, KeyPair keyPair) throws Exception {
String message = "This notification is signed over a millisecond-precision timestamp.";
String messageId = "11111111-2222-3333-4444-555555555555";
String topicArn = "arn:aws:sns:us-east-1:123456789012:my-topic";

String canonicalMessage = String.join("\n",
"Message", message,
"MessageId", messageId,
"Timestamp", timestamp,
"TopicArn", topicArn,
"Type", "Notification") + "\n";

String signature = sign(canonicalMessage, keyPair.getPrivate());

String json = "{\n"
+ " \"Type\" : \"Notification\",\n"
+ " \"MessageId\" : \"" + messageId + "\",\n"
+ " \"TopicArn\" : \"" + topicArn + "\",\n"
+ " \"Message\" : \"" + message + "\",\n"
+ " \"Timestamp\" : \"" + timestamp + "\",\n"
+ " \"SignatureVersion\" : \"1\",\n"
+ " \"Signature\" : \"" + signature + "\"\n"
+ "}";

return new SnsMessageUnmarshaller().unmarshall(new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)));
}

private static String sign(String canonicalMessage, PrivateKey privateKey) throws Exception {
Signature signer = Signature.getInstance("SHA1withRSA");
signer.initSign(privateKey);
signer.update(canonicalMessage.getBytes(StandardCharsets.UTF_8));
return Base64.getEncoder().encodeToString(signer.sign());
}

private static class TestCase {
private String desription;
private String messageJsonResource;
Expand Down