Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Move compileJava() into ClassLoaderUtils #2600

Merged
merged 1 commit into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 3 additions & 137 deletions api/src/main/java/ai/djl/translate/ServingTranslatorFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,13 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Type;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;

/** A {@link TranslatorFactory} that creates a generic {@link Translator}. */
public class ServingTranslatorFactory implements TranslatorFactory {
Expand Down Expand Up @@ -99,94 +84,9 @@ public <I, O> Translator<I, O> newInstance(
}

private ServingTranslator findTranslator(Path path, String className) {
try {
Path classesDir = path.resolve("classes");
compileJavaClass(classesDir);

List<Path> jarFiles = new ArrayList<>();
if (Files.isDirectory(path)) {
try (Stream<Path> stream = Files.list(path)) {
stream.forEach(
p -> {
if (p.toString().endsWith(".jar")) {
jarFiles.add(p);
}
});
}
}
List<URL> urls = new ArrayList<>(jarFiles.size() + 1);
urls.add(classesDir.toUri().toURL());
for (Path p : jarFiles) {
urls.add(p.toUri().toURL());
}

ClassLoader parentCl = ClassLoaderUtils.getContextClassLoader();
ClassLoader cl = new URLClassLoader(urls.toArray(new URL[0]), parentCl);
if (className != null && !className.isEmpty()) {
logger.info("Trying to loading specified Translator: {}", className);
return initTranslator(cl, className);
}

ServingTranslator translator = scanDirectory(cl, classesDir);
if (translator != null) {
return translator;
}

for (Path p : jarFiles) {
translator = scanJarFile(cl, p);
if (translator != null) {
return translator;
}
}
} catch (IOException e) {
logger.debug("Failed to find Translator", e);
}
return null;
}

private ServingTranslator scanDirectory(ClassLoader cl, Path dir) throws IOException {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return null;
}
Collection<Path> files;
try (Stream<Path> stream = Files.walk(dir)) {
files =
stream.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class"))
.collect(Collectors.toList());
}
for (Path file : files) {
Path p = dir.relativize(file);
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
ServingTranslator translator = initTranslator(cl, className);
if (translator != null) {
logger.info("Found translator in model directory: {}", className);
return translator;
}
}
return null;
}

private ServingTranslator scanJarFile(ClassLoader cl, Path path) throws IOException {
try (JarFile jarFile = new JarFile(path.toFile())) {
Enumeration<JarEntry> en = jarFile.entries();
while (en.hasMoreElements()) {
JarEntry entry = en.nextElement();
String fileName = entry.getName();
if (fileName.endsWith(".class")) {
fileName = fileName.substring(0, fileName.lastIndexOf('.'));
fileName = fileName.replace('/', '.');
ServingTranslator translator = initTranslator(cl, fileName);
if (translator != null) {
logger.info("Found translator {} in jar {}", fileName, path);
return translator;
}
}
}
}
return null;
Path classesDir = path.resolve("classes");
ClassLoaderUtils.compileJavaClass(classesDir);
return ClassLoaderUtils.findImplementation(path, ServingTranslator.class, className);
}

private TranslatorFactory loadTranslatorFactory(String className) {
Expand All @@ -201,18 +101,6 @@ private TranslatorFactory loadTranslatorFactory(String className) {
return null;
}

private ServingTranslator initTranslator(ClassLoader cl, String className) {
try {
Class<?> clazz = Class.forName(className, true, cl);
Class<? extends ServingTranslator> subclass = clazz.asSubclass(ServingTranslator.class);
Constructor<? extends ServingTranslator> constructor = subclass.getConstructor();
return constructor.newInstance();
} catch (Throwable e) {
logger.trace("Not able to load Translator: " + className, e);
}
return null;
}

private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> arguments) {
String appName = ArgumentsUtil.stringValue(arguments, "application");
if (appName != null) {
Expand All @@ -228,26 +116,4 @@ private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> arguments
private Translator<Input, Output> getImageClassificationTranslator(Map<String, ?> arguments) {
return new ImageServingTranslator(ImageClassificationTranslator.builder(arguments).build());
}

private void compileJavaClass(Path dir) {
try {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return;
}
String[] files;
try (Stream<Path> stream = Files.walk(dir)) {
files =
stream.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".java"))
.map(p -> p.toAbsolutePath().toString())
.toArray(String[]::new);
}
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
if (files.length > 0) {
compiler.run(null, null, null, files);
}
} catch (Throwable e) {
logger.warn("Failed to compile bundled java file", e);
}
}
}
31 changes: 31 additions & 0 deletions api/src/main/java/ai/djl/util/ClassLoaderUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;

/** A utility class that load classes from specific URLs. */
public final class ClassLoaderUtils {
Expand Down Expand Up @@ -229,4 +233,31 @@ public static void nativeLoad(String nativeHelper, String path) {
throw new IllegalArgumentException("Invalid native_helper: " + nativeHelper, e);
}
}

/**
* Tries to compile java classes in the directory.
*
* @param dir the directory to scan java file.
*/
public static void compileJavaClass(Path dir) {
try {
if (!Files.isDirectory(dir)) {
logger.debug("Directory not exists: {}", dir);
return;
}
String[] files;
try (Stream<Path> stream = Files.walk(dir)) {
files =
stream.filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".java"))
.map(p -> p.toAbsolutePath().toString())
.toArray(String[]::new);
}
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
if (files.length > 0) {
compiler.run(null, null, null, files);
}
} catch (Throwable e) {
logger.warn("Failed to compile bundled java file", e);
}
}
}