Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import jakarta.enterprise.inject.spi.*;
import jakarta.inject.Inject;

import java.util.IdentityHashMap;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

/**
* Allocator instance that ties into the CDI container.
Expand All @@ -16,7 +18,8 @@
*/
@ApplicationScoped
public class CdiClassAllocator implements ClassAllocator {
private final Map<Class<?>, Bean<?>> classBeanMap = new IdentityHashMap<>();
private final Map<Class<?>, Bean<?>> classBeanMap = new WeakHashMap<>();
private final Lock lock = new ReentrantLock();
private final BeanManager beanManager;

@Inject
Expand All @@ -28,20 +31,25 @@ public CdiClassAllocator(@Nonnull BeanManager beanManager) {
@Override
@SuppressWarnings("unchecked")
public <T> T instance(@Nonnull Class<T> cls) throws AllocationException {
lock.lock();
try {
// Create bean
Bean<T> bean = (Bean<T>) classBeanMap.computeIfAbsent(cls, c -> {
AnnotatedType<T> annotatedClass = beanManager.createAnnotatedType(cls);
// TODO bugged.
// Equivalence check is based on the class name and does not include the loader.
AnnotatedType<T> annotatedClass = beanManager.createAnnotatedType((Class<T>) c);
BeanAttributes<T> attributes = beanManager.createBeanAttributes(annotatedClass);
InjectionTargetFactory<T> factory = beanManager.getInjectionTargetFactory(annotatedClass);
return beanManager.createBean(attributes, cls, factory);
return beanManager.createBean(attributes, (Class<T>) c, factory);
});
CreationalContext<T> creationalContext = beanManager.createCreationalContext(bean);

// Allocate instance of bean
return bean.create(creationalContext);
} catch (Throwable t) {
throw new AllocationException(cls, t);
} finally {
lock.unlock();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package software.coley.recaf.services.script;

import software.coley.recaf.util.CancelSignal;

/**
* Cancellation singleton.
* Injected into generated scripts.
* Do not use in normal code.
*
* @author xDark
*/
public final class CancellationSingleton {
private static volatile boolean shouldStop;

private CancellationSingleton() {
}

// Names and descriptors are known to the visitor/runtime.
// See InsertCancelSignalVisitor.
// See GenerateResult.
public static void stop() {
shouldStop = true;
}

public static void poll() {
if (shouldStop) {
throw CancelSignal.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jakarta.annotation.Nullable;
import software.coley.recaf.services.compile.CompilerDiagnostic;

import java.lang.reflect.InvocationTargetException;
import java.util.List;

/**
Expand All @@ -23,4 +24,23 @@ public record GenerateResult(@Nullable Class<?> cls, @Nonnull List<CompilerDiagn
public boolean wasSuccess() {
return cls != null;
}

/**
* Attempts to stop the script. If the generation failed,
* this method will do nothing.
*
* @throws IllegalStateException If something went wrong.
*/
public void requestStop() {
Class<?> cls = this.cls;
if (cls == null) return;
try {
Class<?> cancellationSingleton = cls.getClassLoader().loadClass(CancellationSingleton.class.getName());
cancellationSingleton.getDeclaredMethod("stop").invoke(null);
} catch (InvocationTargetException ex) {
throw new IllegalStateException(ex.getTargetException());
} catch (Exception ex) {
throw new IllegalStateException(ex);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package software.coley.recaf.services.script;

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;
import software.coley.recaf.RecafConstants;
import software.coley.recaf.util.ReflectUtil;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;

import static org.objectweb.asm.Opcodes.*;

final class InsertCancelSignalVisitor extends ClassVisitor {
private static final MethodHandle GET_OFFSET;
private static final String SINGLETON_TYPE = Type.getInternalName(CancellationSingleton.class);
private static final String POLL = "poll";
private static final String POLL_DESC = "()V";

static {
try {
GET_OFFSET = ReflectUtil.lookup().findVirtual(Label.class, "getOffset", MethodType.methodType(int.class));
} catch (ReflectiveOperationException ex) {
throw new ExceptionInInitializerError(ex);
}
}

InsertCancelSignalVisitor(ClassVisitor cv) {
super(RecafConstants.getAsmVersion(), cv);
}

private static boolean isLoopback(Label label) {
try {
int ignored = (int) GET_OFFSET.invokeExact(label);
return true;
} catch (IllegalStateException ex) {
return false;
} catch (Throwable t) {
throw new RuntimeException(t);
}
}

@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature,
String[] exceptions) {
return new MethodVisitor(api, super.visitMethod(access, name, descriptor, signature, exceptions)) {

private boolean isAnyLoopback(Label[] labels) {
for (Label label : labels) {
if (isLoopback(label)) {
return true;
}
}
return false;
}

private void insertPoll(Label first, Label... labels) {
if (!(isLoopback(first) || isAnyLoopback(labels)))
return;
super.visitMethodInsn(INVOKESTATIC, SINGLETON_TYPE, POLL, POLL_DESC, false);
}

@Override
public void visitJumpInsn(int opcode, Label label) {
insertPoll(label);
super.visitJumpInsn(opcode, label);
}

@Override
public void visitLookupSwitchInsn(Label dflt, int[] keys, Label[] labels) {
insertPoll(dflt, labels);
super.visitLookupSwitchInsn(dflt, keys, labels);
}

@Override
public void visitTableSwitchInsn(int min, int max, Label dflt, Label... labels) {
insertPoll(dflt, labels);
super.visitTableSwitchInsn(min, max, dflt, labels);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
import jakarta.annotation.Nonnull;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Type;
import regexodus.Matcher;
import software.coley.recaf.analytics.logging.DebuggingLogger;
import software.coley.recaf.analytics.logging.Logging;
import software.coley.recaf.services.compile.*;
import software.coley.recaf.services.plugin.CdiClassAllocator;
import software.coley.recaf.util.ClassDefiner;
import software.coley.recaf.util.ReflectUtil;
import software.coley.recaf.util.RegexUtil;
import software.coley.recaf.util.StringUtil;
import software.coley.recaf.util.threading.ThreadPoolFactory;

import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -73,7 +78,7 @@ public class JavacScriptEngine implements ScriptEngine {
"jakarta.inject.*",
"org.slf4j.Logger"
);
private final Map<Integer, GenerateResult> generateResultMap = new HashMap<>();
private final Map<Integer, ScriptTemplate> generateResultMap = new HashMap<>();
private final ExecutorService compileAndRunPool = ThreadPoolFactory.newSingleThreadExecutor("script-loader");
private final JavacCompiler compiler;
private final CdiClassAllocator allocator;
Expand Down Expand Up @@ -160,11 +165,11 @@ private GenerateResult generate(@Nonnull String script) {
GenerateResult result;
if (RegexUtil.matchesAny(PATTERN_CLASS_NAME, script)) {
logger.debugging(l -> l.info("Compiling script as class"));
result = generateResultMap.computeIfAbsent(hash, n -> generateStandardClass(script));
result = generateResultMap.computeIfAbsent(hash, n -> generateStandardClass(script)).generateResult();
} else {
logger.debugging(l -> l.info("Compiling script as function"));
String className = "Script" + Math.abs(hash);
result = generateResultMap.computeIfAbsent(hash, n -> generateScriptClass(className, script));
result = generateResultMap.computeIfAbsent(hash, n -> generateScriptClass(className, script)).generateResult();
}
return result;
}
Expand All @@ -179,7 +184,7 @@ private GenerateResult generate(@Nonnull String script) {
* @return Compiler result wrapper containing the loaded class reference.
*/
@Nonnull
private GenerateResult generateStandardClass(@Nonnull String source) {
private ScriptTemplate generateStandardClass(@Nonnull String source) {
String originalSource = source;

// Extract package name
Expand Down Expand Up @@ -212,7 +217,7 @@ private GenerateResult generateStandardClass(@Nonnull String source) {
source = source.replace(" " + originalName + "(", " " + modifiedName + "(");
source = source.replace("\t" + originalName + "(", "\t" + modifiedName + "(");
} else {
return new GenerateResult(null, List.of(
return new ScriptTemplate.Failed(List.of(
new CompilerDiagnostic(-1, -1, 0,
"Could not determine name of class", CompilerDiagnostic.Level.ERROR)));
}
Expand All @@ -233,7 +238,7 @@ private GenerateResult generateStandardClass(@Nonnull String source) {
* @return Compiler result wrapper containing the loaded class reference.
*/
@Nonnull
private GenerateResult generateScriptClass(@Nonnull String className, @Nonnull String script) {
private ScriptTemplate generateScriptClass(@Nonnull String className, @Nonnull String script) {
String originalSource = script;
Set<String> imports = new HashSet<>(DEFAULT_IMPORTS);
Matcher matcher = RegexUtil.getMatcher(PATTERN_IMPORT, script);
Expand Down Expand Up @@ -264,6 +269,24 @@ private GenerateResult generateScriptClass(@Nonnull String className, @Nonnull S
return generate(className, originalSource, code.toString());
}

@Nonnull
private byte[] postProcessClass(@Nonnull byte[] classBytes) {
ClassReader cr = new ClassReader(classBytes);
ClassWriter cw = new ClassWriter(cr, 0);
cr.accept(new InsertCancelSignalVisitor(cw), 0);
return cw.toByteArray();
}

private void injectClasses(@Nonnull Map<String, byte[]> classMap) {
for (Class<?> c : List.of(CancellationSingleton.class)) {
try (InputStream in = c.getClassLoader().getResourceAsStream(Type.getInternalName(c) + ".class")) {
classMap.put(c.getName(), in.readAllBytes());
} catch (IOException ex) {
throw new UncheckedIOException(ex);
}
}
}

/**
* @param className
* Name of the script class.
Expand All @@ -275,26 +298,22 @@ private GenerateResult generateScriptClass(@Nonnull String className, @Nonnull S
* @return Compiler result wrapper containing the loaded class reference.
*/
@Nonnull
private GenerateResult generate(@Nonnull String className,
private ScriptTemplate generate(@Nonnull String className,
@Nonnull String originalSource,
@Nonnull String compileSource) {
JavacArguments args = new JavacArgumentsBuilder()
.withClassName(className)
.withClassSource(compileSource)
.build();
CompilerResult result = compiler.compile(args, null, null);
List<CompilerDiagnostic> diagnostics = mapDiagnostics(originalSource, compileSource, result.getDiagnostics());
if (result.wasSuccess()) {
try {
Map<String, byte[]> classes = result.getCompilations().entrySet().stream()
.collect(Collectors.toMap(e -> e.getKey().replace('/', '.'), Map.Entry::getValue));
ClassDefiner definer = new ClassDefiner(classes);
Class<?> cls = definer.findClass(className.replace('/', '.'));
return new GenerateResult(cls, mapDiagnostics(originalSource, compileSource, result.getDiagnostics()));
} catch (Exception ex) {
logger.error("Failed to define generated script class", ex);
}
Map<String, byte[]> classes = result.getCompilations().entrySet().stream()
.collect(Collectors.toMap(e -> e.getKey().replace('/', '.'), e -> postProcessClass(e.getValue())));
injectClasses(classes);
return new ScriptTemplate.Generated(className.replace('/', '.'), Map.copyOf(classes), diagnostics);
}
return new GenerateResult(null, mapDiagnostics(originalSource, compileSource, result.getDiagnostics()));
return new ScriptTemplate.Failed(diagnostics);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package software.coley.recaf.services.script;

import software.coley.recaf.analytics.logging.DebuggingLogger;
import software.coley.recaf.analytics.logging.Logging;
import software.coley.recaf.services.compile.CompilerDiagnostic;
import software.coley.recaf.util.ClassDefiner;

import java.util.List;
import java.util.Map;

sealed interface ScriptTemplate {

GenerateResult generateResult();

record Generated(
String className,
Map<String, byte[]> classMap,
List<CompilerDiagnostic> diagnostics
) implements ScriptTemplate {
private static final DebuggingLogger logger = Logging.get(Generated.class);

@Override
public GenerateResult generateResult() {
try {
ClassDefiner definer = new ClassDefiner(classMap);
Class<?> cls = definer.findClass(className);
return new GenerateResult(cls, diagnostics);
} catch (Exception ex) {
logger.error("Failed to define generated script class", ex);
}
return new GenerateResult(null, diagnostics);
}
}

record Failed(List<CompilerDiagnostic> diagnostics) implements ScriptTemplate {

@Override
public GenerateResult generateResult() {
return new GenerateResult(null, diagnostics);
}
}
}
Loading