Skip to content

Commit 873884d

Browse files
feat: Add AIFoundaryChatCompletion (#2398)
* poc ai foundry * fix test * add ai foundry doc * enable ai foundry unit test * format doc photo * format
1 parent 83ebb5a commit 873884d

File tree

5 files changed

+468
-0
lines changed

5 files changed

+468
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright (C) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in project root for information.
3+
4+
package com.microsoft.azure.synapse.ml.services.aifoundry
5+
6+
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
7+
import com.microsoft.azure.synapse.ml.param.{GlobalParams, ServiceParam}
8+
import com.microsoft.azure.synapse.ml.services.openai._
9+
import org.apache.spark.ml.ComplexParamsReadable
10+
import org.apache.spark.ml.util._
11+
import org.apache.spark.sql.Row
12+
import spray.json.DefaultJsonProtocol._
13+
14+
import scala.language.existentials
15+
16+
trait HasAIFoundryTextParamsExtended extends HasOpenAITextParamsExtended {
17+
val model = new ServiceParam[String](
18+
this, "model", "The name of the model", isRequired = true)
19+
20+
GlobalParams.registerParam(model, OpenAIDeploymentNameKey)
21+
22+
def getModel: String = getScalarParam(model)
23+
24+
def setModel(v: String): this.type = setScalarParam(model, v)
25+
26+
override val sharedTextParams: Seq[ServiceParam[_]] = Seq(
27+
maxTokens,
28+
temperature,
29+
topP,
30+
user,
31+
n,
32+
echo,
33+
stop,
34+
cacheLevel,
35+
presencePenalty,
36+
frequencyPenalty,
37+
bestOf,
38+
logProbs,
39+
responseFormat,
40+
model
41+
)
42+
}
43+
44+
object AIFoundryChatCompletion extends ComplexParamsReadable[AIFoundryChatCompletion]
45+
46+
class AIFoundryChatCompletion(override val uid: String) extends OpenAIChatCompletion
47+
with HasAIFoundryTextParamsExtended with SynapseMLLogging {
48+
logClass(FeatureNames.AiServices.OpenAI)
49+
50+
def this() = this(Identifiable.randomUID("AIFoundryChatCompletion"))
51+
52+
override private[ml] def internalServiceType: String = "foundry"
53+
54+
override def setCustomServiceName(v: String): this.type = {
55+
setUrl(s"https://$v.services.ai.azure.com/" + urlPath.stripPrefix("/"))
56+
}
57+
58+
override protected def prepareUrlRoot: Row => String = { row =>
59+
s"${getUrl}models/chat/completions"
60+
}
61+
62+
}
63+
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
// Copyright (C) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in project root for information.
3+
4+
package com.microsoft.azure.synapse.ml.services.aifoundry
5+
6+
import com.microsoft.azure.synapse.ml.Secrets
7+
import com.microsoft.azure.synapse.ml.core.test.base.Flaky
8+
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
9+
import com.microsoft.azure.synapse.ml.services.openai.{ChatCompletionResponse, OpenAIMessage, OpenAIResponseFormat}
10+
import org.apache.spark.ml.util.MLReadable
11+
import org.apache.spark.sql.{DataFrame, Row}
12+
import org.scalactic.Equality
13+
14+
trait AIFoundryAPIKey {
15+
lazy val aiFoundryAPIKey: String = sys.env.getOrElse("AI_FOUNDRY_API_KEY", Secrets.AIFoundryApiKey)
16+
lazy val aiFoundryServiceName: String = sys.env.getOrElse("AI_FOUNDRY_SERVICE_NAME", "synapseml-ai-foundry-resource")
17+
lazy val modelName: String = "Phi-4-mini-instruct"
18+
}
19+
20+
class AIFoundryChatCompletionSuite extends TransformerFuzzing[AIFoundryChatCompletion] with AIFoundryAPIKey with Flaky {
21+
22+
import spark.implicits._
23+
24+
lazy val completion: AIFoundryChatCompletion = new AIFoundryChatCompletion()
25+
.setCustomServiceName(aiFoundryServiceName)
26+
.setApiVersion("2024-05-01-preview")
27+
.setMaxTokens(2048)
28+
.setOutputCol("out")
29+
.setMessagesCol("messages")
30+
.setTemperature(0)
31+
.setTopP(0.1)
32+
.setPresencePenalty(0.0)
33+
.setFrequencyPenalty(0.0)
34+
.setModel(modelName)
35+
.setSubscriptionKey(aiFoundryAPIKey)
36+
37+
lazy val goodDf: DataFrame = Seq(
38+
Seq(
39+
OpenAIMessage("system", "You are an AI chatbot with red as your favorite color"),
40+
OpenAIMessage("user", "Whats your favorite color")
41+
),
42+
Seq(
43+
OpenAIMessage("system", "You are very excited"),
44+
OpenAIMessage("user", "How are you today")
45+
),
46+
Seq(
47+
OpenAIMessage("system", "You are a helpful assistant"),
48+
OpenAIMessage("user", "I need to calculate how much apples I sold today"),
49+
OpenAIMessage("system", "How many apples you sold in each transaction"),
50+
OpenAIMessage("user", "One in the first, two in the second")
51+
)
52+
).toDF("messages")
53+
54+
lazy val badDf: DataFrame = Seq(
55+
Seq(),
56+
Seq(
57+
OpenAIMessage("system", "You are very excited"),
58+
OpenAIMessage("user", null) //scalastyle:ignore null
59+
),
60+
null //scalastyle:ignore null
61+
).toDF("messages")
62+
63+
test("Basic Usage") {
64+
testCompletion(completion, goodDf)
65+
}
66+
67+
test("Robustness to bad inputs") {
68+
val results = completion.transform(badDf).collect()
69+
assert(Option(results.head.getAs[Row](completion.getErrorCol)).isDefined)
70+
// empty user message is valid for Phi 4, no error
71+
//assert(Option(results.apply(1).getAs[Row](completion.getErrorCol)).isDefined)
72+
assert(Option(results.apply(2).getAs[Row](completion.getErrorCol)).isEmpty)
73+
assert(Option(results.apply(2).getAs[Row]("out")).isEmpty)
74+
}
75+
76+
test("getOptionalParam should include responseFormat"){
77+
val completion = new AIFoundryChatCompletion()
78+
.setCustomServiceName(aiFoundryServiceName)
79+
80+
def validateResponseFormat(params: Map[String, Any], responseFormat: String): Unit = {
81+
val responseFormatPayloadName = this.completion.responseFormat.payloadName
82+
assert(params.contains(responseFormatPayloadName))
83+
val responseFormatMap = params(responseFormatPayloadName).asInstanceOf[Map[String, String]]
84+
assert(responseFormatMap.contains("type"))
85+
assert(responseFormatMap("type") == responseFormat)
86+
}
87+
88+
val messages: Seq[Row] = Seq(
89+
OpenAIMessage("user", "Whats your favorite color")
90+
).toDF("role", "content", "name").collect()
91+
92+
val optionalParams: Map[String, Any] = completion.getOptionalParams(messages.head)
93+
assert(!optionalParams.contains("response_format"))
94+
95+
completion.setResponseFormat("")
96+
val optionalParams0: Map[String, Any] = completion.getOptionalParams(messages.head)
97+
assert(!optionalParams0.contains("response_format"))
98+
99+
completion.setResponseFormat("json_object")
100+
val optionalParams1: Map[String, Any] = completion.getOptionalParams(messages.head)
101+
validateResponseFormat(optionalParams1, "json_object")
102+
103+
completion.setResponseFormat("text")
104+
val optionalParams2: Map[String, Any] = completion.getOptionalParams(messages.head)
105+
validateResponseFormat(optionalParams2, "text")
106+
107+
completion.setResponseFormat(Map("type" -> "json_object"))
108+
val optionalParams3: Map[String, Any] = completion.getOptionalParams(messages.head)
109+
validateResponseFormat(optionalParams3, "json_object")
110+
111+
completion.setResponseFormat(OpenAIResponseFormat.TEXT)
112+
val optionalParams4: Map[String, Any] = completion.getOptionalParams(messages.head)
113+
validateResponseFormat(optionalParams4, "text")
114+
}
115+
116+
test("setResponseFormat should throw exception if invalid format"){
117+
val completion = new AIFoundryChatCompletion()
118+
.setCustomServiceName(aiFoundryServiceName)
119+
120+
assertThrows[IllegalArgumentException] {
121+
completion.setResponseFormat("invalid_format")
122+
}
123+
124+
assertThrows[IllegalArgumentException] {
125+
completion.setResponseFormat(Map("type" -> "invalid_format"))
126+
}
127+
128+
assertThrows[IllegalArgumentException] {
129+
completion.setResponseFormat(Map("invalid_key" -> "json_object"))
130+
}
131+
}
132+
133+
test("validate accepts json_object response format") {
134+
val goodDf: DataFrame = Seq(
135+
Seq(
136+
OpenAIMessage("system", "You are an AI chatbot with red as your favorite color"),
137+
OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt),
138+
OpenAIMessage("user", "Whats your favorite color")
139+
),
140+
Seq(
141+
OpenAIMessage("system", "You are very excited"),
142+
OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt),
143+
OpenAIMessage("user", "How are you today")
144+
),
145+
Seq(
146+
OpenAIMessage("system", OpenAIResponseFormat.JSON.prompt),
147+
OpenAIMessage("system", "You are very excited"),
148+
OpenAIMessage("user", "How are you today"),
149+
OpenAIMessage("system", "Better than ever"),
150+
OpenAIMessage("user", "Why?")
151+
)
152+
).toDF("messages")
153+
154+
val completion = new AIFoundryChatCompletion()
155+
.setCustomServiceName(aiFoundryServiceName)
156+
.setModel(modelName)
157+
.setApiVersion("2024-05-01-preview")
158+
.setMaxTokens(500)
159+
.setOutputCol("out")
160+
.setMessagesCol("messages")
161+
.setTemperature(0)
162+
.setSubscriptionKey(aiFoundryAPIKey)
163+
.setResponseFormat("json_object")
164+
165+
testCompletion(completion, goodDf)
166+
}
167+
168+
def testCompletion(completion: AIFoundryChatCompletion, df: DataFrame, requiredLength: Int = 10): Unit = {
169+
val fromRow = ChatCompletionResponse.makeFromRowConverter
170+
completion.transform(df).collect().foreach(r =>
171+
fromRow(r.getAs[Row]("out")).choices.foreach(c =>
172+
assert(c.message.content.length > requiredLength)))
173+
}
174+
175+
override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
176+
super.assertDFEq(df1.drop("out"), df2.drop("out"))(eq)
177+
}
178+
179+
override def testObjects(): Seq[TestObject[AIFoundryChatCompletion]] =
180+
Seq(new TestObject(completion, goodDf))
181+
182+
override def reader: MLReadable[_] = AIFoundryChatCompletion
183+
184+
}

core/src/test/scala/com/microsoft/azure/synapse/ml/Secrets.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ object Secrets {
5555

5656
lazy val CognitiveApiKey: String = getSecret("cognitive-api-key")
5757
lazy val OpenAIApiKey: String = getSecret("openai-api-key-2")
58+
lazy val AIFoundryApiKey: String = getSecret("synapseml-ai-foundry-resource-key")
5859

5960
lazy val CustomSpeechApiKey: String = getSecret("custom-speech-api-key")
6061
lazy val ConversationTranscriptionUrl: String = getSecret("conversation-transcription-url")

0 commit comments

Comments
 (0)