Skip to content

Commit 02cdc82

Browse files
feat: Assignment selection for chat (#155)
Co-authored-by: Asad Ali <asad.ali@arbisoft.com>
1 parent 87e7cfd commit 02cdc82

21 files changed

+1371
-26
lines changed

package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@
9292
"@storybook/addon-links": "^9.0.13",
9393
"@storybook/addon-onboarding": "^9.0.13",
9494
"@storybook/addon-webpack5-compiler-swc": "^3.0.0",
95+
"@storybook/blocks": "^8.6.14",
9596
"@storybook/nextjs": "^9.0.13",
9697
"@storybook/react-webpack5": "^9.0.13",
98+
"@storybook/test": "^8.6.14",
9799
"@swc/jest": "^0.2.37",
98100
"@testing-library/dom": "^10.4.0",
99101
"@testing-library/jest-dom": "^6.6.3",

src/bundles/AiDrawer/AiDrawerManager.stories.tsx

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,7 @@ export const AiDrawerManagerStory: Story = {
183183
http.get(CONTENT_FILE_URL, () => {
184184
return HttpResponse.json(sampleResponse)
185185
}),
186-
http.post(TRACKING_EVENTS_ENDPOINT, async ({ request }) => {
187-
const body = await request.json()
188-
console.log("TrackingEvent", body)
186+
http.post(TRACKING_EVENTS_ENDPOINT, () => {
189187
return HttpResponse.json({ success: true })
190188
}),
191189
...handlers,

src/components/AiChat/AiChat.stories.tsx

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import * as React from "react"
22
import type { Meta, StoryObj } from "@storybook/nextjs"
3+
import { http, HttpResponse } from "msw"
34
import { AiChat } from "./AiChat"
45
import type { AiChatProps } from "./types"
56
import styled from "@emotion/styled"
@@ -9,6 +10,7 @@ import { MathJaxContext } from "better-react-mathjax"
910

1011
const TEST_API_STREAMING = "http://localhost:4567/streaming"
1112
const TEST_API_JSON = "http://localhost:4567/json"
13+
const TEST_API_PROBLEM_SET_LIST = "http://localhost:4567/problem_set_list"
1214

1315
const INITIAL_MESSAGES: AiChatProps["initialMessages"] = [
1416
{
@@ -41,7 +43,21 @@ const meta: Meta<typeof AiChat> = {
4143
title: "smoot-design/AI/AiChat",
4244
component: AiChat,
4345
parameters: {
44-
msw: { handlers },
46+
msw: {
47+
handlers: [
48+
http.get(TEST_API_PROBLEM_SET_LIST, () => {
49+
return HttpResponse.json({
50+
problem_set_titles: [
51+
"Assignment 1",
52+
"Assignment 2",
53+
"Assignment 3",
54+
"Assignment 4",
55+
],
56+
})
57+
}),
58+
...handlers,
59+
],
60+
},
4561
},
4662
render: (args) => <AiChat {...args} />,
4763
decorators: (Story, context) => {
@@ -94,6 +110,34 @@ export const JsonResponses: Story = {
94110
},
95111
}
96112

113+
export const AssignmentSelection: Story = {
114+
args: {
115+
requestOpts: {
116+
apiUrl: TEST_API_STREAMING,
117+
transformBody: (messages, body) => ({
118+
message: messages[messages.length - 1].content,
119+
problem_set_title: body?.problem_set_title,
120+
}),
121+
},
122+
initialMessages: [
123+
{
124+
content:
125+
"Hi! Please select an assignment from the dropdown menu to begin.",
126+
role: "assistant",
127+
},
128+
],
129+
conversationStarters: [],
130+
entryScreenEnabled: false,
131+
problemSetListUrl: TEST_API_PROBLEM_SET_LIST,
132+
problemSetInitialMessages: [
133+
{
134+
role: "assistant",
135+
content: "Which question are you working on?",
136+
},
137+
],
138+
},
139+
}
140+
97141
const ScrollComponent: FC<AiChatProps> = (args) => {
98142
const ref = useRef<HTMLDivElement>(null)
99143
const [scrollElement, setScrollElement] = useState<HTMLDivElement | null>(

src/components/AiChat/AiChat.test.tsx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,13 @@ describe("AiChat", () => {
179179
await user.paste("User message")
180180
await user.click(screen.getByRole("button", { name: "Send" }))
181181

182-
expect(transformBody).toHaveBeenCalledWith([
183-
expect.objectContaining(initialMessages[0]),
184-
expect.objectContaining({ content: "User message", role: "user" }),
185-
])
182+
expect(transformBody).toHaveBeenCalledWith(
183+
[
184+
expect.objectContaining(initialMessages[0]),
185+
expect.objectContaining({ content: "User message", role: "user" }),
186+
],
187+
{},
188+
)
186189
expect(mockFetch).toHaveBeenCalledTimes(1)
187190
expect(mockFetch).toHaveBeenCalledWith(
188191
API_URL,

src/components/AiChat/AiChat.tsx

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ import { useScrollSnap } from "../ScrollSnap/useScrollSnap"
1818
import type { Message } from "@ai-sdk/react"
1919
import Markdown from "./Markdown"
2020
import EllipsisIcon from "./EllipsisIcon"
21+
import { SimpleSelectField } from "../SimpleSelect/SimpleSelect"
22+
import { useFetch } from "./utils"
23+
import { SelectChangeEvent } from "@mui/material/Select"
2124

2225
const classes = {
2326
root: "MitAiChat--root",
@@ -69,6 +72,16 @@ const ChatContainer = styled.div<{ externalScroll: boolean }>(
6972
}),
7073
)
7174

75+
const AssignmentSelect = styled(SimpleSelectField)({
76+
width: "295px",
77+
"> div": {
78+
width: "inherit",
79+
},
80+
label: {
81+
display: "none",
82+
},
83+
})
84+
7285
const MessagesContainer = styled(ScrollSnap)<{ externalScroll: boolean }>(
7386
({ externalScroll }) => ({
7487
display: "flex",
@@ -224,12 +237,17 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
224237
ref,
225238
useMathJax = false,
226239
onSubmit,
240+
problemSetListUrl,
241+
problemSetInitialMessages,
227242
...others // Could contain data attributes
228243
}) => {
229244
const containerRef = useRef<HTMLDivElement>(null)
230245
const messagesContainerRef = useRef<HTMLDivElement>(null)
231246
const chatScreenRef = useRef<HTMLDivElement>(null)
232247
const promptInputRef = useRef<HTMLDivElement>(null)
248+
const { response: problemSetListResponse } = useFetch<{
249+
problem_set_titles: string[]
250+
}>(problemSetListUrl)
233251

234252
const {
235253
messages,
@@ -242,6 +260,9 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
242260
error,
243261
initialMessages,
244262
status,
263+
additionalBody,
264+
setMessages,
265+
setAdditionalBody,
245266
} = useAiChat()
246267

247268
useScrollSnap({
@@ -281,6 +302,18 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
281302
})
282303
}
283304

305+
const onProblemSetChange = (event: SelectChangeEvent<string | string[]>) => {
306+
if (problemSetInitialMessages) {
307+
setMessages(
308+
problemSetInitialMessages.map((message, i) => ({
309+
...message,
310+
id: `initial-${i}`,
311+
})),
312+
)
313+
}
314+
setAdditionalBody?.({ problem_set_title: event.target.value as string })
315+
}
316+
284317
const lastMsg = messages[messages.length - 1]
285318

286319
const externalScroll = !!scrollElement
@@ -317,7 +350,30 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
317350
askTimTitle={askTimTitle}
318351
externalScroll={externalScroll}
319352
className={classNames(className, classes.title)}
353+
control={
354+
problemSetListResponse?.problem_set_titles?.length ? (
355+
<AssignmentSelect
356+
label="Assignments"
357+
options={[
358+
{
359+
value: "",
360+
label: "Select an assignment",
361+
disabled: true,
362+
},
363+
...problemSetListResponse.problem_set_titles.map(
364+
(title) => ({
365+
value: title,
366+
label: title,
367+
}),
368+
),
369+
]}
370+
value={additionalBody?.problem_set_title ?? ""}
371+
onChange={onProblemSetChange}
372+
/>
373+
) : null
374+
}
320375
/>
376+
321377
<MessagesContainer
322378
className={classes.messagesContainer}
323379
externalScroll={externalScroll}

src/components/AiChat/AiChatContext.tsx

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
import * as React from "react"
22
import { useChat, UseChatHelpers } from "@ai-sdk/react"
33
import type { RequestOpts, AiChatMessage, AiChatContextProps } from "./types"
4-
import { useMemo, createContext } from "react"
4+
import { useMemo, createContext, useState } from "react"
55
import retryingFetch from "../../utils/retryingFetch"
66
import { getCookie } from "../../utils/getCookie"
77

88
const identity = <T,>(x: T): T => x
99

10-
const getFetcher: (requestOpts: RequestOpts) => typeof fetch =
11-
(requestOpts: RequestOpts) => async (url, opts) => {
10+
const getFetcher: (
11+
requestOpts: RequestOpts,
12+
additionalBody: Record<string, string>,
13+
) => typeof fetch =
14+
(requestOpts: RequestOpts, additionalBody: Record<string, string> = {}) =>
15+
async (url, opts) => {
1216
if (typeof opts?.body !== "string") {
1317
console.error("Unexpected body type.")
1418
return retryingFetch(url, opts)
1519
}
16-
const messages: AiChatMessage[] = JSON.parse(opts?.body).messages
20+
const parsedBody = JSON.parse(opts?.body)
21+
const messages: AiChatMessage[] = parsedBody.messages
1722
const transformBody: RequestOpts["transformBody"] =
1823
requestOpts.transformBody ?? identity
1924
const options: RequestInit = {
2025
...opts,
21-
body: JSON.stringify(transformBody(messages)),
26+
body: JSON.stringify(transformBody(messages, additionalBody)),
2227
...requestOpts.fetchOpts,
2328
headers: {
2429
...opts?.headers,
@@ -43,6 +48,8 @@ const getFetcher: (requestOpts: RequestOpts) => typeof fetch =
4348
*/
4449
type AiChatContextResult = UseChatHelpers & {
4550
initialMessages: AiChatMessage[] | null
51+
additionalBody?: Record<string, string>
52+
setAdditionalBody?: (body: Record<string, string>) => void
4653
}
4754
const AiChatContext = createContext<AiChatContextResult | null>(null)
4855

@@ -66,8 +73,19 @@ const AiChatProvider: React.FC<AiChatContextProps> = ({
6673
)
6774
}, [_initialMessages])
6875

69-
const fetcher = useMemo(() => getFetcher(requestOpts), [requestOpts])
70-
const { messages: unparsed, ...others } = useChat({
76+
const [additionalBody, setAdditionalBody] = useState<Record<string, string>>(
77+
{},
78+
)
79+
80+
const fetcher = useMemo(
81+
() => getFetcher(requestOpts, additionalBody),
82+
[requestOpts, additionalBody],
83+
)
84+
const {
85+
messages: unparsed,
86+
setMessages,
87+
...others
88+
} = useChat({
7189
api: requestOpts.apiUrl,
7290
streamProtocol: "text",
7391
fetch: fetcher,
@@ -100,7 +118,14 @@ const AiChatProvider: React.FC<AiChatContextProps> = ({
100118
* Ensure that child state is reset when chatId changes.
101119
*/
102120
key={chatId}
103-
value={{ initialMessages, messages, ...others }}
121+
value={{
122+
initialMessages,
123+
messages,
124+
setMessages,
125+
additionalBody,
126+
setAdditionalBody,
127+
...others,
128+
}}
104129
>
105130
{children}
106131
</AiChatContext.Provider>

src/components/AiChat/ChatTitle.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,22 @@ const AskTimTitle = styled.div(({ theme }) => ({
3737
width: "24px",
3838
height: "24px",
3939
},
40+
"&& p": {
41+
margin: 0,
42+
},
4043
}))
4144

4245
type ChatTitleProps = {
4346
askTimTitle?: string
4447
externalScroll?: boolean
48+
control?: React.ReactNode
4549
className?: string
4650
}
4751

4852
const ChatTitle = ({
4953
askTimTitle,
5054
externalScroll,
55+
control,
5156
className,
5257
}: ChatTitleProps) => {
5358
if (!askTimTitle) return null
@@ -60,6 +65,7 @@ const ChatTitle = ({
6065
{askTimTitle}
6166
</Typography>
6267
</AskTimTitle>
68+
{control}
6369
</Container>
6470
)
6571
}

src/components/AiChat/types.ts

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ type RequestOpts = {
1818
*
1919
* JSON.stringify is applied to the return value.
2020
*/
21-
transformBody?: (messages: AiChatMessage[]) => unknown
21+
transformBody?: (
22+
messages: AiChatMessage[],
23+
body?: Record<string, string>,
24+
) => unknown
2225
/**
2326
* Extra options to pass to fetch.
2427
*/
@@ -52,6 +55,9 @@ type AiChatContextProps = {
5255
initialMessages?: Omit<AiChatMessage, "id">[]
5356

5457
children?: React.ReactNode
58+
59+
additionalBody?: Record<string, string>
60+
setAdditionalBody?: (body: Record<string, string>) => void
5561
}
5662

5763
type AiChatDisplayProps = {
@@ -106,7 +112,7 @@ type AiChatDisplayProps = {
106112
scrollElement?: HTMLElement | null
107113

108114
/**
109-
* If true, the chat will display math equations using MathJax.
115+
* If true, the chat will display math equations using MathJax..
110116
* Defaults to false.
111117
*/
112118
useMathJax?: boolean
@@ -116,6 +122,19 @@ type AiChatDisplayProps = {
116122
*/
117123
autofocus?: boolean
118124

125+
/**
126+
* URL to fetch problem set list for dropdown.
127+
*
128+
* The problem set selection is passed as the second argument to the `transformBody` function
129+
* provided as `{ problem_set_title: string }`.
130+
*/
131+
problemSetListUrl?: string
132+
133+
/**
134+
* Initial messages to display on problem set selection.
135+
*/
136+
problemSetInitialMessages?: Omit<AiChatMessage, "id">[]
137+
119138
onSubmit?: (
120139
messageText: string,
121140
meta: {

0 commit comments

Comments
 (0)