Skip to content

Commit

Permalink
Merge pull request #5218 from inception-project/feature/5211-AI-assis…
Browse files Browse the repository at this point in the history
…tant-prototype

#5211 - AI assistant prototype
  • Loading branch information
reckart authored Jan 11, 2025
2 parents 57913bf + 393f7a8 commit 4821ade
Show file tree
Hide file tree
Showing 66 changed files with 790 additions and 541 deletions.
17 changes: 17 additions & 0 deletions inception/inception-app-webapp/src/test/resources/log4j2-test.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<Configuration status="WARN">
<Appenders>
<Console name="ConsoleAppender" target="SYSTEM_OUT">
<PatternLayout pattern="%d{yyyy-MM-dd HH:mm:ss} %level{length=5} %logger{1} - %msg%n" />
</Console>
</Appenders>

<Loggers>
<!-- <Logger name="de.tudarmstadt.ukp.dkpro.core.api.datasets.DatasetFactory" level="INFO"/> -->
<!-- <Logger name="de.tudarmstadt.ukp.inception.recommendation" level="TRACE"/> -->

<Root level="WARN">
<AppenderRef ref="ConsoleAppender" />
</Root>
</Loggers>
</Configuration>
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ public void processUserMessage(String aSessionOwner, Project aProject,
catch (IOException e) {
var errorMessage = MTextMessage.builder() //
.withActor("Error")
.withRole(SYSTEM) //
.withRole(SYSTEM).internal() //
.withMessage("Error: " + e.getMessage()) //
.build();
recordMessage(aSessionOwner, aProject, errorMessage);
Expand Down Expand Up @@ -350,10 +350,21 @@ private AssistentState getState(String aSessionOwner, Project aProject)
{
synchronized (states) {
return states.computeIfAbsent(new AssistentStateKey(aSessionOwner, aProject.getId()),
(v) -> new AssistentState());
(v) -> newState());
}
}

private AssistentState newState()
{
var state = new AssistentState();
// state.upsertMessage(MTextMessage.builder() //
// .withActor(properties.getNickname()) //
// .withRole(SYSTEM) //
// .withMessage("Hi") //
// .build());
return state;
}

private void clearState(Project aProject)
{
Validate.notNull(aProject, "Project must be specified");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,16 @@ public MTextMessage generate(List<MTextMessage> aMessasges,
var startTime = System.currentTimeMillis();
var response = ollamaClient.generate(properties.getUrl(), request,
msg -> streamMessage(aCallback, responseId, msg));
var tokens = response.getEvalCount();
var endTime = System.currentTimeMillis();

// Send a final and complete message also including final metrics
return newMessage(responseId)
.withMessage(response.getMessage().content()) //
.withPerformance(new MPerformanceMetrics(endTime - startTime)) //
.withPerformance(MPerformanceMetrics.builder() //
.withDuration(endTime - startTime) //
.withTokens(tokens) //
.build()) //
// Include all refs in the final message again just to be sure
.withReferences(references.values()) //
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public List<MTextMessage> retrieve(ChatContext aAssistant, MTextMessage aMessage
for (var chunk : chunks) {
var reference = MReference.builder() //
//.withId(String.valueOf(references.size() + 1)) //
.withId(UUID.randomUUID().toString()) //
.withId(UUID.randomUUID().toString().substring(0,8)) //
.withDocumentId(chunk.documentId()) //
.withDocumentName(chunk.documentName()) //
.withBegin(chunk.begin()) //
Expand Down Expand Up @@ -117,13 +117,21 @@ public List<MTextMessage> retrieve(ChatContext aAssistant, MTextMessage aMessage
Input:
{
"document": "The Eiffel Tower is located in Paris, France. It is one of the most famous landmarks in the world.",
"ref-id": "123"
"document": "The Eiffel Tower is located in Paris, France.",
"ref-id": "917"
}
{
"document": "It is one of the most famous landmarks in the world.",
"ref-id": "735"
}
{
"document": The Eiffel Tower was built from 1887 to 1889.",
"ref-id": "582"
}
Response:
The Eiffel Tower is located in Paris, France {{ref::123}}.
It is one of the most famous landmarks in the world {{ref::123}}.
The Eiffel Tower is a famous landmark located in Paris, France {{ref::917}} {{ref::735}}.
It was built from 1887 to 1889 {{ref::582}}.
Now, use the same pattern to process the following document:
""",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ private void autoDetectEmbeddingDimension()
embeddingProperties.getModel(), embeddingProperties.getDimension());
}
catch (Exception e) {
LOG.error("Unable to auto-detect embedding dimension - using default", e);
if (LOG.isDebugEnabled()) {
LOG.warn("Unable to auto-detect embedding dimension - using default", e);
}
else {
LOG.warn("Unable to auto-detect embedding dimension - using default");
}
embeddingProperties.setDimension(1024);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
* @param duration time it took to produce the messages in milliseconds
*/
@JsonSerialize
public record MPerformanceMetrics(long duration) {
public record MPerformanceMetrics(long duration, int tokens) {

private MPerformanceMetrics(Builder builder)
{
this(builder.duration);
this(builder.duration, builder.tokens);
}

public MPerformanceMetrics merge(MPerformanceMetrics aPerformance)
Expand All @@ -49,6 +49,7 @@ public static Builder builder()
public static final class Builder
{
private long duration;
private int tokens;

private Builder()
{
Expand All @@ -60,6 +61,12 @@ public Builder withDuration(long aDuration)
return this;
}

public Builder withTokens(int aTokens)
{
tokens = aTokens;
return this;
}

public MPerformanceMetrics build()
{
return new MPerformanceMetrics(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public List<MTextMessage> retrieve(ChatContext aAssistant, MTextMessage aMessage
{
var dtf = DateTimeFormatter.ofLocalizedDateTime(FormatStyle.MEDIUM);
return asList(MTextMessage.builder() //
.withActor("Current time provider").withRole(SYSTEM).internal() //
.withActor("Current time provider") //
.withRole(SYSTEM).internal() //
.withMessage("The current time is " + LocalDateTime.now(ZoneOffset.UTC).format(dtf)) //
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
interface MPerformanceMetrics {
duration: number,
tokens: number
}
interface MReference {
Expand Down Expand Up @@ -300,6 +301,7 @@
{#if message.performance}
<div class="message-footer">
<span><i class="far fa-clock me-1"/>{message.performance.duration / 1000}s</span>
<span><i class="far me-1"/>{(message.performance.tokens / (message.performance.duration / 1000)).toFixed(2)}tps</span>
</div>
{/if}
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
<!-- <Logger name="org.springframework.security" level="TRACE" /> -->
<!-- <Logger name="org.springframework.security.authorization" level="TRACE" /> -->
<!-- <Logger name="de.tudarmstadt.ukp.inception.support.test.websocket" level="TRACE" /> -->
<Logger name="de.tudarmstadt.ukp.inception.assistant" level="TRACE"/>
<!-- <Logger name="de.tudarmstadt.ukp.inception.assistant" level="TRACE"/> -->

<Root level="WARN">
<AppenderRef ref="ConsoleAppender" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import static de.tudarmstadt.ukp.inception.annotation.feature.link.LinkFeatureMultiplicityMode.MULTIPLE_TARGETS_MULTIPLE_ROLES;
import static de.tudarmstadt.ukp.inception.annotation.feature.link.LinkFeatureMultiplicityMode.MULTIPLE_TARGETS_ONE_ROLE;
import static de.tudarmstadt.ukp.inception.annotation.feature.link.LinkFeatureMultiplicityMode.ONE_TARGET_MULTIPLE_ROLES;
import static de.tudarmstadt.ukp.inception.annotation.layer.relation.RelationLayerSupport.FEAT_REL_SOURCE;
import static de.tudarmstadt.ukp.inception.annotation.layer.relation.RelationLayerSupport.FEAT_REL_TARGET;
import static de.tudarmstadt.ukp.inception.support.uima.AnnotationBuilder.buildAnnotation;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
Expand Down Expand Up @@ -79,7 +81,6 @@
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.dkpro.core.api.syntax.type.dependency.Dependency;
import de.tudarmstadt.ukp.inception.support.WebAnnoConst;

public class CasDiffTest
{
Expand Down Expand Up @@ -370,7 +371,7 @@ public void relationStackedSpansTest() throws Exception
casByUser.put("user2", jcasB.getCas());

var diffAdapters = asList(new RelationDiffAdapter("webanno.custom.Relation",
WebAnnoConst.FEAT_REL_TARGET, WebAnnoConst.FEAT_REL_SOURCE, "value"));
FEAT_REL_TARGET, FEAT_REL_SOURCE, "value"));

var result = doDiff(diffAdapters, casByUser).toResult();

Expand All @@ -379,7 +380,6 @@ public void relationStackedSpansTest() throws Exception
assertThat(result.getIncompleteConfigurationSets()).isEmpty();
assertThat(calculateState(result)).isEqualTo(AGREE);
}

}

@Nested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package de.tudarmstadt.ukp.clarin.webanno.curation.casdiff;

import static de.tudarmstadt.ukp.inception.annotation.layer.relation.RelationLayerSupport.FEAT_REL_SOURCE;
import static de.tudarmstadt.ukp.inception.annotation.layer.relation.RelationLayerSupport.FEAT_REL_TARGET;
import static de.tudarmstadt.ukp.inception.support.uima.AnnotationBuilder.buildAnnotation;
import static de.tudarmstadt.ukp.inception.support.uima.FeatureStructureBuilder.buildFS;
import static java.util.Arrays.asList;
Expand Down Expand Up @@ -49,7 +51,6 @@
import de.tudarmstadt.ukp.clarin.webanno.tsv.WebannoTsv3XReader;
import de.tudarmstadt.ukp.inception.annotation.layer.relation.RelationLayerSupport;
import de.tudarmstadt.ukp.inception.annotation.layer.span.SpanLayerSupport;
import de.tudarmstadt.ukp.inception.support.WebAnnoConst;

public class CurationTestUtils
{
Expand Down Expand Up @@ -202,8 +203,8 @@ public static TypeSystemDescription createCustomTypeSystem(String aType, String
else if (RelationLayerSupport.TYPE.equals(aType)) {
var td = type.addType(aTypeName, "", TYPE_NAME_ANNOTATION);

td.addFeature(WebAnnoConst.FEAT_REL_TARGET, "", aAttacheType);
td.addFeature(WebAnnoConst.FEAT_REL_SOURCE, "", aAttacheType);
td.addFeature(FEAT_REL_TARGET, "", aAttacheType);
td.addFeature(FEAT_REL_SOURCE, "", aAttacheType);

for (var feature : aFeatures) {
td.addFeature(feature, "", TYPE_NAME_STRING);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

package de.tudarmstadt.ukp.clarin.webanno.curation.casdiff;

import static de.tudarmstadt.ukp.inception.support.WebAnnoConst.RELATION_TYPE;
import static de.tudarmstadt.ukp.inception.annotation.layer.relation.RelationLayerSupport.FEAT_REL_SOURCE;
import static de.tudarmstadt.ukp.inception.annotation.layer.relation.RelationLayerSupport.FEAT_REL_TARGET;
import static java.util.Arrays.asList;
import static org.apache.uima.fit.factory.CollectionReaderFactory.createReader;

Expand Down Expand Up @@ -48,8 +49,8 @@

import de.tudarmstadt.ukp.clarin.webanno.tsv.WebannoTsv2Reader;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;
import de.tudarmstadt.ukp.inception.annotation.layer.relation.RelationLayerSupport;
import de.tudarmstadt.ukp.inception.annotation.layer.span.SpanLayerSupport;
import de.tudarmstadt.ukp.inception.support.WebAnnoConst;

public class DiffTestUtils
{
Expand Down Expand Up @@ -204,19 +205,19 @@ public static TypeSystemDescription createCustomTypeSystem(String aType, String
List<String> aFeatures, String aAttacheType)
throws Exception
{
TypeSystemDescription type = new TypeSystemDescription_impl();
var type = new TypeSystemDescription_impl();
if (SpanLayerSupport.TYPE.equals(aType)) {
TypeDescription td = type.addType(aTypeName, "", CAS.TYPE_NAME_ANNOTATION);
var td = type.addType(aTypeName, "", CAS.TYPE_NAME_ANNOTATION);
for (String feature : aFeatures) {
td.addFeature(feature, "", CAS.TYPE_NAME_STRING);
}

}
else if (aType.equals(RELATION_TYPE)) {
TypeDescription td = type.addType(aTypeName, "", CAS.TYPE_NAME_ANNOTATION);
else if (aType.equals(RelationLayerSupport.TYPE)) {
var td = type.addType(aTypeName, "", CAS.TYPE_NAME_ANNOTATION);

td.addFeature(WebAnnoConst.FEAT_REL_TARGET, "", aAttacheType);
td.addFeature(WebAnnoConst.FEAT_REL_SOURCE, "", aAttacheType);
td.addFeature(FEAT_REL_TARGET, "", aAttacheType);
td.addFeature(FEAT_REL_SOURCE, "", aAttacheType);

for (String feature : aFeatures) {
td.addFeature(feature, "", "uima.cas.String");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
<AppenderRef ref="ConsoleAppender" />
</Root>

<Logger name="de.tudarmstadt.ukp.clarin.webanno.curation.casdiff" level="TRACE" />
<Logger name="de.tudarmstadt.ukp.inception.curation.merge" level="TRACE" />
<Logger name="de.tudarmstadt.ukp.inception.curation.merge.strategy" level="TRACE" />
<!-- <Logger name="de.tudarmstadt.ukp.clarin.webanno.curation.casdiff" level="TRACE" /> -->
<!-- <Logger name="de.tudarmstadt.ukp.inception.curation.merge" level="TRACE" /> -->
<!-- <Logger name="de.tudarmstadt.ukp.inception.curation.merge.strategy" level="TRACE" /> -->
</Loggers>
</Configuration>
5 changes: 0 additions & 5 deletions inception/inception-diag/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-test</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.test.context.bean.override.mockito.MockitoBean;
import org.springframework.test.context.junit.jupiter.SpringExtension;

import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
Expand All @@ -54,7 +54,7 @@ static class Config
{
}

@MockBean
@MockitoBean
AnnotationSchemaService annotationService;

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.test.context.bean.override.mockito.MockitoBean;
import org.springframework.test.context.junit.jupiter.SpringExtension;

import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
Expand All @@ -53,7 +53,7 @@ static class Config
{
}

@MockBean
@MockitoBean
AnnotationSchemaService annotationService;

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mockito;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.test.context.bean.override.mockito.MockitoBean;
import org.springframework.test.context.junit.jupiter.SpringExtension;

import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
Expand All @@ -55,7 +55,7 @@ static class Config
{
}

@MockBean
@MockitoBean
private AnnotationSchemaService annotationService;

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,6 @@ public void test() throws Exception

assertFalse(result);

messages.forEach(System.out::println);
// messages.forEach(System.out::println);
}
}
Loading

0 comments on commit 4821ade

Please sign in to comment.