|
| 1 | +// Copyright (C) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. See LICENSE in project root for information. |
| 3 | + |
| 4 | +package com.microsoft.azure.synapse.ml.services.aifoundry |
| 5 | + |
| 6 | +import com.microsoft.azure.synapse.ml.Secrets |
| 7 | +import com.microsoft.azure.synapse.ml.core.test.base.Flaky |
| 8 | +import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} |
| 9 | +import com.microsoft.azure.synapse.ml.services.openai.{ChatCompletionResponse, OpenAIMessage, OpenAIResponseFormat} |
| 10 | +import org.apache.spark.ml.util.MLReadable |
| 11 | +import org.apache.spark.sql.{DataFrame, Row} |
| 12 | +import org.scalactic.Equality |
| 13 | + |
| 14 | +trait AIFoundryAPIKey { |
| 15 | + lazy val aiFoundryAPIKey: String = sys.env.getOrElse("AI_FOUNDRY_API_KEY", Secrets.AIFoundryApiKey) |
| 16 | + lazy val aiFoundryServiceName: String = sys.env.getOrElse("AI_FOUNDRY_SERVICE_NAME", "synapseml-ai-foundry-resource") |
| 17 | + lazy val modelName: String = "Phi-4-mini-instruct" |
| 18 | +} |
| 19 | + |
| 20 | +class AIFoundryChatCompletionSuite extends TransformerFuzzing[AIFoundryChatCompletion] with AIFoundryAPIKey with Flaky { |
| 21 | + |
| 22 | + import spark.implicits._ |
| 23 | + |
| 24 | + lazy val completion: AIFoundryChatCompletion = new AIFoundryChatCompletion() |
| 25 | + .setCustomServiceName(aiFoundryServiceName) |
| 26 | + .setApiVersion("2024-05-01-preview") |
| 27 | + .setMaxTokens(2048) |
| 28 | + .setOutputCol("out") |
| 29 | + .setMessagesCol("messages") |
| 30 | + .setTemperature(0) |
| 31 | + .setTopP(0.1) |
| 32 | + .setPresencePenalty(0.0) |
| 33 | + .setFrequencyPenalty(0.0) |
| 34 | + .setModel(modelName) |
| 35 | + .setSubscriptionKey(aiFoundryAPIKey) |
| 36 | + |
| 37 | + lazy val goodDf: DataFrame = Seq( |
| 38 | + Seq( |
| 39 | + OpenAIMessage("system", "You are an AI chatbot with red as your favorite color"), |
| 40 | + OpenAIMessage("user", "Whats your favorite color") |
| 41 | + ), |
| 42 | + Seq( |
| 43 | + OpenAIMessage("system", "You are very excited"), |
| 44 | + OpenAIMessage("user", "How are you today") |
| 45 | + ), |
| 46 | + Seq( |
| 47 | + OpenAIMessage("system", "You are a helpful assistant"), |
| 48 | + OpenAIMessage("user", "I need to calculate how much apples I sold today"), |
| 49 | + OpenAIMessage("system", "How many apples you sold in each transaction"), |
| 50 | + OpenAIMessage("user", "One in the first, two in the second") |
| 51 | + ) |
| 52 | + ).toDF("messages") |
| 53 | + |
| 54 | + lazy val badDf: DataFrame = Seq( |
| 55 | + Seq(), |
| 56 | + Seq( |
| 57 | + OpenAIMessage("system", "You are very excited"), |
| 58 | + OpenAIMessage("user", null) //scalastyle:ignore null |
| 59 | + ), |
| 60 | + null //scalastyle:ignore null |
| 61 | + ).toDF("messages") |
| 62 | + |
| 63 | + test("Basic Usage") { |
| 64 | + testCompletion(completion, goodDf) |
| 65 | + } |
| 66 | + |
| 67 | + test("Robustness to bad inputs") { |
| 68 | + val results = completion.transform(badDf).collect() |
| 69 | + assert(Option(results.head.getAs[Row](completion.getErrorCol)).isDefined) |
| 70 | + // empty user message is valid for Phi 4, no error |
| 71 | + //assert(Option(results.apply(1).getAs[Row](completion.getErrorCol)).isDefined) |
| 72 | + assert(Option(results.apply(2).getAs[Row](completion.getErrorCol)).isEmpty) |
| 73 | + assert(Option(results.apply(2).getAs[Row]("out")).isEmpty) |
| 74 | + } |
| 75 | + |
| 76 | + test("getOptionalParam should include responseFormat"){ |
| 77 | + val completion = new AIFoundryChatCompletion() |
| 78 | + .setCustomServiceName(aiFoundryServiceName) |
| 79 | + |
| 80 | + def validateResponseFormat(params: Map[String, Any], responseFormat: String): Unit = { |
| 81 | + val responseFormatPayloadName = this.completion.responseFormat.payloadName |
| 82 | + assert(params.contains(responseFormatPayloadName)) |
| 83 | + val responseFormatMap = params(responseFormatPayloadName).asInstanceOf[Map[String, String]] |
| 84 | + assert(responseFormatMap.contains("type")) |
| 85 | + assert(responseFormatMap("type") == responseFormat) |
| 86 | + } |
| 87 | + |
| 88 | + val messages: Seq[Row] = Seq( |
| 89 | + OpenAIMessage("user", "Whats your favorite color") |
| 90 | + ).toDF("role", "content", "name").collect() |
| 91 | + |
| 92 | + val optionalParams: Map[String, Any] = completion.getOptionalParams(messages.head) |
| 93 | + assert(!optionalParams.contains("response_format")) |
| 94 | + |
| 95 | + completion.setResponseFormat("") |
| 96 | + val optionalParams0: Map[String, Any] = completion.getOptionalParams(messages.head) |
| 97 | + assert(!optionalParams0.contains("response_format")) |
| 98 | + |
| 99 | + completion.setResponseFormat("json_object") |
| 100 | + val optionalParams1: Map[String, Any] = completion.getOptionalParams(messages.head) |
| 101 | + validateResponseFormat(optionalParams1, "json_object") |
| 102 | + |
| 103 | + completion.setResponseFormat("text") |
| 104 | + val optionalParams2: Map[String, Any] = completion.getOptionalParams(messages.head) |
| 105 | + validateResponseFormat(optionalParams2, "text") |
| 106 | + |
| 107 | + completion.setResponseFormat(Map("type" -> "json_object")) |
| 108 | + val optionalParams3: Map[String, Any] = completion.getOptionalParams(messages.head) |
| 109 | + validateResponseFormat(optionalParams3, "json_object") |
| 110 | + |
| 111 | + completion.setResponseFormat(OpenAIResponseFormat.TEXT) |
| 112 | + val optionalParams4: Map[String, Any] = completion.getOptionalParams(messages.head) |
| 113 | + validateResponseFormat(optionalParams4, "text") |
| 114 | + } |
| 115 | + |
| 116 | + test("setResponseFormat should throw exception if invalid format"){ |
| 117 | + val completion = new AIFoundryChatCompletion() |
| 118 | + .setCustomServiceName(aiFoundryServiceName) |
| 119 | + |
| 120 | + assertThrows[IllegalArgumentException] { |
| 121 | + completion.setResponseFormat("invalid_format") |
| 122 | + } |
| 123 | + |
| 124 | + assertThrows[IllegalArgumentException] { |
| 125 | + completion.setResponseFormat(Map("type" -> "invalid_format")) |
| 126 | + } |
| 127 | + |
| 128 | + assertThrows[IllegalArgumentException] { |
| 129 | + completion.setResponseFormat(Map("invalid_key" -> "json_object")) |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + test("validate accepts json_object response format") { |
| 134 | + val goodDf: DataFrame = Seq( |
| 135 | + Seq( |
| 136 | + OpenAIMessage("system", "You are an AI chatbot with red as your favorite color"), |
| 137 | + OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt), |
| 138 | + OpenAIMessage("user", "Whats your favorite color") |
| 139 | + ), |
| 140 | + Seq( |
| 141 | + OpenAIMessage("system", "You are very excited"), |
| 142 | + OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt), |
| 143 | + OpenAIMessage("user", "How are you today") |
| 144 | + ), |
| 145 | + Seq( |
| 146 | + OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt), |
| 147 | + OpenAIMessage("system", "You are very excited"), |
| 148 | + OpenAIMessage("user", "How are you today"), |
| 149 | + OpenAIMessage("system", "Better than ever"), |
| 150 | + OpenAIMessage("user", "Why?") |
| 151 | + ) |
| 152 | + ).toDF("messages") |
| 153 | + |
| 154 | + val completion = new AIFoundryChatCompletion() |
| 155 | + .setCustomServiceName(aiFoundryServiceName) |
| 156 | + .setModel(modelName) |
| 157 | + .setApiVersion("2024-05-01-preview") |
| 158 | + .setMaxTokens(500) |
| 159 | + .setOutputCol("out") |
| 160 | + .setMessagesCol("messages") |
| 161 | + .setTemperature(0) |
| 162 | + .setSubscriptionKey(aiFoundryAPIKey) |
| 163 | + .setResponseFormat("json_object") |
| 164 | + |
| 165 | + testCompletion(completion, goodDf) |
| 166 | + } |
| 167 | + |
| 168 | + def testCompletion(completion: AIFoundryChatCompletion, df: DataFrame, requiredLength: Int = 10): Unit = { |
| 169 | + val fromRow = ChatCompletionResponse.makeFromRowConverter |
| 170 | + completion.transform(df).collect().foreach(r => |
| 171 | + fromRow(r.getAs[Row]("out")).choices.foreach(c => |
| 172 | + assert(c.message.content.length > requiredLength))) |
| 173 | + } |
| 174 | + |
| 175 | + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { |
| 176 | + super.assertDFEq(df1.drop("out"), df2.drop("out"))(eq) |
| 177 | + } |
| 178 | + |
| 179 | + override def testObjects(): Seq[TestObject[AIFoundryChatCompletion]] = |
| 180 | + Seq(new TestObject(completion, goodDf)) |
| 181 | + |
| 182 | + override def reader: MLReadable[_] = AIFoundryChatCompletion |
| 183 | + |
| 184 | +} |
0 commit comments