|
| 1 | +#' Generate a Visualization Recommendation and ggplot Code in RMarkdown |
| 2 | +#' |
| 3 | +#' This function generates a visualization recommendation and the corresponding ggplot code |
| 4 | +#' for the given dataset using OpenAI's API. The generated R code with the ggplot is saved in a folder as an .Rmd file. |
| 5 | +#' |
| 6 | +#' @param data A data frame containing the dataset to analyze. |
| 7 | +#' @param api_key Your OpenAI API key as a string. |
| 8 | +#' @param output_name The name of the output file for the R code with ggplot. |
| 9 | +#' @param additional_prompt Additional instructions for the OpenAI API. |
| 10 | +#' @return The generated R code as a string. |
| 11 | +#' @examples |
| 12 | +#' \dontrun{ |
| 13 | +#' data <- data.frame( |
| 14 | +#' var1 = rnorm(100), |
| 15 | +#' var2 = rnorm(100), |
| 16 | +#' outcome = sample(c(0, 1), 100, replace = TRUE) |
| 17 | +#' ) |
| 18 | +#' api_key <- "your_openai_api_key" |
| 19 | +#' openAIScientist_generate_visualization_rmd(data, api_key, "Visualization") |
| 20 | +#' } |
| 21 | +#' @importFrom httr POST add_headers content |
| 22 | +#' @importFrom utils capture.output |
| 23 | +#' @export |
| 24 | +openAIScientist_generate_visualization_rmd <- function(data, api_key, output_name = "Visualization", additional_prompt = "") { |
| 25 | + |
| 26 | + cat("Generating data summary...\n") |
| 27 | + |
| 28 | + # Create a data summary |
| 29 | + data_summary <- summary(data) |
| 30 | + data_description <- paste(capture.output(data_summary), collapse = "\n") |
| 31 | + |
| 32 | + # Check if API key is provided |
| 33 | + if (api_key == "") { |
| 34 | + stop("API key not found. Please provide a valid OpenAI API key.") |
| 35 | + } |
| 36 | + |
| 37 | + # Clean up any problematic characters in the data description |
| 38 | + data_description <- gsub("[`*]", "", data_description) |
| 39 | + |
| 40 | + # Construct the prompt |
| 41 | + prompt <- paste( |
| 42 | + "You are provided with the following dataset summary: We are working in R \n\n", |
| 43 | + data_description, |
| 44 | + "\n\nYour tasks are enlisted below finish all of them one after another \n", |
| 45 | + "- Explain what you are doing ", |
| 46 | + "- Write your R code in Codeblocks with ```r ", |
| 47 | + "- The data is created in the variable `dataset`. Reference it in your code. THIS IS VERY important. DO NOT OVERWRITE THE `dataset` variable (1:1).", |
| 48 | + "- The data is created in the variable `dataset`. Reference it in your code. THIS IS VERY important. DO NOT OVERWRITE THE `dataset` variable (1:1).", |
| 49 | + "- Analyze the data you and write a explenation about it ", |
| 50 | + "- Now create plots with ggplot2 fitting to the analysiation you did before. Describe the plot and why you chose it.\n", |
| 51 | + "- Structure your Response well with # Headers \n", |
| 52 | + "- Write an if statement at the beginning to check if all the needed library is already installed and if not, install it.\n", |
| 53 | + "- Always use the variable names, never use abbreviations.", |
| 54 | + "- Try to ALWAYS cover all variables and correlations in a plot", |
| 55 | + "- For every Plot write a Description and explenation on why this plot fits to the data", |
| 56 | + "- Write an analysis of the data at the beginning", |
| 57 | + "- Do not use grid.arrange", |
| 58 | + "- Follow best practices while using colors for data visualization:\n", |
| 59 | + " * Use Qualitative palettes for categorical data.\n", |
| 60 | + " * Use Sequential palettes for numerical data with order.\n", |
| 61 | + " * Use Diverging palettes for numerical data with a meaningful midpoint.\n", |
| 62 | + " * Leverage the meaningfulness of color.\n", |
| 63 | + " * Avoid unnecessary usage of color.\n", |
| 64 | + " * Be consistent with color across charts.\n", |
| 65 | + " * Try to not use bright neon colors\n", |
| 66 | + "Think about using: scatter plots, line charts, box plots, heatmaps, bar charts, pie charts, histograms, area charts or barplots depending on the best usecase", |
| 67 | + "Every attribute that can be plotted should be plotted", |
| 68 | + |
| 69 | + additional_prompt, |
| 70 | + "The data is created in the variable `dataset`. Reference it in your code. THIS IS VERY important. DO NOT OVERWRITE THE `dataset` variable (1:1).", |
| 71 | + "Reference the dataset like this:\n", |
| 72 | + "data <- dataset" |
| 73 | + ) |
| 74 | + |
| 75 | + cat("Sending request to OpenAI API (this might take a while)...\n") |
| 76 | + |
| 77 | + response <- POST( |
| 78 | + url = "https://api.openai.com/v1/chat/completions", |
| 79 | + add_headers(Authorization = paste("Bearer", api_key), 'Content-Type' = 'application/json'), |
| 80 | + body = list( |
| 81 | + model = "gpt-4o", |
| 82 | + messages = list(list(role = "user", content = prompt)) |
| 83 | + ), |
| 84 | + encode = "json" |
| 85 | + ) |
| 86 | + |
| 87 | + content <- content(response, "parsed") |
| 88 | + |
| 89 | + if (!is.null(content$choices)) { |
| 90 | + r_code <- content$choices[[1]]$message$content |
| 91 | + |
| 92 | + # Replace ```r and ```R with ```{r} |
| 93 | + rmd_code <- gsub("```[rR]", "```{r, message=FALSE}", r_code) |
| 94 | + |
| 95 | + # Ensure there is a blank space after each code block |
| 96 | + rmd_code <- gsub("```\\{r, message=FALSE\\}\\n(.+?)\\n```", "```{r, message=FALSE}\\n\\1\\n```\n", rmd_code, perl = TRUE) |
| 97 | + |
| 98 | + # Create a folder for the output RMarkdown file |
| 99 | + time_stamp <- format(Sys.time(), "%Y-%m-%d_%H-%M-%S") |
| 100 | + folder_name <- paste0(output_name, "_", time_stamp) |
| 101 | + dir.create(folder_name) |
| 102 | + |
| 103 | + file_path <- file.path(folder_name, paste0(output_name, ".Rmd")) |
| 104 | + |
| 105 | + # Write the header and dataset to the RMarkdown file |
| 106 | + rmd_header <- c( |
| 107 | + "---", |
| 108 | + "output:", |
| 109 | + " html_document:", |
| 110 | + " code_folding: hide", |
| 111 | + "---", |
| 112 | + "", |
| 113 | + "# Dataset", |
| 114 | + "```{r, message=FALSE}", |
| 115 | + "dataset <- ", deparse(data), |
| 116 | + "```", |
| 117 | + "" |
| 118 | + ) |
| 119 | + |
| 120 | + # Save the RMarkdown code |
| 121 | + writeLines(c(rmd_header, rmd_code), file_path) |
| 122 | + cat(paste("RMarkdown file for visualization saved in:", file_path, "\n")) |
| 123 | + |
| 124 | + # Append additional text at the end of the file |
| 125 | + additional_text <- "\n\n\n---\n\n\n\n\nThis analysis was created with [openAIScientist](https://github.com/noluyorAbi/openaAIScientist).\n\n Made with \u2665 by [noluyorAbi](https://github.com/noluyorAbi) for FortStaSoft @ LMU Munich" |
| 126 | + cat(additional_text, file = file_path, append = TRUE) |
| 127 | + |
| 128 | + # Token usage and cost calculation |
| 129 | + usage <- content$usage |
| 130 | + if (!is.null(usage)) { |
| 131 | + total_tokens <- usage$total_tokens |
| 132 | + total_input_tokens <- usage$prompt_tokens |
| 133 | + total_output_tokens <- usage$completion_tokens |
| 134 | + |
| 135 | + price_per_input_token <- 5.00 / 1e6 # $5 per 1M input tokens |
| 136 | + price_per_output_token <- 15.00 / 1e6 # $15 per 1M output tokens |
| 137 | + |
| 138 | + total_cost <- (total_input_tokens * price_per_input_token) + (total_output_tokens * price_per_output_token) |
| 139 | + cat("Initial call - Total tokens used:", total_tokens, "\n") |
| 140 | + cat("Initial call - Total cost (USD):", total_cost, "\n") |
| 141 | + } else { |
| 142 | + cat("Initial call - Token usage information not available.\n") |
| 143 | + } |
| 144 | + |
| 145 | + return(rmd_code) |
| 146 | + } else { |
| 147 | + cat("Failed to generate visualization. No content returned from OpenAI API.\n") |
| 148 | + return(NULL) |
| 149 | + } |
| 150 | +} |
0 commit comments