Skip to content

Cache control support #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Sources/AnthropicSwiftSDK-Bedrock/Messages.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public struct Messages {
public func createMessage(
_ messages: [Message],
model: Model = .claude_3_Haiku,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down Expand Up @@ -88,7 +88,7 @@ public struct Messages {
public func streamMessage(
_ messages: [Message],
model: Model = .claude_3_Haiku,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down
4 changes: 2 additions & 2 deletions Sources/AnthropicSwiftSDK-VertexAI/Messages.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public struct Messages {
public func createMessage(
_ messages: [Message],
model: Model = .claude_3_Haiku,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down Expand Up @@ -92,7 +92,7 @@ public struct Messages {
public func streamMessage(
_ messages: [Message],
model: Model = .claude_3_Haiku,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down
4 changes: 4 additions & 0 deletions Sources/AnthropicSwiftSDK/ClientError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public enum ClientError: Error {
case failedToDecodeToolUseContent
/// SDK failed to make `ToolUse.input` encodable
case failedToMakeEncodableToolUseInput([String: Any])
/// SDK failed to encode `SystemPrompt` object
case failedToEncodeSystemPrompt

/// Description of sdk internal errors.
public var localizedDescription: String {
Expand All @@ -59,6 +61,8 @@ public enum ClientError: Error {
return "Failed to decode into tool use content"
case .failedToMakeEncodableToolUseInput:
return "Failed to make ToolUse.input object Encodable"
case .failedToEncodeSystemPrompt:
return "Failed to encode `SystemPrompt` object"
}
}
}
56 changes: 56 additions & 0 deletions Sources/AnthropicSwiftSDK/Entity/SystemPrompt.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//
// SystemPrompt.swift
//
//
// Created by 伊藤史 on 2024/09/04.
//

import Foundation

public enum CacheControl: String {
/// corresponds to this 5-minute lifetime.
case ephemeral
}

extension CacheControl: Encodable {
enum CodingKeys: CodingKey {
case type
}

public func encode(to encoder: any Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(self.rawValue, forKey: .type)
}
}

public enum SystemPrompt {
case text(String, CacheControl?)

private var type: String {
switch self {
case .text:
return "text"
}
}
}

extension SystemPrompt: Encodable {
enum CodingKeys: String, CodingKey {
case type
case text
case cacheControl = "cache_control"
}

public func encode(to encoder: any Encoder) throws {
guard case let .text(text, cacheControl) = self else {
throw ClientError.failedToEncodeSystemPrompt
}

var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(type, forKey: .type)
try container.encode(text, forKey: .text)
if let cacheControl {
try container.encode(cacheControl, forKey: .cacheControl)
}
}
}
12 changes: 6 additions & 6 deletions Sources/AnthropicSwiftSDK/Messages.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public struct Messages {
public func createMessage(
_ messages: [Message],
model: Model = .claude_3_Opus,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down Expand Up @@ -87,7 +87,7 @@ public struct Messages {
public func createMessage(
_ messages: [Message],
model: Model = .claude_3_Opus,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down Expand Up @@ -177,7 +177,7 @@ public struct Messages {
forToolUseRequest toolUseRequest: MessagesResponse,
priviousMessages messages: [Message],
model: Model = .claude_3_Opus,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down Expand Up @@ -240,7 +240,7 @@ public struct Messages {
public func streamMessage(
_ messages: [Message],
model: Model = .claude_3_Opus,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down Expand Up @@ -288,7 +288,7 @@ public struct Messages {
public func streamMessage(
_ messages: [Message],
model: Model = .claude_3_Opus,
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequence: [String]? = nil,
Expand Down Expand Up @@ -374,7 +374,7 @@ public struct Messages {
_ stream: AsyncThrowingStream<StreamingResponse, Error>,
messages: [Message],
model: Model,
system: String?,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData?,
stopSequence: [String]?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct DefaultAnthropicHeaderProvider: AnthropicHeaderProvider {
/// content type of response, now only support JSON
let contentType = "application/json"

private let betaDescription = "messages-2023-12-15"
private let betaDescription = "prompt-caching-2024-07-31"

func getAnthropicAPIHeaders() -> [String: String] {
var headers: [String: String] = [
Expand Down
4 changes: 2 additions & 2 deletions Sources/AnthropicSwiftSDK/Network/MessagesRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public struct MessagesRequest: Encodable {
/// System prompt.
///
/// A system prompt is a way of providing context and instructions to Claude, such as specifying a particular goal or role.
public let system: String?
public let system: [SystemPrompt]
/// The maximum number of tokens to generate before stopping.
///
/// Note that our models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate.
Expand Down Expand Up @@ -56,7 +56,7 @@ public struct MessagesRequest: Encodable {
public init(
model: Model = .claude_3_Opus,
messages: [Message],
system: String? = nil,
system: [SystemPrompt] = [],
maxTokens: Int,
metaData: MetaData? = nil,
stopSequences: [String]? = nil,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import AWSBedrockRuntime

final class AnthropicBedrockClientTests: XCTestCase {
func testInvokeModelContainEncodedMessageRequest() throws {
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: nil, maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: [], maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let invokeModel = try InvokeModelInput(accept: "application/json", request: request, contentType: "application/json")

let requestData = try XCTUnwrap(invokeModel.body)
Expand All @@ -29,7 +29,7 @@ final class AnthropicBedrockClientTests: XCTestCase {
}

func testInvokeModelNotContainUnnecessaryParameters() throws {
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: nil, maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: [], maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let invokeModel = try InvokeModelInput(accept: "application/json", request: request, contentType: "application/json")

let requestData = try XCTUnwrap(invokeModel.body)
Expand All @@ -41,7 +41,7 @@ final class AnthropicBedrockClientTests: XCTestCase {
}

func testInvokeModelWithResponseStreamContainEncodedMessageRequest() throws {
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: nil, maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: [], maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let invokeModel = try InvokeModelWithResponseStreamInput(accept: "application/json", request: request, contentType: "application/json")

let requestData = try XCTUnwrap(invokeModel.body)
Expand All @@ -58,7 +58,7 @@ final class AnthropicBedrockClientTests: XCTestCase {
}

func testInvokeModelWithResponseStreamNotContainUnnecessaryParameters() throws {
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: nil, maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: [], maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let invokeModel = try InvokeModelWithResponseStreamInput(accept: "application/json", request: request, contentType: "application/json")

let requestData = try XCTUnwrap(invokeModel.body)
Expand Down Expand Up @@ -152,7 +152,7 @@ extension MessagesRequest: Decodable {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.init(
messages: try container.decode([Message].self, forKey: .messages),
system: try? container.decode(String.self, forKey: .system),
system: try container.decode([SystemPrompt].self, forKey: .system),
maxTokens: try container.decode(Int.self, forKey: .maxTokens),
stopSequences: try container.decode([String].self, forKey: .stopSequences),
temperature: try container.decode(Double.self, forKey: .temperature),
Expand Down Expand Up @@ -221,3 +221,32 @@ extension Content: Equatable {
}
}
}

extension SystemPrompt: Equatable {
public static func == (lhs: SystemPrompt, rhs: SystemPrompt) -> Bool {
guard case let .text(lhsText, lhsCacheControl) = lhs,
case let .text(rhsText, rhsCacheControl) = rhs else {
return false
}

return lhsText == rhsText && lhsCacheControl == rhsCacheControl
}
}

extension SystemPrompt: Decodable {
enum CodingKeys: CodingKey {
case text
case cacheControl
}

public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)

self = .text(
try container.decode(String.self, forKey: .text),
try? container.decode(CacheControl.self, forKey: .cacheControl)
)
}
}

extension CacheControl: Decodable {}
39 changes: 39 additions & 0 deletions Tests/AnthropicSwiftSDKTests/Entity/SystemPromptTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//
// SystemPromptTests.swift
//
//
// Created by Fumito Ito on 2024/09/06.
//

import XCTest
@testable import AnthropicSwiftSDK

final class SystemPromptTests: XCTestCase {
let encoder = JSONEncoder()

func testEncodeSystemPrompt() throws {
let systemPrompt = SystemPrompt.text("this is test text for system prompt test", nil)
let dictionary = try XCTUnwrap(systemPrompt.toDictionary(encoder))

XCTAssertEqual(dictionary["type"] as? String, "text")
XCTAssertEqual(dictionary["text"] as? String, "this is test text for system prompt test")
XCTAssertNil(dictionary.index(forKey: "cache_control"))
}

func testEncodeSystemPromptWithCacheControl() throws {
let systemPrompt = SystemPrompt.text("this is test text for system prompt test with cache control", .ephemeral)
let dictionary = try XCTUnwrap(systemPrompt.toDictionary(encoder))

XCTAssertEqual(dictionary["type"] as? String, "text")
XCTAssertEqual(dictionary["text"] as? String, "this is test text for system prompt test with cache control")
let cacheControl = try XCTUnwrap(dictionary["cache_control"] as? [String: String])
XCTAssertEqual(cacheControl["type"], "ephemeral")
}
}

extension SystemPrompt {
func toDictionary(_ encoder: JSONEncoder) throws -> [String: Any]? {
let e = try encoder.encode(self)
return try JSONSerialization.jsonObject(with: e, options: []) as? [String: Any]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ final class AnthropicAPIClientTests: XCTestCase {
XCTAssertEqual(headers!["x-api-key"], "test-api-key")
XCTAssertEqual(headers!["anthropic-version"], "2023-06-01")
XCTAssertEqual(headers!["Content-Type"], "application/json")
XCTAssertEqual(headers!["anthropic-beta"], "messages-2023-12-15")
XCTAssertEqual(headers!["anthropic-beta"], "prompt-caching-2024-07-31")

expectation.fulfill()
})
Expand All @@ -90,7 +90,7 @@ final class AnthropicAPIClientTests: XCTestCase {
XCTAssertEqual(headers!["x-api-key"], "test-api-key")
XCTAssertEqual(headers!["anthropic-version"], "2023-06-01")
XCTAssertEqual(headers!["Content-Type"], "application/json")
XCTAssertEqual(headers!["anthropic-beta"], "messages-2023-12-15")
XCTAssertEqual(headers!["anthropic-beta"], "prompt-caching-2024-07-31")

expectation.fulfill()
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ final class DefaultAnthropicHeaderProviderTests: XCTestCase {
let provider = DefaultAnthropicHeaderProvider(useBeta: true)
let headers = provider.getAnthropicAPIHeaders()

XCTAssertEqual(headers["anthropic-beta"], "messages-2023-12-15")
XCTAssertEqual(headers["anthropic-beta"], "prompt-caching-2024-07-31")
}

func testBetaHeaderShouldNotBeProvidedIfUseBeta() {
Expand Down
Loading