Skip to content

Improve ollama container reuse feature and cache model #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Spring-TestContainers

[![Build Status](https://github.com/flowinquiry/spring-testcontainers/actions/workflows/gradle.yml/badge.svg)](https://github.com/flowinquiry/spring-testcontainers/actions/workflows/gradle.yml)
[![Maven Central](https://img.shields.io/maven-central/v/io.flowinquiry.testcontainers/spring-testcontainers?label=Maven%20Central)](https://search.maven.org/artifact/io.flowinquiry.testcontainers/spring-testcontainers)

Spring-TestContainers is a Java library that makes it easier to write integration tests with Testcontainers, especially when you're using Spring or Spring Boot. It handles the setup and lifecycle of containers for you, so you can focus on testing—not boilerplate.

We originally built this for FlowInquiry to make our own testing smoother. It worked so well, we decided to share it as a standalone library so other teams can take advantage of it too.
Expand Down Expand Up @@ -146,9 +149,9 @@ Add the core library along with the database module(s) you plan to use. Each dat

```kotlin
// Add one or more of the following database modules
testImplementation("io.flowinquiry.testcontainers:postgresql:0.9.1") // PostgreSQL support
testImplementation("io.flowinquiry.testcontainers:mysql:0.9.1") // MySQL support
testImplementation("io.flowinquiry.testcontainers:ollama:0.9.1") // Ollama support
testImplementation("io.flowinquiry.testcontainers:postgresql:<!-- Replace with the latest version -->") // PostgreSQL support
testImplementation("io.flowinquiry.testcontainers:mysql:<!-- Replace with the latest version -->") // MySQL support
testImplementation("io.flowinquiry.testcontainers:ollama:<!-- Replace with the latest version -->") // Ollama support
```

### Maven
Expand All @@ -161,7 +164,7 @@ testImplementation("io.flowinquiry.testcontainers:ollama:0.9.1") // Ollama s
<dependency>
<groupId>io.flowinquiry.testcontainers</groupId>
<artifactId>postgresql</artifactId>
<version>0.9.0</version>
<version><!-- Replace with the latest version --></version>
<scope>test</scope>
</dependency>
<dependency>
Expand All @@ -170,15 +173,15 @@ testImplementation("io.flowinquiry.testcontainers:ollama:0.9.1") // Ollama s
<dependency>
<groupId>io.flowinquiry.testcontainers</groupId>
<artifactId>mysql</artifactId>
<version>0.9.1</version>
<version><!-- Replace with the latest version --></version>
<scope>test</scope>
</dependency>

<!-- Add this dependency to test Ollama container -->
<dependency>
<groupId>io.flowinquiry.testcontainers</groupId>
<artifactId>ollama</artifactId>
<version>0.9.1</version>
<version><!-- Replace with the latest version --></version>
<scope>test</scope>
</dependency>
```
Expand Down
1 change: 1 addition & 0 deletions examples/springboot-ollama/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies {
implementation(libs.bundles.spring.ai)
testImplementation(platform(libs.junit.bom))
testImplementation(libs.junit.jupiter)
testImplementation(libs.junit.jupiter.params)
testImplementation(libs.junit.platform.launcher)
testImplementation(libs.spring.boot.starter.test)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import io.flowinquiry.testcontainers.ai.OllamaOptions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.slf4j.Logger;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -23,7 +25,7 @@
@EnableOllamaContainer(
dockerImage = "ollama/ollama",
version = "0.9.0",
model = "smollm2:135m",
model = "llama3:latest",
options = @OllamaOptions(temperature = "0.7", topP = "0.5"))
@ActiveProfiles("test")
public class OllamaDemoAppTest {
Expand Down Expand Up @@ -53,16 +55,21 @@ public void testHealthEndpoint() {
assertTrue(response.contains("Ollama Chat Controller is up and running"));
}

@Test
public void testChatClient() {
@ParameterizedTest
@CsvSource({
"What is the result of 1+2? Give the value only, 3",
"How many letter 'r' in the word 'Hello'? Give the value only, 0"
})
public void testChatClient(String prompt, String expectedResult) {
log.info("Testing chat client directly");
String prompt = "What is Spring AI?";
log.info("Sending prompt: {}", prompt);

String content = chatClient.prompt().user(prompt).call().content();

log.info("Received response: {}", content);
assertNotNull(content);
assertFalse(content.isEmpty());
assertTrue(
content.contains(expectedResult), "Response should contain '" + expectedResult + "'");
}
}
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# https://docs.gradle.org/current/userguide/build_environment.html#sec:gradle_configuration_properties

org.gradle.configuration-cache=true
version=0.9.1
version=0.9.2

1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ spring-ai = "1.0.0"
junit-bom = { group = "org.junit", name = "junit-bom", version.ref = "junit-jupiter" }
junit-jupiter = { group = "org.junit.jupiter", name = "junit-jupiter" }
junit-jupiter-api = { group = "org.junit.jupiter", name = "junit-jupiter-api" }
junit-jupiter-params = { group = "org.junit.jupiter", name = "junit-jupiter-params" }
junit-jupiter-engine = { group = "org.junit.jupiter", name = "junit-jupiter-engine" }
junit-platform-launcher = { group = "org.junit.platform", name = "junit-platform-launcher" }
spring-bom = { group = "org.springframework", name = "spring-framework-bom", version.ref = "spring" }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package io.flowinquiry.testcontainers.ai;

import static io.flowinquiry.testcontainers.ContainerType.OLLAMA;
import static org.testcontainers.containers.BindMode.READ_WRITE;

import io.flowinquiry.testcontainers.ContainerType;
import io.flowinquiry.testcontainers.Slf4jOutputConsumer;
import io.flowinquiry.testcontainers.SpringAwareContainerProvider;
import java.io.IOException;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.PropertiesPropertySource;
import org.testcontainers.containers.Container;
import org.testcontainers.ollama.OllamaContainer;

/**
Expand Down Expand Up @@ -46,7 +49,8 @@ public ContainerType getContainerType() {
*/
@Override
protected OllamaContainer createContainer() {
return new OllamaContainer(dockerImage + ":" + version);
return new OllamaContainer(dockerImage + ":" + version)
.withFileSystemBind("/tmp/ollama-cache", "/root/.ollama", READ_WRITE);
}

/**
Expand All @@ -60,14 +64,31 @@ protected OllamaContainer createContainer() {
@Override
public void start() {
super.start();

Logger containerLog = LoggerFactory.getLogger(OllamaContainerProvider.class);
container.followOutput(new Slf4jOutputConsumer(containerLog));

try {
log.info("Starting pull model {}", enableContainerAnnotation.model());
container.execInContainer("ollama", "pull", enableContainerAnnotation.model());
pullModelIfMissing(enableContainerAnnotation.model());
} catch (IOException | InterruptedException e) {
throw new RuntimeException(e);
}
}

private void pullModelIfMissing(String modelName) throws IOException, InterruptedException {
Container.ExecResult result = container.execInContainer("ollama", "list");
String output = result.getStdout();

if (!output.contains(modelName)) {
log.info("Model '{}' not found in ollama cache. Pulling...", modelName);
Container.ExecResult pullResult = container.execInContainer("ollama", "pull", modelName);
log.info("Pull complete: {}", pullResult.getStdout());
} else {
log.info("Model '{}' already exists. Skipping pull.", modelName);
}
}

/**
* Applies Ollama-specific configuration to the Spring environment.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package io.flowinquiry.testcontainers;

import org.slf4j.Logger;
import org.slf4j.event.Level;
import org.testcontainers.containers.output.BaseConsumer;
import org.testcontainers.containers.output.OutputFrame;

/**
* An implementation of {@link BaseConsumer} that routes container output to SLF4J logging. This
* consumer allows for different log levels to be used for STDOUT and STDERR streams.
*
* <p>Usage example:
*
* <pre>
* Logger logger = LoggerFactory.getLogger(MyClass.class);
* GenericContainer container = new GenericContainer("some-image")
* .withLogConsumer(new Slf4jOutputConsumer(logger));
* </pre>
*/
public class Slf4jOutputConsumer extends BaseConsumer<Slf4jOutputConsumer> {

/** The SLF4J logger to which container output will be written. */
private final Logger logger;

/** The log level to use for STDOUT output from the container. */
private final Level stdoutLogLevel;

/** The log level to use for STDERR output from the container. */
private final Level stderrLogLevel;

/**
* Creates a new Slf4jOutputConsumer with default log levels. STDOUT messages will be logged at
* DEBUG level, and STDERR messages at ERROR level.
*
* @param logger the SLF4J logger to which container output will be written
*/
public Slf4jOutputConsumer(Logger logger) {
this(logger, Level.DEBUG, Level.ERROR);
}

/**
* Creates a new Slf4jOutputConsumer with custom log levels for STDOUT and STDERR.
*
* @param logger the SLF4J logger to which container output will be written
* @param stdoutLogLevel the log level to use for STDOUT output
* @param stderrLogLevel the log level to use for STDERR output
*/
public Slf4jOutputConsumer(Logger logger, Level stdoutLogLevel, Level stderrLogLevel) {
this.logger = logger;
this.stdoutLogLevel = stdoutLogLevel;
this.stderrLogLevel = stderrLogLevel;
}

/**
* Processes an output frame from a container and logs it using the configured SLF4J logger.
*
* <p>The method:
*
* <ul>
* <li>Skips null or empty frames
* <li>Determines the appropriate log level based on the frame type (STDOUT or STDERR)
* <li>Logs the message with the frame type as a prefix
* </ul>
*
* @param outputFrame the output frame to process
*/
@Override
public void accept(OutputFrame outputFrame) {
if (outputFrame == null || outputFrame.getBytes() == null) return;

String message = outputFrame.getUtf8String().trim();
if (message.isEmpty()) return;

Level levelToUse =
switch (outputFrame.getType()) {
case STDOUT -> stdoutLogLevel;
case STDERR -> stderrLogLevel;
case END -> null;
};

if (levelToUse != null) {
logAtLevel(levelToUse, "[{}] {}", outputFrame.getType(), message);
}
}

/**
* Logs a message at the specified SLF4J level.
*
* @param level the SLF4J level at which to log the message
* @param format the message format string
* @param args the arguments to be formatted into the message string
*/
private void logAtLevel(Level level, String format, Object... args) {
switch (level) {
case TRACE -> logger.trace(format, args);
case DEBUG -> logger.debug(format, args);
case INFO -> logger.info(format, args);
case WARN -> logger.warn(format, args);
case ERROR -> logger.error(format, args);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public abstract class SpringAwareContainerProvider<

private static final Logger log = LoggerFactory.getLogger(SpringAwareContainerProvider.class);

private static boolean reuseContainerSupport =
TestcontainersConfiguration.getInstance().environmentSupportsReuse();

/** The version of the container image to use. */
protected String version;

Expand All @@ -43,12 +46,17 @@ public final void initContainerInstance(A enableContainerAnnotation) {
enableContainerAnnotation.annotationType().getMethod("dockerImage");
Method versionMethod = enableContainerAnnotation.annotationType().getMethod("version");

log.info("Initializing JDBC container with image {}:{}", dockerImage, version);
log.info("Initializing the container with image {}:{}", dockerImage, version);
this.version = (String) versionMethod.invoke(enableContainerAnnotation);
this.dockerImage = (String) dockerImageMethod.invoke(enableContainerAnnotation);

container = createContainer();
container.withReuse(TestcontainersConfiguration.getInstance().environmentSupportsReuse());
container.withReuse(reuseContainerSupport);
log.info(
"Created the container with image {}:{} with reuse {}",
dockerImage,
version,
reuseContainerSupport);
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new IllegalArgumentException(
"Annotation "
Expand All @@ -74,7 +82,7 @@ public void start() {
/** Stops the container. This method is called when the Spring context is closed. */
@Override
public void stop() {
if (!TestcontainersConfiguration.getInstance().environmentSupportsReuse()) {
if (!reuseContainerSupport) {
container.stop();
}
}
Expand Down
Loading