Skip to content

validate message and batch #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Sources/AnthropicSwiftSDK/ClientError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public enum ClientError: Error {
case failedToMakeEncodableToolUseInput([String: Any])
/// SDK failed to encode `SystemPrompt` object
case failedToEncodeSystemPrompt
/// These messages are not supported by the model.
case unsupportedMessageContentContained(model: Model, messages: [Message])
/// Some unsupported features are used.
case unsupportedFeatureUsed(description: String)

/// Description of sdk internal errors.
public var localizedDescription: String {
Expand Down Expand Up @@ -63,6 +67,10 @@ public enum ClientError: Error {
return "Failed to make ToolUse.input object Encodable"
case .failedToEncodeSystemPrompt:
return "Failed to encode `SystemPrompt` object"
case let .unsupportedMessageContentContained(model, messages):
return "The model \(model.stringfy) does not support these messages: \(messages)"
case let .unsupportedFeatureUsed(description):
return "Some unsupported features are used. For more detail, see \(description)."
}
}
}
44 changes: 44 additions & 0 deletions Sources/AnthropicSwiftSDK/Entity/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,50 @@ public enum Model {
}
}

extension Model {
/// Whether this model supports Message Batches API or not.
///
/// `Claude 3.0 Sonnet` does not support it.
var isSupportBatches: Bool {
switch self {
case
.claude_3_Opus,
.claude_3_Haiku,
.claude_3_5_Sonnet,
.claude_3_5_Haiku,
.custom:
return true
case .claude_3_Sonnet:
return false
}
}

/// Whether this model supports Vision feature or not.
///
/// `Claude 3.5 Haiku` does not support it.
var isSupportVision: Bool {
switch self {
case
.claude_3_Opus,
.claude_3_Haiku,
.claude_3_Sonnet,
.claude_3_5_Sonnet,
.custom:
return true
case .claude_3_5_Haiku:
return false
}
}

func isValid(for message: Message) -> Bool {
if isSupportVision {
return true
}

return message.content.allSatisfy { $0.contentType != .image }
}
}

extension Model {
var stringfy: String {
switch self {
Expand Down
23 changes: 22 additions & 1 deletion Sources/AnthropicSwiftSDK/MessageBatches.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ public struct MessageBatches {
/// - Returns: A `BatchResponse` containing the details of the created batches.
/// - Throws: An error if the request fails.
public func createBatches(batches: [MessageBatch]) async throws -> BatchResponse {
try await createBatches(
try validate(batches: batches)

return try await createBatches(
batches: batches,
anthropicHeaderProvider: DefaultAnthropicHeaderProvider(),
authenticationHeaderProvider: APIKeyAuthenticationHeaderProvider(apiKey: apiKey)
Expand Down Expand Up @@ -336,3 +338,22 @@ public struct MessageBatches {
return try anthropicJSONDecoder.decode(BatchResponse.self, from: data)
}
}

extension MessageBatches {
func validate(batches: [MessageBatch]) throws {
try batches.forEach { batch in
let model = batch.parameter.model
guard model.isSupportBatches else {
throw ClientError.unsupportedFeatureUsed(description: "The model: \(model.stringfy) does not support Message Batches API")
}

let messages = batch.parameter.messages
guard (messages.allSatisfy { model.isValid(for: $0) }) else {
throw ClientError.unsupportedMessageContentContained(
model: model,
messages: messages.filter { model.isValid(for: $0) == false }
)
}
}
}
}
19 changes: 17 additions & 2 deletions Sources/AnthropicSwiftSDK/Messages.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ public struct Messages {
tools: [Tool]? = nil,
toolChoice: ToolChoice = .auto
) async throws -> MessagesResponse {
try await createMessage(
try validate(model, for: messages)

return try await createMessage(
messages,
model: model,
system: system,
Expand Down Expand Up @@ -160,7 +162,9 @@ public struct Messages {
tools: [Tool]? = nil,
toolChoice: ToolChoice = .auto
) async throws -> AsyncThrowingStream<StreamingResponse, Error> {
try await streamMessage(
try validate(model, for: messages)

return try await streamMessage(
messages,
model: model,
system: system,
Expand Down Expand Up @@ -246,3 +250,14 @@ public struct Messages {
return try await AnthropicStreamingParser.parse(stream: data.lines).accumulated()
}
}

extension Messages {
func validate(_ model: Model, for messages: [Message]) throws {
guard (messages.allSatisfy { model.isValid(for: $0) }) else {
throw ClientError.unsupportedMessageContentContained(
model: model,
messages: messages.filter { model.isValid(for: $0) == false }
)
}
}
}
62 changes: 62 additions & 0 deletions Tests/AnthropicSwiftSDKTests/Entity/ModelTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//
// ModelTests.swift
// AnthropicSwiftSDK
//
// Created by 伊藤史 on 2024/11/27.
//

import XCTest
@testable import AnthropicSwiftSDK

final class ModelTests: XCTestCase {

func testIsSupportBatches() {
XCTAssertTrue(Model.claude_3_Opus.isSupportBatches, "claude_3_Opus should support batches.")
XCTAssertFalse(Model.claude_3_Sonnet.isSupportBatches, "claude_3_Sonnet should not support batches.")
XCTAssertTrue(Model.claude_3_Haiku.isSupportBatches, "claude_3_Haiku should support batches.")
XCTAssertTrue(Model.claude_3_5_Sonnet.isSupportBatches, "claude_3_5_Sonnet should support batches.")
XCTAssertTrue(Model.claude_3_5_Haiku.isSupportBatches, "claude_3_5_Haiku should support batches.")
XCTAssertTrue(Model.custom("custom-model").isSupportBatches, "Custom models should support batches.")
}

func testIsSupportVision() {
XCTAssertTrue(Model.claude_3_Opus.isSupportVision, "claude_3_Opus should support vision.")
XCTAssertTrue(Model.claude_3_Sonnet.isSupportVision, "claude_3_Sonnet should support vision.")
XCTAssertTrue(Model.claude_3_Haiku.isSupportVision, "claude_3_Haiku should support vision.")
XCTAssertTrue(Model.claude_3_5_Sonnet.isSupportVision, "claude_3_5_Sonnet should support vision.")
XCTAssertFalse(Model.claude_3_5_Haiku.isSupportVision, "claude_3_5_Haiku should not support vision.")
XCTAssertTrue(Model.custom("custom-model").isSupportVision, "Custom models should support vision.")
}

func testIsValid() {

let textMessage = Message(role: .user, content: [.text("")])
let imageMessage = Message(role: .user, content: [.image(.init(type: .base64, mediaType: .gif, data: Data()))])
let documentMessage = Message(role: .user, content: [.document(.init(type: .base64, mediaType: .pdf, data: Data()))])

// Models that support vision
XCTAssertTrue(Model.claude_3_Opus.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.claude_3_Opus.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_Opus.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_Sonnet.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.claude_3_Haiku.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.claude_3_Haiku.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_Haiku.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_5_Sonnet.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.claude_3_5_Haiku.isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertFalse(Model.claude_3_5_Haiku.isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.claude_3_5_Haiku.isValid(for: documentMessage), "claude_3_Opus should validate document messages.")

XCTAssertTrue(Model.custom("custom-model").isValid(for: textMessage), "claude_3_Opus should validate text messages.")
XCTAssertTrue(Model.custom("custom-model").isValid(for: imageMessage), "claude_3_Opus should validate image messages.")
XCTAssertTrue(Model.custom("custom-model").isValid(for: documentMessage), "claude_3_Opus should validate document messages.")
}
}
109 changes: 109 additions & 0 deletions Tests/AnthropicSwiftSDKTests/MessageBatchesTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -240,4 +240,113 @@ final class MessageBatchesTests: XCTestCase {
XCTAssertEqual(error, .invalidRequestError)
}
}

func testValidate_Success() {
let batch = MessageBatch(
customId: "",
parameter: .init(
messages: [
.init(
role: .user,
content: [
.text("")
]
),
.init(
role: .user,
content: [
.text("")
]
)
],
model: .claude_3_Opus,
maxTokens: 1
)
)
let batches = [batch]
let messageBatches = MessageBatches(apiKey: "", session: .shared)

XCTAssertNoThrow(try messageBatches.validate(batches: batches))
}

func testValidate_ModelDoesNotSupportBatches() {
let batch = MessageBatch(
customId: "",
parameter: .init(
messages: [
.init(
role: .user,
content: [
.text("")
]
),
.init(
role: .user,
content: [
.text("")
]
)
],
model: .claude_3_Sonnet,
maxTokens: 1
)
)
let batches = [batch]
let messageBatches = MessageBatches(apiKey: "", session: .shared)

XCTAssertThrowsError(try messageBatches.validate(batches: batches)) { error in
guard let clientError = error as? ClientError else {
XCTFail("Expected ClientError but got \(error)")
return
}
switch clientError {
case .unsupportedFeatureUsed(let description):
XCTAssertEqual(description, "The model: \(Model.claude_3_Sonnet.stringfy) does not support Message Batches API")
default:
XCTFail("Unexpected ClientError: \(clientError)")
}
}
}

func testValidate_UnsupportedMessageContentContained() {
// Arrange
let batch = MessageBatch(
customId: "",
parameter: .init(
messages: [
.init(
role: .user,
content: [
.image(.init(type: .base64, mediaType: .png, data: Data()))
]
),
.init(
role: .user,
content: [
.text("")
]
)
],
model: .claude_3_5_Haiku,
maxTokens: 1
)
)
let batches = [batch]
let messageBatches = MessageBatches(apiKey: "", session: .shared)

XCTAssertThrowsError(try messageBatches.validate(batches: batches)) { error in
guard let clientError = error as? ClientError else {
XCTFail("Expected ClientError but got \(error)")
return
}
switch clientError {
case .unsupportedMessageContentContained(let model, let messages):
XCTAssertEqual(model.stringfy, Model.claude_3_5_Haiku.stringfy)
XCTAssertEqual(messages.count, 1)
XCTAssertEqual(messages.first?.content.first?.contentType, .image)
default:
XCTFail("Unexpected ClientError: \(clientError)")
}
}
}
}
58 changes: 58 additions & 0 deletions Tests/AnthropicSwiftSDKTests/MessagesTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,62 @@ final class MessagesTests: XCTestCase {
XCTAssertEqual(response.message.usage.outputTokens, 1)
}
}

func testValidate_Success() {
let messages: [Message] = [
.init(role: .user, content: [.text("Valid message")]),
.init(role: .user, content: [.text("Another valid message")])
]
let messagesHandler = Messages(apiKey: "", session: .shared)

XCTAssertNoThrow(try messagesHandler.validate(.claude_3_Opus, for: messages))
}

func testValidate_UnsupportedMessageContentContained() {
let messages: [Message] = [
.init(role: .user, content: [.text("Valid text message")]),
.init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]) // Unsupported image
]
let messagesHandler = Messages(apiKey: "", session: .shared)

// Act & Assert
XCTAssertThrowsError(try messagesHandler.validate(.claude_3_5_Haiku, for: messages)) { error in
guard let clientError = error as? ClientError else {
XCTFail("Expected ClientError but got \(error)")
return
}
switch clientError {
case .unsupportedMessageContentContained(let invalidModel, let invalidMessages):
XCTAssertEqual(invalidModel.stringfy, Model.claude_3_5_Haiku.stringfy)
XCTAssertEqual(invalidMessages.count, 1)
XCTAssertEqual(invalidMessages.first?.content.first?.contentType, .image)
default:
XCTFail("Unexpected ClientError: \(clientError)")
}
}
}

func testValidate_AllUnsupportedMessages() {
let messages: [Message] = [
.init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]), // Unsupported image
.init(role: .user, content: [.image(.init(type: .base64, mediaType: .png, data: Data()))]) // Unsupported image
]
let messagesHandler = Messages(apiKey: "", session: .shared)

// Act & Assert
XCTAssertThrowsError(try messagesHandler.validate(.claude_3_5_Haiku, for: messages)) { error in
guard let clientError = error as? ClientError else {
XCTFail("Expected ClientError but got \(error)")
return
}
switch clientError {
case .unsupportedMessageContentContained(let invalidModel, let invalidMessages):
XCTAssertEqual(invalidModel.stringfy, Model.claude_3_5_Haiku.stringfy)
XCTAssertEqual(invalidMessages.count, 2)
XCTAssertTrue(invalidMessages.allSatisfy { $0.content.first?.contentType == .image })
default:
XCTFail("Unexpected ClientError: \(clientError)")
}
}
}
}