Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions bom/application/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,46 @@
</dependency>


<!-- Asynchronous NIO client server framework for the jvm -->
<!-- <dependency>
<!-- reactor-core: pin to 3.4.41 to resolve conflict between existing deps (3.3.16)
and langchain4j-azure-open-ai (3.4+). Required for ContextView. -->
<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-core</artifactId>
<version>3.4.41</version>
</dependency>

<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
</dependency>-->
<!-- Netty: pin to 4.1.118.Final to resolve conflict between pgjdbc-ng (4.1.63)
and langchain4j-azure-open-ai (4.1.118). Required for DefaultHeaders$ValueValidator. -->
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-common</artifactId>
<version>4.1.118.Final</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-buffer</artifactId>
<version>4.1.118.Final</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-transport</artifactId>
<version>4.1.118.Final</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver</artifactId>
<version>4.1.118.Final</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-codec</artifactId>
<version>4.1.118.Final</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
<version>4.1.118.Final</version>
</dependency>


<!--
Expand Down
5 changes: 5 additions & 0 deletions dotCMS/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,11 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
</dependency>
<dependency>
<!-- LangChain4J Azure OpenAI provider: Chat, Embedding, Image models via Azure OpenAI Service -->
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-open-ai</artifactId>
</dependency>
<dependency>
<groupId>jakarta.inject</groupId>
<artifactId>jakarta.inject-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package com.dotcms.ai.client.langchain4j;

import dev.langchain4j.model.azure.AzureOpenAiChatModel;
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
import dev.langchain4j.model.azure.AzureOpenAiImageModel;
import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
Expand All @@ -20,8 +24,8 @@
* To add support for a new provider, add a case to each switch block below.
* No other class needs to change.
*
* <p>Supported providers (Phase 1): {@code openai}
* <p>Planned (Phase 2): {@code azure_openai}, {@code bedrock}, {@code vertex_ai}
* <p>Supported providers: {@code openai}, {@code azure_openai}
* <p>Planned (Phase 2): {@code bedrock}, {@code vertex_ai}
*/
public class LangChain4jModelFactory {

Expand All @@ -35,7 +39,9 @@ private LangChain4jModelFactory() {}
* @throws IllegalArgumentException if config or provider is null, or the provider is unsupported
*/
public static ChatModel buildChatModel(final ProviderConfig config) {
return build(config, "chat", LangChain4jModelFactory::buildOpenAiChatModel);
return build(config, "chat",
LangChain4jModelFactory::buildOpenAiChatModel,
LangChain4jModelFactory::buildAzureOpenAiChatModel);
}

/**
Expand All @@ -46,7 +52,9 @@ public static ChatModel buildChatModel(final ProviderConfig config) {
* @throws IllegalArgumentException if config or provider is null, or the provider is unsupported
*/
public static StreamingChatModel buildStreamingChatModel(final ProviderConfig config) {
return build(config, "chat", LangChain4jModelFactory::buildOpenAiStreamingChatModel);
return build(config, "chat",
LangChain4jModelFactory::buildOpenAiStreamingChatModel,
LangChain4jModelFactory::buildAzureOpenAiStreamingChatModel);
}

/**
Expand All @@ -57,7 +65,9 @@ public static StreamingChatModel buildStreamingChatModel(final ProviderConfig co
* @throws IllegalArgumentException if config or provider is null, or the provider is unsupported
*/
public static EmbeddingModel buildEmbeddingModel(final ProviderConfig config) {
return build(config, "embeddings", LangChain4jModelFactory::buildOpenAiEmbeddingModel);
return build(config, "embeddings",
LangChain4jModelFactory::buildOpenAiEmbeddingModel,
LangChain4jModelFactory::buildAzureOpenAiEmbeddingModel);
}

/**
Expand All @@ -68,12 +78,15 @@ public static EmbeddingModel buildEmbeddingModel(final ProviderConfig config) {
* @throws IllegalArgumentException if config or provider is null, or the provider is unsupported
*/
public static ImageModel buildImageModel(final ProviderConfig config) {
return build(config, "image", LangChain4jModelFactory::buildOpenAiImageModel);
return build(config, "image",
LangChain4jModelFactory::buildOpenAiImageModel,
LangChain4jModelFactory::buildAzureOpenAiImageModel);
}

private static <T> T build(final ProviderConfig config,
final String modelType,
final Function<ProviderConfig, T> openAiFn) {
final Function<ProviderConfig, T> openAiFn,
final Function<ProviderConfig, T> azureOpenAiFn) {
if (config == null || config.provider() == null) {
throw new IllegalArgumentException("ProviderConfig or provider name is null for model type: " + modelType);
}
Expand All @@ -82,16 +95,24 @@ private static <T> T build(final ProviderConfig config,
case "openai":
validateOpenAi(config, modelType);
return openAiFn.apply(config);
case "azure_openai":
validateAzureOpenAi(config, modelType);
return azureOpenAiFn.apply(config);
default:
throw new IllegalArgumentException("Unsupported " + modelType + " provider: "
+ config.provider() + ". Supported in Phase 1: openai");
+ config.provider() + ". Supported: openai, azure_openai");
}
}

private static void validateOpenAi(final ProviderConfig config, final String modelType) {
requireNonBlank(config.apiKey(), "apiKey", modelType);
}

private static void validateAzureOpenAi(final ProviderConfig config, final String modelType) {
requireNonBlank(config.apiKey(), "apiKey", modelType);
requireNonBlank(config.endpoint(), "endpoint", modelType);
}

private static void requireNonBlank(final String value, final String field, final String modelType) {
if (value == null || value.isBlank()) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -162,4 +183,55 @@ private static ImageModel buildOpenAiImageModel(final ProviderConfig config) {
return builder.build();
}

// ── Azure OpenAI builders ─────────────────────────────────────────────────

private static StreamingChatModel buildAzureOpenAiStreamingChatModel(final ProviderConfig config) {
final AzureOpenAiStreamingChatModel.Builder builder = AzureOpenAiStreamingChatModel.builder()
.apiKey(config.apiKey())
.endpoint(config.endpoint())
.deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model());
if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion());
if (config.maxRetries() != null) builder.maxRetries(config.maxRetries());
if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout()));
if (config.temperature() != null) builder.temperature(config.temperature());
if (config.maxTokens() != null) builder.maxTokens(config.maxTokens());
return builder.build();
}

private static ChatModel buildAzureOpenAiChatModel(final ProviderConfig config) {
final AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.apiKey(config.apiKey())
.endpoint(config.endpoint())
.deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model());
if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion());
if (config.maxRetries() != null) builder.maxRetries(config.maxRetries());
if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout()));
if (config.temperature() != null) builder.temperature(config.temperature());
if (config.maxTokens() != null) builder.maxTokens(config.maxTokens());
return builder.build();
}

private static EmbeddingModel buildAzureOpenAiEmbeddingModel(final ProviderConfig config) {
final AzureOpenAiEmbeddingModel.Builder builder = AzureOpenAiEmbeddingModel.builder()
.apiKey(config.apiKey())
.endpoint(config.endpoint())
.deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model());
if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion());
if (config.maxRetries() != null) builder.maxRetries(config.maxRetries());
if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout()));
return builder.build();
}

private static ImageModel buildAzureOpenAiImageModel(final ProviderConfig config) {
final AzureOpenAiImageModel.Builder builder = AzureOpenAiImageModel.builder()
.apiKey(config.apiKey())
.endpoint(config.endpoint())
.deploymentName(config.deploymentName() != null ? config.deploymentName() : config.model());
if (config.apiVersion() != null) builder.serviceVersion(config.apiVersion());
if (config.maxRetries() != null) builder.maxRetries(config.maxRetries());
if (config.timeout() != null) builder.timeout(Duration.ofSeconds(config.timeout()));
if (config.size() != null) builder.size(config.size());
return builder.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,32 @@ public void test_buildChatModel_openai_missingApiKey_throws() {
assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config));
}

@Test
public void test_buildChatModel_azureOpenai_returnsModel() {
final ChatModel model = LangChain4jModelFactory.buildChatModel(azureOpenAiConfig("gpt-4o"));
assertNotNull(model);
}

@Test
public void test_buildChatModel_azureOpenai_missingApiKey_throws() {
final ProviderConfig config = ImmutableProviderConfig.builder()
.provider("azure_openai")
.model("gpt-4o")
.endpoint("https://my-company.openai.azure.com/")
.build();
assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config));
}

@Test
public void test_buildChatModel_azureOpenai_missingEndpoint_throws() {
final ProviderConfig config = ImmutableProviderConfig.builder()
.provider("azure_openai")
.model("gpt-4o")
.apiKey("test-key")
.build();
assertThrows(IllegalArgumentException.class, () -> LangChain4jModelFactory.buildChatModel(config));
}

@Test
public void test_buildChatModel_unknownProvider_throws() {
final ProviderConfig config = ImmutableProviderConfig.builder()
Expand All @@ -60,6 +86,12 @@ public void test_buildEmbeddingModel_openai_returnsModel() {
assertNotNull(model);
}

@Test
public void test_buildEmbeddingModel_azureOpenai_returnsModel() {
final EmbeddingModel model = LangChain4jModelFactory.buildEmbeddingModel(azureOpenAiConfig("text-embedding-ada-002"));
assertNotNull(model);
}

@Test
public void test_buildEmbeddingModel_unknownProvider_throws() {
final ProviderConfig config = ImmutableProviderConfig.builder()
Expand All @@ -81,6 +113,12 @@ public void test_buildImageModel_openai_returnsModel() {
assertNotNull(model);
}

@Test
public void test_buildImageModel_azureOpenai_returnsModel() {
final ImageModel model = LangChain4jModelFactory.buildImageModel(azureOpenAiConfig("dall-e-3"));
assertNotNull(model);
}

@Test
public void test_buildImageModel_unknownProvider_throws() {
final ProviderConfig config = ImmutableProviderConfig.builder()
Expand All @@ -99,4 +137,15 @@ private static ProviderConfig openAiConfig(final String model) {
.build();
}

private static ProviderConfig azureOpenAiConfig(final String model) {
return ImmutableProviderConfig.builder()
.provider("azure_openai")
.model(model)
.apiKey("test-key")
.endpoint("https://my-company.openai.azure.com/")
.deploymentName(model)
.apiVersion("2024-02-01")
.build();
}

}
Loading