Fast Classification with Provisioned Throughput#
This notebook shows how to use the Provisioned Throughput Foundation Model APIs for a high-throughput classification task. In particular, we will use the mpt-30b-instruct model to classify the emotions of short texts.
Getting Started: Create a Provisioned Throughput Serving Endpoint#
To get started, follow the instructions here to create a provisioned throughput serving endpoint. You can get the model from the Databricks Marketplace: navigate to “Marketplace” in your databricks workspace and search for “MPT Models.” You can save the models to a catalog in your workspace, and then that model using the provisioned throughput instructions linked above. Note that setting up the endpoint may take a few minutes.
The rest of this notebook assumes you have set up the provisioned throughput endpoint. Note that the final results of this notebook will depend on the tokens per second limit configured when setting up the provisioned throughput endpoint.
Classification Example#
This example will use mpt-30b-instruct to classify the emotions of the texts in the DAIR AI Emotion dataset. The focus of this example is throughput: there will be room left to improve on the classification component of this example through tweaking the prompt, model parameters, and model choice.
Set up libraries and Environment Variables#
# Upgrade to use the newest Databricks SDK
%pip install --upgrade aiohttp tqdm datasets
dbutils.library.restartPython()
# Get the API endpoint and token for the current notebook context
API_ROOT = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()
API_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
endpoint_name = "<your-endpoint-name>"
Load the dataset#
from datasets import load_dataset
emotion = load_dataset("dair-ai/emotion", cache_dir="/Volumes/daniel_liden/examples/datasets")
emotion_messages = emotion['train']['text']
complete_prompts = []
prompt_template = """Your task is to classify the emotion of the provided message. Only use one of the specified emotions for your classification. The valid classifications are: JOY, SADNESS, ANGER, FEAR, LOVE, SURPRISE. No other classifications will be accepted. Even if none of these seems like a perfect fit, pick the closest one.
### Instruction:
Read the message below and classify its emotion:
<message>{}</message>
Expected Format:
Respond with the emotion in the following format only:
EMOTION
Do not include any explanation or additional text.
### Examples:
- Message: "I was shocked by the testimony!"
Response: SURPRISE
- Message: "I don't know how I'm ever going to move on after my team's loss..."
Response: SADNESS
Ensure your response strictly follows the given format and only uses one of the specified emotion classifications.
### Response:\n"""
# Loop through each message in the emotion_messages list
for message in emotion_messages:
formatted_prompt = prompt_template.format(message)
complete_prompts.append(formatted_prompt)
Configure API calls with aiohttp#
To improve throughput, we use the aiohttp library to make API calls concurrently.
import aiohttp
import asyncio
import time
import statistics
from tqdm.asyncio import tqdm # Import tqdm for asyncio
# Placeholder variables for endpoint and headers
endpoint_url = f"{API_ROOT}/serving-endpoints/{endpoint_name}/invocations"
headers = {
"Content-Type": "text/json",
"Authorization": f"Bearer {API_TOKEN}"
}
latencies = []
data = []
async def worker(index, prompt, concurrency_semaphore):
async with concurrency_semaphore: # Use the passed semaphore for concurrency control
input_data = {
"inputs": {
"prompt": [prompt]
},
"params": {
"max_tokens": 10,
"temperature": 0.2,
"top_p": 0.1
}
}
request_start_time = time.time()
try:
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(endpoint_url, headers=headers, json=input_data) as response:
if response.ok:
response_data = await response.json()
latency = time.time() - request_start_time
latencies.append(latency)
data.append(response_data)
else:
error_response = await response.text()
print(f"Request failed, status: {response.status}, error: {error_response}")
except Exception as e:
print(f"An error occurred: {e}")
async def process_prompts(complete_prompts, num_concurrent_calls):
concurrency_semaphore = asyncio.Semaphore(num_concurrent_calls) # Create semaphore based on num_concurrent_calls
tasks = [worker(i, prompt, concurrency_semaphore) for i, prompt in enumerate(complete_prompts)]
for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Generating Data"):
await task
async def main(complete_prompts, num_concurrent_calls=15): # Default concurrency level set to 15
print("Starting data generation...")
await process_prompts(complete_prompts, num_concurrent_calls)
if latencies:
median_latency = statistics.median(latencies)
print(f"Median latency (s): {median_latency}")
else:
print("No data collected.")
Here is a quick overview of what the above functions do:
worker(index, prompt, concurrency_semaphore)
: Executes a single asynchronous HTTP POST request using a provided prompt and a concurrency semaphore to control the rate of concurrent requests, recording latency and response data.process_prompts(complete_prompts, num_concurrent_calls)
: Distributes prompts among workers, enforcing concurrency limits with a semaphore, and tracks the completion of all tasks with progress output.main(complete_prompts, num_concurrent_calls=15)
: Initiates the process of generating data by callingprocess_prompts
with the complete list of prompts and the specified number of concurrent calls, and summarizes the median latency upon completion.
Run the API Requests#
await main(complete_prompts, num_concurrent_calls=60)
Starting data generation...
Generating Data: 100%|██████████| 16000/16000 [09:33<00:00, 27.92it/s]
Median latency (s): 2.189871907234192
Preview the Output#
data[:10]
[{'predictions': [{'candidates': [{'text': 'FEAR',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 220,
'output_tokens': 3,
'total_tokens': 223}}]},
{'predictions': [{'candidates': [{'text': 'LOVE',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 216,
'output_tokens': 3,
'total_tokens': 219}}]},
{'predictions': [{'candidates': [{'text': 'LOVE',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 238,
'output_tokens': 3,
'total_tokens': 241}}]},
{'predictions': [{'candidates': [{'text': 'SADNESS',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 224,
'output_tokens': 4,
'total_tokens': 228}}]},
{'predictions': [{'candidates': [{'text': 'FEELING',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 216,
'output_tokens': 4,
'total_tokens': 220}}]},
{'predictions': [{'candidates': [{'text': 'JOY',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 262,
'output_tokens': 3,
'total_tokens': 265}}]},
{'predictions': [{'candidates': [{'text': 'LOVE',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 262,
'output_tokens': 3,
'total_tokens': 265}}]},
{'predictions': [{'candidates': [{'text': 'SURPRISE',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 264,
'output_tokens': 5,
'total_tokens': 269}}]},
{'predictions': [{'candidates': [{'text': 'LOVE',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 251,
'output_tokens': 3,
'total_tokens': 254}}]},
{'predictions': [{'candidates': [{'text': 'FEAR',
'metadata': {'finish_reason': 'stop'}}],
'metadata': {'input_tokens': 211,
'output_tokens': 3,
'total_tokens': 214}}]}]
Conclusion#
This notebook provides a sketch of one approach to accomplish high-throughput classification with the Provisioned Throughput Foundation Model APIs.