Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.languagemodels.model;

import java.util.Map;

/**
* Abstract base class for Solr-managed wrappers around langchain4j used in {@code language-models} module
*/
public abstract class SolrLanguageModel {

// common parameters
protected static final String TIMEOUT_PARAM = "timeout";
protected static final String MAX_RETRIES_PARAM = "maxRetries";

protected final String name;
protected final Map<String, Object> params;

protected SolrLanguageModel(String name, Map<String, Object> params) {
this.name = name;
this.params = params;
}

public String getName() {
return name;
}

public Map<String, Object> getParams() {
return params;
}

/** Returns the class name of the underlying langchain4j model instance. */
public abstract String getModelClassName();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

Comment thread
nicolo-rinaldi marked this conversation as resolved.
/** Contains model related classes. */
package org.apache.solr.languagemodels.model;
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.languagemodels.textvectorisation.store;
package org.apache.solr.languagemodels.store;

public class TextToVectorModelException extends RuntimeException {
public class LanguageModelException extends RuntimeException {

private static final long serialVersionUID = 1L;

public TextToVectorModelException(String message) {
public LanguageModelException(String message) {
super(message);
}

public TextToVectorModelException(String message, Exception cause) {
public LanguageModelException(String message, Exception cause) {
super(message, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
*/

/** Contains model store related classes. */
package org.apache.solr.languagemodels.textvectorisation.store;
package org.apache.solr.languagemodels.store;
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.languagemodels.store.rest;

import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import net.jcip.annotations.ThreadSafe;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.languagemodels.model.SolrLanguageModel;
import org.apache.solr.languagemodels.store.LanguageModelException;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.rest.BaseSolrResource;
import org.apache.solr.rest.ManagedResource;
import org.apache.solr.rest.ManagedResourceStorage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Abstract base class for {@link ManagedResource} wrappers that expose a {@link LanguageModelStore}
* via the REST API. Concrete subclasses supply the REST endpoint and the model instantiation logic.
*/
@ThreadSafe
public abstract class ManagedLanguageModelStore<ModelT extends SolrLanguageModel> extends ManagedResource
implements ManagedResource.ChildResourceSupport {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the logger here, while for SolrLanguageModel it is in the classes that extend the abstract class?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it here due to the fact that most of the logic is in the function that are already in the abstract class

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the standard for it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loggers in Solr are overridden when used in the same way:

private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

An example of this is RequestHandlerBase, which has its own logger, and CollectionsHandler opens another logger for the specific class. Since the logger is not used in SolrLanguageModel, I decided to avoid adding an unused variable. This way of working seems to be the standard in Solr


private static final String MODELS_JSON_FIELD = "models";

protected static final String CLASS_KEY = "class";
protected static final String NAME_KEY = "name";
protected static final String PARAMS_KEY = "params";

private final LanguageModelStore store;
private Object managedData;

protected ManagedLanguageModelStore(
String resourceId, SolrResourceLoader loader, ManagedResourceStorage.StorageIO storageIO)
throws SolrException {
super(resourceId, loader, storageIO);
store = new LanguageModelStore();
}

/**
* Creates a model instance from the JSON map persisted in the managed resource storage.
*
* @param loader the resource loader for the current core
* @param modelMap a map containing {@code "class"}, {@code "name"}, and {@code "params"} keys
* @return the instantiated model
*/
protected abstract ModelT fromModelMap(SolrResourceLoader loader, Map<String, Object> modelMap);

private static LinkedHashMap<String, Object> toModelMap(SolrLanguageModel model) {
final LinkedHashMap<String, Object> modelMap = new LinkedHashMap<>(3, 1.0f);
modelMap.put(NAME_KEY, model.getName());
modelMap.put(CLASS_KEY, model.getModelClassName());
modelMap.put(PARAMS_KEY, model.getParams());
return modelMap;
}

@Override
protected void onManagedDataLoadedFromStorage(NamedList<?> managedInitArgs, Object managedData)
throws SolrException {
store.clear();
this.managedData = managedData;
}

public void loadStoredModels() {
log.info("------ managed models ~ loading ------");
if ((managedData != null) && (managedData instanceof List)) {
@SuppressWarnings("unchecked")
final List<Map<String, Object>> models = (List<Map<String, Object>>) managedData;
for (final Map<String, Object> model : models) {
addModelFromMap(model);
}
}
}

private void addModelFromMap(Map<String, Object> modelMap) {
try {
addModel(fromModelMap(solrResourceLoader, modelMap));
} catch (final LanguageModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}

public void addModel(ModelT model) throws SolrException {
try {
if (log.isInfoEnabled()) {
log.info("adding model {}", model.getName());
}
store.addModel(model);
} catch (final LanguageModelException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e);
}
}

@SuppressWarnings("unchecked")
@Override
protected Object applyUpdatesToManagedData(Object updates) {
if (updates instanceof List) {
final List<Map<String, Object>> models = (List<Map<String, Object>>) updates;
for (final Map<String, Object> model : models) {
addModelFromMap(model);
}
}
if (updates instanceof Map) {
addModelFromMap((Map<String, Object>) updates);
}
return modelsAsManagedResources(store.getModels());
}

@Override
public void doDeleteChild(BaseSolrResource endpoint, String childId) {
store.delete(childId);
storeManagedData(applyUpdatesToManagedData(null));
}

@Override
public void doGet(BaseSolrResource endpoint, String childId) {
final SolrQueryResponse response = endpoint.getSolrResponse();
response.add(MODELS_JSON_FIELD, modelsAsManagedResources(store.getModels()));
}

public ModelT getModel(String modelName) {
return store.getModel(modelName);
}

private static List<Object> modelsAsManagedResources(List<? extends SolrLanguageModel> models) {
return models.stream().map(ManagedLanguageModelStore::toModelMap).collect(Collectors.toList());
}

@Override
public String toString() {
return getClass().getSimpleName() + " [store=" + store + "]";
Comment thread
nicolo-rinaldi marked this conversation as resolved.
}


// Inner Data Structure to deal with Store persistence
private class LanguageModelStore {

private final Map<String, ModelT> availableModels;

public LanguageModelStore() {
availableModels = Collections.synchronizedMap(new LinkedHashMap<>());
}

public ModelT getModel(String name) {
return availableModels.get(name);
}

public void clear() {
availableModels.clear();
}

public List<ModelT> getModels() {
synchronized (availableModels) {
final List<ModelT> availableModelsValues = new ArrayList<>(availableModels.values());
return Collections.unmodifiableList(availableModelsValues);
}
}

@Override
public String toString() {
return "LanguageModelStore [availableModels=" + availableModels.keySet() + "]";
}

public ModelT delete(String modelName) {
return availableModels.remove(modelName);
}

public void addModel(ModelT modelData) throws LanguageModelException {
final String name = modelData.getName();
if (availableModels.putIfAbsent(name, modelData) != null) {
throw new LanguageModelException(
"model '" + name + "' already exists. Please use a different name");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/** Contains model store rest related classes. */
package org.apache.solr.languagemodels.store.rest;
Loading