Query Foundation Models with Databricks SQL#

This notebook shows how to use the ai_query SQL function to query models from the Foundation Model API. The ai_query function provides a simple and familiar SQL interface to Large Language Models via the Foundation Model API. It also provides a straightforward way to integrate LLMs into existing SQL-based Databricks workflows.

Basic Syntax#

The most basic usage of the ai_query function takes a model endpoint name and a text query and returns a single result.

Generate Completions#

We can generate chat or text completions with supported Foundation Model API models as follows:

%sql
SELECT ai_query("databricks-llama-2-70b-chat", "What is model logging in MLflow?");
ai_query(databricks-llama-2-70b-chat, What is model logging in MLflow?, returnTypeDefaultValue)
In MLflow, model logging is the process of tracking and recording events and metrics related to the performance and behavior of machine learning models during their deployment. It is an essential feature of the MLflow platform that enables data scientists and engineers to monitor, analyze, and optimize the performance of their models in production environments. Model logging in MLflow involves collecting data on various aspects of model behavior, such as: 1. Model predictions: MLflow logs the predictions made by the model for each input it receives, along with the corresponding output. 2. Model performance: MLflow tracks metrics such as accuracy, precision, recall, F1 score, and other custom metrics that are relevant to the specific model and task. 3. Data distribution: MLflow logs information about the distribution of the data that the model processes, including histograms, min/max values, and other summary statistics. 4. Model internals: MLflow can log information about the internal workings of the model, such as the activation values of layers, the gradients, and other parameters that can help in understanding the model's behavior. 5. Environmental variables: MLflow logs information about the environment in which the model is running, including the hardware, software, and other relevant details. The logged data is stored in a database, and users can query and visualize it using MLflow's built-in tools or third-party visualization platforms. Model logging in MLflow provides several benefits, including: 1. Model performance optimization: By analyzing the logged data, users can identify areas where the model can be improved, such as biases, errors, or inefficiencies. 2. Explainability and interpretability: Logging model behavior helps users understand how the model works, how it makes predictions, and what factors it considers most important. 3. Compliance and auditing: Model logging is essential for ensuring compliance with data privacy and security regulations, as it provides a record of all data processed by the model. 4. Collaboration and knowledge sharing: Logged data can be shared across teams, allowing data scientists and engineers to collaborate more effectively and build upon each other's work. In summary, model logging in MLflow is a powerful feature that enables users to monitor, analyze, and optimize the performance of their machine learning models in production environments. It provides valuable insights into model behavior, helps improve model performance, and ensures compliance with data privacy and security regulations.

Generate Embeddings#

Similarly, we can generate embeddings with:

%sql
SELECT ai_query("databricks-bge-large-en", "What is model logging in MLflow?");
ai_query(databricks-bge-large-en, What is model logging in MLflow?, returnTypeDefaultValue)
List(0.015533447, 0.010643005, -0.023101807, 0.003955841, 0.015945435, -0.008522034, 0.008705139, -0.016174316, 0.01689148, 0.07293701, 0.011688232, -0.002708435, 0.01134491, -0.038848877, -0.03729248, 0.024841309, 0.005264282, -0.0552063, -0.030548096, -0.01939392, 0.02368164, -0.009109497, -0.027511597, -0.0287323, -0.028457642, 0.028030396, 0.0052375793, -0.041992188, 0.07183838, 0.03515625, -0.019210815, 0.028701782, -0.028656006, -0.035491943, -0.0076789856, -0.052093506, 0.0068855286, -0.02078247, -0.032684326, -0.05606079, 0.04458618, 0.0051231384, 0.062683105, -0.07006836, -0.016723633, -0.0096588135, 0.03274536, 0.035369873, 0.05819702, -0.018249512, -0.037628174, 0.001039505, -0.016433716, -0.02645874, -0.027130127, -0.009841919, 0.0029411316, 0.0126953125, -0.04421997, 0.009216309, 0.05053711, 0.03024292, 0.0039634705, -0.07965088, 0.03475952, 0.0077400208, -0.034454346, -0.0018749237, -0.0054092407, -0.01966858, -0.031280518, 0.015731812, -0.01499939, -0.0021762848, -0.038391113, 0.039398193, 0.029510498, -0.008079529, -7.343292E-5, 0.03164673, 0.008239746, 0.021133423, 0.011207581, -0.007129669, -0.0076942444, -0.048431396, 0.008850098, -0.0013341904, -0.033599854, -0.024429321, 0.008117676, 0.06713867, 0.005508423, 0.016174316, 0.023910522, 0.03488159, -0.039245605, 0.028518677, -0.028259277, -0.014930725, 0.025894165, 0.091308594, 0.0178833, 0.0036945343, -0.052947998, -0.016433716, 0.040405273, 0.007171631, -0.0657959, -0.05441284, -0.018997192, 0.027297974, -0.042114258, 0.016723633, -0.044433594, 0.007865906, 0.022369385, -0.012069702, -0.045715332, -0.03286743, 0.03552246, 0.007858276, 0.022918701, -0.05697632, -0.018371582, 0.004917145, -0.0030345917, 0.070617676, -0.0259552, 0.017532349, -0.014915466, -0.043548584, 0.027282715, 0.023895264, -0.0151901245, -0.02067566, -0.009025574, 0.029724121, 0.02255249, 0.012702942, 0.014640808, 0.0027122498, -0.0012168884, 0.07788086, 0.019760132, 0.02911377, 0.029663086, -0.017608643, -0.014350891, -0.0032291412, -0.019973755, -0.0141067505, -0.004081726, 0.044769287, -0.028457642, -0.002281189, -0.06072998, -0.0025081635, 0.0057792664, -0.0015172958, 0.016052246, 0.02255249, -0.060424805, 0.035888672, -0.027893066, 0.029525757, -0.027816772, -0.010948181, 8.1825256E-4, -0.03326416, 0.019042969, -0.009544373, -0.0030918121, 0.007575989, 0.026489258, -1.1074543E-4, -0.017669678, 0.03555298, 0.039794922, 0.0024280548, -0.0024032593, -0.023239136, 0.022399902, 0.0054893494, 0.023788452, 0.0104904175, 0.038848877, 0.013130188, -0.024230957, 0.017364502, -0.005142212, 0.07525635, -0.051513672, 0.042999268, -0.035125732, -0.019958496, -0.067993164, 0.006866455, 0.0015010834, -0.017730713, -0.0129470825, 0.038513184, -0.051696777, -2.477169E-4, -0.06524658, 0.020492554, 0.026733398, 0.043701172, -0.016235352, 0.030883789, 0.03543091, -0.013473511, -0.023727417, -0.07299805, 0.010658264, -0.020248413, 0.0129776, 7.02858E-4, -0.0015134811, -0.02029419, 0.021591187, -0.013931274, 0.031402588, 0.020019531, -0.0014801025, -0.080566406, -0.0029716492, 0.031707764, -0.006111145, -0.002696991, 0.019088745, 0.010269165, 0.03543091, 0.055114746, 0.037200928, 0.008621216, 0.07244873, 0.023925781, 0.010787964, 0.0065460205, 0.0067214966, 0.006401062, -0.006790161, 0.033172607, 0.023849487, -0.025665283, -0.022262573, 0.00667572, -0.003107071, -0.014022827, -0.03567505, 0.043792725, 0.048065186, -0.029556274, -0.01335144, 0.021850586, 0.011764526, 0.047332764, -0.022216797, 0.010093689, 0.009811401, 0.033813477, -0.018249512, -0.0703125, 0.015823364, 0.06573486, 0.008163452, -0.012290955, -9.608269E-4, -0.04119873, 0.008323669, -0.030136108, -0.08654785, -0.02848816, -0.013298035, 0.0036945343, -0.026229858, -0.047424316, -0.025894165, -0.0023708344, 0.008468628, -0.018585205, -0.013648987, 0.03387451, 0.0069732666, 0.001876831, -0.054382324, 0.029388428, -0.05404663, 0.024261475, -0.030639648, -0.029922485, -0.009498596, -0.02607727, 2.2256374E-4, 0.018615723, -0.003255844, -0.048034668, -0.041046143, -0.0052604675, -0.018341064, -0.009117126, 0.012840271, -0.018112183, -0.050689697, 0.027694702, 0.03366089, -0.04458618, 0.017333984, 0.06451416, -0.03375244, 0.020690918, -0.018554688, 0.002243042, -0.03152466, 0.024139404, 0.056640625, -0.020019531, 0.017868042, -0.055908203, -0.011405945, 0.019088745, 5.5265427E-4, -0.021987915, -0.017166138, 0.057525635, 0.033050537, -0.031707764, 0.01828003, 0.0026435852, 0.009994507, -0.03527832, 0.022857666, 0.008728027, 0.068115234, 0.0491333, -0.051239014, -0.06213379, -0.03050232, 0.03149414, 0.0211792, 0.022079468, 0.01537323, 0.032592773, -0.014228821, -0.0042533875, 0.03503418, -0.028213501, 0.0032291412, 0.007575989, -0.006313324, -0.0047035217, 0.003080368, 0.04446411, -0.011917114, 0.023544312, 0.02168274, -0.0041122437, 0.00793457, -0.006134033, 0.014190674, 0.010574341, -0.0014619827, 0.0079574585, -0.035461426, -0.03717041, -0.0104599, 0.0033397675, 0.0440979, -0.0597229, 0.020614624, -3.042221E-4, -0.046325684, -0.006061554, -0.019332886, -0.039367676, 0.007293701, 0.0128479, 0.048034668, -0.03881836, -0.026412964, 0.052337646, 0.0017824173, 0.034240723, 0.04928589, -0.001241684, -0.054260254, -0.023498535, -0.013069153, 0.016311646, -0.024337769, 0.025054932, -0.025146484, -0.0056610107, -0.022583008, -0.04397583, 0.040618896, 0.072753906, 0.0052490234, 0.0042877197, 0.07336426, -0.0031795502, 0.0013828278, 0.026779175, -0.005344391, 0.01751709, -0.022003174, 0.039855957, -0.00466156, 0.02281189, 0.006111145, -0.025970459, 0.006755829, 0.019073486, 0.02571106, -0.023086548, -0.015808105, 0.042022705, 0.023330688, 0.019241333, -0.027893066, -0.0059165955, -0.023773193, 0.011657715, 0.031402588, -0.04647827, 0.04006958, -0.00623703, 0.022476196, 0.012390137, -0.015945435, -0.05999756, -0.022216797, -0.012077332, -0.04714966, 0.067871094, -0.004634857, -0.029663086, -0.030380249, 0.005329132, 0.009880066, 0.0020484924, -0.029525757, -0.014030457, 0.004550934, -0.013908386, 0.043884277, -0.009468079, -0.017562866, -0.046844482, 0.029754639, -0.038208008, 0.0017518997, -0.06756592, 0.03137207, -0.0073547363, -6.0129166E-4, -0.016738892, -0.005897522, 0.018218994, -0.027267456, 0.010322571, 0.012420654, -0.024734497, -0.042297363, 0.050720215, 0.061157227, -0.017364502, 0.041503906, 0.029281616, -0.009399414, 0.013053894, -0.010414124, -0.02420044, -0.005405426, -0.003080368, -0.0028190613, -0.0063285828, -0.004146576, 0.01360321, -0.05291748, 0.033966064, 0.010620117, -0.0013084412, 0.021011353, -0.040618896, -0.03866577, -0.016418457, 0.0095825195, 0.005756378, 0.028839111, 0.011985779, 0.014083862, -0.027008057, -0.0049095154, -0.029251099, -0.046569824, -0.009925842, 0.031555176, 0.03857422, 0.038879395, -0.031066895, -0.0335083, 0.056915283, -0.013389587, -0.017196655, -0.003358841, 0.036743164, 0.03967285, 0.03338623, -0.012237549, -0.029647827, 0.0070610046, -0.053527832, -7.638931E-4, -0.0209198, 0.021499634, -0.0062294006, 0.008079529, -0.014259338, 8.840561E-4, -0.037994385, -0.01890564, -0.012687683, -0.0029411316, -0.003206253, 0.0067634583, -0.037872314, -0.020339966, 0.008644104, 0.01739502, 0.0010557175, -0.0018062592, -0.029510498, -0.06738281, -0.009933472, -0.011917114, 0.047424316, -0.04953003, 2.5689602E-5, 0.03237915, 0.035705566, -0.03277588, -0.010444641, -0.046325684, 0.039215088, -0.003753662, 0.045715332, 0.00982666, -0.024002075, -0.032470703, -0.013748169, 0.028656006, 0.02017212, -0.042877197, -0.0287323, -0.03262329, -0.0047836304, -0.008705139, 0.016082764, 0.011238098, 0.020767212, -0.01424408, -5.7399273E-5, -0.04638672, -0.039154053, 0.0012655258, -0.02659607, 0.066833496, -0.05319214, 0.022949219, -0.0049057007, 0.009613037, -0.023025513, 0.0012283325, -0.030334473, 0.004722595, -0.04953003, -0.06317139, 0.052001953, -0.029800415, -0.013175964, -0.016403198, -0.0060539246, 0.049804688, 0.009597778, 0.0045204163, 0.053527832, 0.016616821, -0.012237549, -0.049865723, 0.020553589, -0.0184021, -0.0019893646, 0.007698059, -3.7789345E-4, -0.052856445, -0.023330688, 0.013290405, -0.035125732, 0.009384155, 0.016540527, 0.016937256, 0.0134887695, 0.03338623, 0.005218506, -0.04537964, -0.03225708, 0.02734375, 0.031982422, 0.05697632, 0.03982544, 0.025421143, 0.029754639, 0.026016235, -0.004295349, -0.029037476, -0.040893555, 0.049194336, 0.002450943, -0.019058228, -0.023605347, 0.018508911, -0.03756714, -0.053649902, 0.006538391, 0.0016727448, 0.0075683594, -0.02130127, -0.02079773, -0.029891968, -0.008972168, 0.006717682, 0.003332138, 0.0023059845, 0.024612427, -0.008743286, -0.008583069, -0.054504395, 0.008621216, 0.04345703, -0.04284668, 0.030410767, -0.026901245, -0.035369873, -0.012825012, 0.01626587, 0.019485474, -0.010429382, 0.021072388, 0.05203247, -0.017578125, 0.05923462, 0.013710022, 0.037872314, 0.031021118, -0.015899658, -0.047210693, -0.021347046, 8.711815E-4, 0.025619507, 0.03503418, 0.0069847107, 0.02758789, 0.045288086, 0.0446167, 0.0031833649, -0.058624268, -0.001042366, -0.049835205, -0.041809082, -0.03829956, -0.011123657, -0.0027675629, -0.0154953, -0.001531601, -0.039154053, -0.020339966, 0.01083374, -0.024627686, -0.0027694702, -0.032989502, 0.0074157715, -0.03817749, 0.008728027, -0.040924072, 0.016036987, 0.030685425, 0.018814087, -0.027435303, 0.024749756, -0.008552551, 0.03665161, -0.024642944, -0.028182983, 0.0066871643, 0.0073432922, 0.003068924, 0.0096588135, -0.06365967, 0.039093018, 0.004047394, -0.011955261, 0.02798462, 0.003944397, -0.013893127, -0.044036865, -8.111E-4, 0.007575989, 3.077984E-4, -0.01727295, -0.022399902, 0.043823242, -0.020614624, -0.010314941, -0.013298035, -0.03302002, 0.018432617, 0.024642944, 0.01058197, 0.030029297, 0.045135498, 0.026779175, -0.034362793, 0.03805542, 0.0056381226, 0.019622803, 0.07446289, 0.008087158, -0.006324768, -0.06262207, 0.04019165, 0.009613037, -0.044006348, 0.029388428, -0.00806427, -0.02128601, -0.012138367, -0.046783447, 0.016555786, 0.07336426, 0.020599365, -0.029785156, -0.025741577, -0.0075263977, -0.029922485, -0.042053223, -0.07965088, 0.02243042, 0.024627686, -0.023330688, 0.00920105, -0.0118255615, 0.012237549, 0.0068740845, 0.045928955, 0.026504517, 0.018356323, 0.00579834, 0.0016012192, 0.0056533813, -0.01625061, 5.7029724E-4, 0.016555786, -0.007171631, 0.00207901, -0.042266846, 0.017181396, 0.061431885, 0.016082764, 0.008598328, 0.0053634644, -0.025726318, -0.008117676, 0.031204224, -0.034698486, -0.010726929, -0.01838684, 0.04083252, 0.018112183, 0.015106201, -0.036102295, -0.02482605, 0.025497437, 1.18136406E-4, 0.021850586, -0.008666992, 0.009338379, -8.8500977E-4, 0.041107178, 0.014480591, 0.015220642, -0.015853882, 0.010871887, -0.010559082, 0.026168823, 0.033569336, 0.03164673, -0.026550293, 0.052703857, 0.024093628, -0.058776855, -0.017120361, 0.0803833, 0.029891968, 0.023376465, -0.010368347, -0.044311523, 0.031982422, -3.4308434E-4, -0.0018835068, -0.014328003, -0.002029419, -0.03390503, -0.011955261, 0.018432617, 0.011833191, 0.012931824, -0.050750732, -0.018875122, -0.01235199, 0.074035645, 0.0012493134, 0.0035686493, 0.023071289, 0.034484863, 0.022903442, 0.034698486, 0.0039863586, 0.039001465, 0.03314209, -0.012916565, 0.022537231, -0.006816864, -3.8146973E-4, -0.036315918, -0.01071167, -0.019699097, -0.010093689, -0.031799316, -0.045959473, 3.2114983E-4, -0.0066604614, -0.0011262894, -0.09106445, -0.0017852783, -0.03274536, -0.0015258789, 0.05328369, -0.003929138, 0.015411377, 0.05142212, 0.05255127, 0.04385376, 0.013244629, 0.073791504, 0.003232956, -0.009376526, 0.031707764, -0.014312744, 0.013000488, 0.009963989, -0.017807007, -0.016845703, 0.039794922, 0.022018433, 0.033355713, 0.0021018982, -0.07147217, 0.010971069, 0.022125244, -0.017456055, -0.06549072, 0.033081055, 0.0013437271, -0.02458191, -0.018356323, 0.06414795, 0.03173828, 0.022109985, -0.0021629333, 0.009635925, -0.0064086914, -0.016540527, -0.029830933, 0.033599854, -0.0149002075, -0.03579712, 0.0074386597, -0.028564453, -0.002035141, -0.021987915, -0.019607544, -0.028549194, 0.022140503, 0.022155762, 0.056640625, -0.021881104, -0.03262329, -0.07122803, 0.041229248, 0.02658081, -0.019332886, 0.079589844, 0.04562378, -0.04647827, -0.029067993, -0.013259888, 0.023269653, 0.007320404, 0.009567261, -0.009773254, 0.0034503937, -0.0104522705, -0.06713867, -0.0074310303, 6.709099E-4, 0.060943604, -0.009529114, -0.009780884, -0.04449463, -0.037841797, -0.0010290146, -0.06713867, 0.03677368, 0.04244995, -0.016906738, -0.0056266785, -0.040649414, 0.21972656, 0.0647583, 0.019897461, -3.7240982E-4, 0.09436035, 0.07550049, 0.0109939575, 0.013023376, 4.5466423E-4, -0.04675293, -0.0050621033, 0.009567261, 0.0066833496, 0.055267334, 0.0028247833, 0.05142212, -0.062316895, 0.01953125, -0.00712204, -0.060943604, -0.05215454, 0.05529785, 9.622574E-4, 0.044708252, 0.02835083, -0.011764526, 0.011962891, -0.023101807, 0.04336548, -0.05319214, 0.023223877, -0.03253174, -0.017959595, -0.015716553, -0.02671814, -0.010398865, 0.026657104, -0.005756378, 0.0013360977, 0.04083252, 0.0071144104, -0.023513794, 0.011230469, -0.017501831, -0.03265381, 0.04083252, -0.024230957, 0.010009766, 7.019043E-4, -0.041168213, 0.018371582, -0.014533997, 0.009887695, -0.05178833, -0.04537964, 0.029922485, -0.007663727, -0.05142212, -0.007007599, 0.03338623, -5.9604645E-4, 0.01864624, -0.020889282, 0.02281189, -0.015533447, 0.029800415, -0.038635254, -0.053100586, -0.03640747, -0.033233643, -0.00274086, 0.004146576, -0.029281616, -0.043395996, -0.0047683716, 0.046203613, -0.016784668, 0.013038635, -0.01789856, -0.02281189, -0.021072388, -0.0019197464, 0.027145386, -0.05038452, 0.019454956, -0.0061569214, -0.024749756, 0.019500732, 0.0011720657, -0.020965576, 0.027542114, 0.035827637, -0.008544922, 0.021377563, 0.04034424)

Applying AI SQL Functions to multiple rows#

In the SQL context, we are probably not just generating one-off LLM completions. Instead, we tend to use these functions to operate on some data. Here’s an example. Let’s use ai_query to correct some SQL queries and explain the corrections. We’ll start by generating a small synthetic dataset of SQL queries with errors.

# Data
data = [
    (1, "SELECT FROM users WHERE id = 101;"),
    (2, "INSERT INTO products (name, price) VALSES ('Laptop', 999.99);"),
    (3, "UPDATE orders SET price = '299.99' WHERE od_id = 1;")
]

# Define schema
schema = ["id", "raw_sql"]

# Create DataFrame
df = spark.createDataFrame(data, schema)

# Register the DataFrame as a SQL temporary view
df.createOrReplaceTempView("sql_errors")

Now, we’ll use ai_query to fix these issues.

%sql
SELECT
  id,
  raw_sql,
  ai_query(
    "databricks-mixtral-8x7b-instruct",
    "You are a SQL expert. Fix the following SQL. Only return the corrected SQL; do not explain: " || raw_sql
  ) AS corrected_sql
FROM
  sql_errors;
idraw_sqlcorrected_sql
1SELECT FROM users WHERE id = 101; SELECT * FROM users WHERE id = 101;
2INSERT INTO products (name, price) VALSES ('Laptop', 999.99);INSERT INTO products (name, price) VALUES ('Laptop', 999.99);
3UPDATE orders SET price = '299.99' WHERE od_id = 1;UPDATE orders SET price = 299.99 WHERE od\_id = 1;

We can expand on this and request explanations for the fixes in a new column:

%sql
WITH CorrectedSQL AS (
  SELECT
    id,
    raw_sql,
    ai_query(
      "databricks-mixtral-8x7b-instruct",
      "You are a SQL expert. Fix the following SQL. Only return the corrected SQL; do not explain: " || raw_sql
    ) AS corrected_sql
  FROM
    sql_errors
)

SELECT
  c.id,
  c.raw_sql,
  c.corrected_sql,
  ai_query(
    "databricks-mixtral-8x7b-instruct",
    "Given the original SQL: '" || c.raw_sql || "' and the corrected SQL: '" || c.corrected_sql || "', summarize the changes and explain why they were necessary."
  ) AS explanation_of_fixes
FROM
  CorrectedSQL c;
idraw_sqlcorrected_sqlexplanation_of_fixes
1SELECT FROM users WHERE id = 101; SELECT * FROM users WHERE id = 101;1. Asterisk (*) was added in the 'SELECT' clause. 2. The space before 'FROM' was removed. The correction was necessary to correctly specify the columns to be returned in the query result. In the original SQL, 'SELECT FROM' is incorrect syntax because it lacks the specific column(s) to be selected. Adding an asterisk (*) in the 'SELECT' clause selects all columns from the 'users' table where the 'id' is 101. The removal of the space before 'FROM' is for standard SQL syntax compliance.
2INSERT INTO products (name, price) VALSES ('Laptop', 999.99);INSERT INTO products (name, price) VALUES ('Laptop', 999.99); Changes: The original SQL statement used "VALSES" instead of "VALUES". Explanation: The "VALUES" keyword is used in SQL to specify the values to be inserted into the table. The keyword "VALSES" is not a recognized SQL keyword, hence the error. The corrected SQL statement uses the correct "VALUES" keyword, making it a valid SQL statement to insert a new record into the "products" table with a name of "Laptop" and a price of 999.99.
3UPDATE orders SET price = '299.99' WHERE od_id = 1;UPDATE orders SET price = 299.99 WHERE od\_id = 1; The changes made to the original SQL query are minimal but significant. The single quotes around the value 299.99 in the original query have been removed in the corrected query. This change was necessary because the price field is likely a numeric type (like DECIMAL or FLOAT), and numeric values should not be enclosed in single quotes. Single quotes are used to denote string literals in SQL, not numeric values. Including single quotes around a numeric value can lead to unexpected results or errors, such as a data type conversion error. Therefore, it's essential to format numeric values correctly in SQL queries to ensure data integrity and consistency.

Conclusion#

You now know how to query Foundation Models from Databricks SQL and can start integrating calls to foundation models into your Databricks SQL work.

Further Reading#