Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
import java.math.BigInteger;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.IntFunction;
import org.apache.commons.lang3.ClassUtils;
import org.apache.hadoop.hdds.protocol.proto.HddsProtos.ContainerID;
Expand Down Expand Up @@ -55,7 +59,6 @@ public final class ScmCodecFactory {
putProto(DeletedBlocksTransaction.getDefaultInstance());
putProto(DeletedBlocksTransactionSummary.getDefaultInstance());

codecs.put(List.class, new ScmListCodec());
codecs.put(Integer.class, new ScmIntegerCodec());
codecs.put(Long.class, new ScmLongCodec());
codecs.put(String.class, new ScmStringCodec());
Expand All @@ -69,6 +72,10 @@ public final class ScmCodecFactory {
putEnum(LifeCycleEvent.class, LifeCycleEvent::forNumber);
putEnum(PipelineState.class, PipelineState::forNumber);
putEnum(NodeType.class, NodeType::forNumber);

// Must be the last one
final ClassResolver resolver = new ClassResolver(codecs.keySet());
codecs.put(List.class, new ScmListCodec(resolver));
}

static <T extends Message> void putProto(T proto) {
Expand Down Expand Up @@ -97,4 +104,47 @@ public static ScmCodec getCodec(Class<?> type)
throw new InvalidProtocolBufferException(
"Codec for " + type + " not found!");
}

/** Resolve the codec class from a given class. */
static class ClassResolver {
private final Map<String, Class<?>> provided;
private final Map<String, Class<?>> resolved = new ConcurrentHashMap<>();

ClassResolver(Collection<Class<?>> provided) {
final Map<String, Class<?>> map = new TreeMap<>();
for (Class<?> c : provided) {
map.put(c.getName(), c);
}
map.put(List.class.getName(), List.class);
this.provided = Collections.unmodifiableMap(map);
}

Class<?> get(String className) throws InvalidProtocolBufferException {
final Class<?> c = provided.get(className);
if (c != null) {
return c;
}
throw new InvalidProtocolBufferException("Class not found for " + className);
}

Class<?> get(Class<?> clazz) throws InvalidProtocolBufferException {
final String className = clazz.getName();
final Class<?> c = provided.get(className);
if (c != null) {
return c;
}
final Class<?> found = resolved.get(className);
if (found != null) {
return found;
}

for (Class<?> base : provided.values()) {
if (base.isAssignableFrom(clazz)) {
resolved.put(className, base);
return base;
}
}
throw new InvalidProtocolBufferException("Failed to resolve " + clazz);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,65 +17,67 @@

package org.apache.hadoop.hdds.scm.ha.io;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.hdds.protocol.proto.SCMRatisProtocol.ListArgument;
import org.apache.hadoop.hdds.scm.ha.ReflectionUtil;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException;

/**
* {@link ScmCodec} for {@link List} objects.
*/
public class ScmListCodec implements ScmCodec<Object> {
class ScmListCodec implements ScmCodec<Object> {
private static final ByteString EMPTY_LIST = ListArgument.newBuilder()
.setType(Object.class.getName())
.build()
.toByteString();

private final ScmCodecFactory.ClassResolver resolver;

ScmListCodec(ScmCodecFactory.ClassResolver resolver) {
this.resolver = resolver;
}

@Override
public ByteString serialize(Object object)
throws InvalidProtocolBufferException {
final ListArgument.Builder listArgs = ListArgument.newBuilder();
final List<?> values = (List<?>) object;
if (!values.isEmpty()) {
Class<?> type = values.get(0).getClass();
listArgs.setType(type.getName());
for (Object value : values) {
listArgs.addValue(ScmCodecFactory.getCodec(type).serialize(value));
}
} else {
listArgs.setType(Object.class.getName());
public ByteString serialize(Object object) throws InvalidProtocolBufferException {
if (!(object instanceof List)) {
throw new InvalidProtocolBufferException(
"Unexpected non-list object: " + object.getClass());
}
final List<?> elements = (List<?>) object;
if (elements.isEmpty()) {
return EMPTY_LIST;
}

final Class<?> resolved = resolver.get(elements.get(0).getClass());
final ScmCodec<Object> elementCodec = ScmCodecFactory.getCodec(resolved);
final ListArgument.Builder builder = ListArgument.newBuilder()
.setType(resolved.getName());
for (Object e : elements) {
builder.addValue(elementCodec.serialize(e));
}
return listArgs.build().toByteString();
return builder.build().toByteString();
}

@Override
public Object deserialize(Class<?> type, ByteString value)
throws InvalidProtocolBufferException {
try {
// If argument type is the generic interface, then determine a
// concrete implementation.
Class<?> concreteType = (type == List.class) ? ArrayList.class : type;

List<Object> result = (List<Object>) concreteType.newInstance();
final ListArgument listArgs = (ListArgument) ReflectionUtil
.getMethod(ListArgument.class, "parseFrom", byte[].class)
.invoke(null, (Object) value.toByteArray());

// proto2 required-equivalent check
if (!listArgs.hasType()) {
throw new InvalidProtocolBufferException("Missing ListArgument.type");
}

final Class<?> dataType = ReflectionUtil.getClass(listArgs.getType());
for (ByteString element : listArgs.getValueList()) {
result.add(ScmCodecFactory.getCodec(dataType)
.deserialize(dataType, element));
}
return result;
} catch (InstantiationException | NoSuchMethodException |
IllegalAccessException | InvocationTargetException |
ClassNotFoundException ex) {
if (!List.class.isAssignableFrom(type)) {
throw new InvalidProtocolBufferException(
"Message cannot be decoded: " + ex.getMessage());
"Unexpected non-list type: " + type);
}
final ListArgument argument = ListArgument.parseFrom(
value.asReadOnlyByteBuffer());
if (!argument.hasType()) {
throw new InvalidProtocolBufferException(
"Missing ListArgument.type: " + argument);
}
final Class<?> elementClass = resolver.get(argument.getType());
final ScmCodec<?> elementCodec = ScmCodecFactory.getCodec(elementClass);
final List<Object> list = new ArrayList<>();
for (ByteString element : argument.getValueList()) {
list.add(elementCodec.deserialize(elementClass, element));
}
return list;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import java.util.List;
import org.apache.hadoop.hdds.protocol.proto.HddsProtos;
import org.apache.hadoop.hdds.protocol.proto.SCMRatisProtocol;
import org.apache.hadoop.hdds.scm.ha.io.ScmListCodec;
import org.apache.hadoop.hdds.scm.pipeline.PipelineID;
import org.apache.ratis.protocol.Message;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
Expand Down Expand Up @@ -206,22 +205,4 @@ public void testDecodeMissingArgumentValueShouldFail() throws Exception {

assertTrue(ex.getMessage().contains("Missing argument value"));
}

@Test
public void testListDecodeMissingTypeShouldFail() throws Exception {
// ListArgument without type
SCMRatisProtocol.ListArgument listArg =
SCMRatisProtocol.ListArgument.newBuilder()
// no type
.addValue(ByteString.copyFromUtf8("x"))
.build();

ScmListCodec codec = new ScmListCodec();

InvalidProtocolBufferException ex = assertThrows(
InvalidProtocolBufferException.class,
() -> codec.deserialize(List.class, listArg.toByteString()));

assertTrue(ex.getMessage().contains("Missing ListArgument.type"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.hdds.scm.ha.io;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.Collections;
import java.util.List;
import org.apache.hadoop.hdds.protocol.proto.SCMRatisProtocol;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
import org.apache.ratis.thirdparty.com.google.protobuf.InvalidProtocolBufferException;
import org.junit.jupiter.api.Test;

/**
* Tests for {@link ScmListCodec}.
*/
public class TestScmListCodec {

@Test
public void testListDecodeMissingTypeShouldFail() throws Exception {
// ListArgument without type
SCMRatisProtocol.ListArgument listArg =
SCMRatisProtocol.ListArgument.newBuilder()
// no type
.addValue(ByteString.copyFromUtf8("x"))
.build();

ScmListCodec codec = new ScmListCodec(
new ScmCodecFactory.ClassResolver(Collections.emptyList()));

InvalidProtocolBufferException ex = assertThrows(
InvalidProtocolBufferException.class,
() -> codec.deserialize(List.class, listArg.toByteString()));

assertTrue(ex.getMessage().contains("Missing ListArgument.type"));
}
}