Reliable Zero-Shot Classification with the Trustworthy Language Model
In zero-shot classification, we use a Foundation model to classify input data into predefined categories (aka. classes), without having to train this model on a dataset manually annotated with these categories. This utilizes the pre-trained model’s world knowledge to accomplish tasks that would require much more work training classical machine learning models from scratch. The problem with zero-shot classification of text with LLMs is we don’t know which LLM classifications we can trust. Most LLMs are prone to hallucination and will often predict a category even when their world knowledge does not suffice to justify this prediction.
This tutorial demonstrates how you can easily replace any LLM with Cleanlab’s Trustworthy Language Model (TLM) to gauge the trustworthiness of each zero-shot classification. Use the TLM to ensure reliable classification where you which model predictions cannot be trusted. Before this tutorial, we recommend completing the TLM quickstart tutorial.
Setup
Using TLM requires a Cleanlab account. Sign up for one here if you haven’t yet. If you’ve already signed up, check your email for a personal login link.
The Python client package can be installed using pip:
%pip install cleanlab-studio
import re
import pandas as pd
from tqdm import tqdm
from difflib import SequenceMatcher
from cleanlab_studio import Studio
In Python, launch your Cleanlab Studio client using your API key.
# Get your API key from https://app.cleanlab.ai/account after creating an account.
studio = Studio("<insert your API key>")
Let’s load an example classification dataset. Here we consider legal documents from the “US” Jurisdiction of the Multi_Legal_Pile, a large-scale multilingual legal dataset that spans over 24 languages. We aim to classify each document into one of three categories: [caselaw, contracts, legislation]
.
We’ll prompt our TLM to categorize each document and record its response and associated trustworthiness score. You can use the ideas from this tutorial to improve LLMs for any other text classification task!
First download our example dataset and then load it into a DataFrame.
wget -nc 'https://cleanlab-public.s3.amazonaws.com/Datasets/zero_shot_classification.csv'
df = pd.read_csv('zero_shot_classification.csv')
df.head(2)
index | text | |
---|---|---|
0 | 0 | Probl2B\n0/NV Form\nRev. June 2014\n\n\n\n ... |
1 | 1 | UNITED STATES DI... |
Perform Zero Shot Classification with TLM
Let’s initalize a TLM
object. Here we use default TLM settings, but check out the TLM quickstart tutorial for configuration options that can produce better results.
tlm = studio.TLM()
Next, let’s define a prompt template to instruct the TLM on how to classify each document’s text. Write your prompt just as you would with any other LLM when adapting it for zero-shot classification. A good prompt template might contain all the possible categories a document can be classified as, as well as formatting instructions for the LLM response. Of course the text of the document is crucial.
'You are an expert Legal Document Auditor. Classify the following document into a single category that best represents it. The categories are: {categories}. In your response, first provide a brief explanation as to why the document belongs to a specific category and then on a new line write "Category: <category document belongs to>". \nDocument: {document}'
If you have a couple labeled examples from different classes, you may be able to get better LLM predictions via few-shot prompting (where these examples + their classes are embedded within the prompt). Here we’ll stick with zero-shot classification for simplicity, but note that TLM can also be used for few-shot classification just like any other LLM.
Lets apply the above prompt template to all documents in our dataset and form the list of prompts we want to run. For one arbitrary document, we print the actual corresponding prompt fed into the TLM below.
zero_shot_prompt_template = 'You are an expert Legal Document Auditor. Classify the following document into a single category that best represents it. The categories are: {categories}. In your response, first provide a brief explanation as to why the document belongs to a specific category and then on a new line write "Cateogry: <category document belongs to>". \nDocument: {document}'
categories = ['caselaw', 'contracts', 'legislation']
string_categories = str(categories).replace('\'', '')
# Create a DataFrame to store results and apply the prompt template to all examples
results_df = df.copy()
results_df['prompt'] = results_df['text'].apply(lambda x: zero_shot_prompt_template.format(categories=string_categories, document=x))
print(f"{results_df.at[7, 'prompt']}")
Now we prompt the TLM and save the output responses and their associated trustworthiness scores for all examples. We recommend the try_prompt()
method to run TLM over datasets with many examples.
outputs = tlm.try_prompt(results_df['prompt'].to_list())
results_df[["response","trustworthiness_score"]] = pd.DataFrame(outputs)
Parse raw LLM Responses into Category Predictions
Our prompt template asks the LLM to explain it’s predictions, which can boost their accuracy. We now parse out the classification prediction, which should be exactly one of the categories for each document. Because LLMs don’t necessarily follow output formatting instructions perfectly, we define a function that parses out only the expected categories. If no value out of the possible categories is directly mentioned in the response, the category with greatest string similarity to the response is returned (along with a warning).
Note If there are no close matches between the LLM response and any of the possible categories
, then the last entry of the categories
list is returned. We can add an “other” category to account for bad responses that are hard to parse into a specific category.
categories_with_bad_parse = categories + ["other"]
categories_with_bad_parse
Optional: Define helper methods to parse categories and better display results. (click to expand)
import warnings
def parse_category(
response: str,
categories: list,
disable_warnings: bool = False,
) -> str:
"""Extracts one of the provided categories from the response using regex patterns. Returns last extracted category if multiple exist.
If no category out of the possible `categories` is directly mentioned in the response, the category with greatest string similarity to the response is returned (along with a warning).
If there are no close matches between the LLM response and any of the possible `categories`, then the last entry of the `categories` list is returned.
Params
------
response: Response from the LLM
categories: List of expected categories, the last value of this list should be considered the default/baseline value (eg. “other”),
that value will be returned if there are no close matches.
disable_warnings: If True, print warnings are disabled
"""
response_str = str(response)
# Create string pattern of listed constrain outputs
escaped_categories = [re.escape(output) for output in categories]
categories_pattern = "(" + "|".join(escaped_categories) + ")"
# Parse category if LLM response is properly formatted
exact_matches = re.findall(categories_pattern, response_str, re.IGNORECASE)
if len(exact_matches) > 0:
return str(exact_matches[-1])
# If there are no exact matches to a specific category, return the closest category based on string similarity.
best_match = max(
categories, key=lambda x: SequenceMatcher(None, response_str, x).ratio()
)
similarity_score = SequenceMatcher(None, response_str, best_match).ratio()
if similarity_score < 0.5:
warning_message = (
f"None of the categories remotely match raw LLM output: {response_str}.\n"
+ "Returning the last entry in the constrain outputs list."
)
best_match = categories[-1]
else:
warning_message = f"None of the categories match raw LLM output: {response_str}"
if not disable_warnings:
warnings.warn(warning_message)
return best_match
def display_result(results_df: pd.DataFrame, index: int):
"""Displays the TLM result for the example from the dataset whose `index` is provided."""
print(f"TLM predicted category: {results_df.iloc[index].predicted_category}")
print(f"TLM trustworthiness score: {results_df.iloc[index].trustworthiness_score}\n")
print(results_df.iloc[index].text)
results_df['predicted_category'] = results_df['response'].apply(lambda x: parse_category(x, categories_with_bad_parse))
Analyze Classification Results
Let’s first inspect the most trustworthy predictions from our model. We sort the TLM outputs over our documents to see which predictions received the highest trustworthiness scores.
results_df = results_df.sort_values(by='trustworthiness_score', ascending=False)
display_result(results_df, index=0)
A document about “DEPARTMENT OF TRANSPORTATION National Highway Traffic Safety Administration” is very clearly belonging to some legislative measure so it makes sense the TLM classifies it into the “legislation” category with a high trustworthiness score.
display_result(results_df, index=1)
Another document titled as “National Oil and Hazardous Substances Pollution Contingency Plan; National Priorities List” is very clearly belonging to some legislative measure so it makes sense the TLM classifies it into the “legislation” category with a high trustworthiness score.
display_result(results_df, index=2)
This document about “Amendment to Loan Agreement” is very clearly a contract so it makes sense the TLM classifies it into the “contracts” category with a high trustworthiness score.
Least Trustworthy Predictions
Now let’s see which classifications predicted by the model are least trustworthy. We sort the data by trustworthiness scores in the opposite order to see which predictions received the lowest scores. Observe how model classifications with the lowest trustworthiness scores are often incorrect, corresponding to examples with vague/irrelevant text or documents possibly belonging to more than one category.
results_df = results_df.sort_values(by='trustworthiness_score')
display_result(results_df, index=0)
This is clearly not a contract but instead a caselaw document with a case number. It’s good to see that the TLM gives a very low trustworthiness score.
display_result(results_df, index=1)
This document also clearly a caselaw, but the model predicted it to be contracts. It’s good to see that the TLM gives a very low trustworthiness score.
display_result(results_df, index=3)
This document clearly does not belong in any of the three categories as it is just a series of image titles. It makes sense why the TLM gives low trustworthiness score.
How to use Trustworthiness Scores?
If you have time/resources, your team can manually review the LLM classifications of low-trustworthiness responses and provide a better human classification instead. If not, you can determine a trustworthiness threshold below which responses seem too unreliable to use, and have the model abstain from predicting in such cases (i.e. outputting “I don’t know” instead).
The overall magnitude/range of the trustworthiness scores may differ between datasets, so we recommend selecting any thresholds to be application-specific. First consider the relative trustworthiness levels between different data points before considering the overall magnitude of these scores for individual data points.
Measuring Classification Accuracy with Ground Truth Labels
Our example dataset happens to have labels for each document, so we can load them in to assess the accuracy of our model predictions. We’ll study the impact on accuracy as we abstain from making predictions for examples receiving lower trustworthiness scores.
wget -nc 'https://cleanlab-public.s3.amazonaws.com/Datasets/zero_shot_classification_labels.csv'
df_ground_truth = pd.read_csv('zero_shot_classification_labels.csv')
df = pd.merge(results_df, df_ground_truth, on=['index'], how='outer')
df['is_correct'] = df['type'] == df['predicted_category']
df.head()
index | text | prompt | response | trustworthiness_score | predicted_category | type | is_correct | |
---|---|---|---|---|---|---|---|---|
0 | 0 | Probl2B\n0/NV Form\nRev. June 2014\n\n\n\n ... | You are an expert Legal Document Auditor. Clas... | The document is a formal request for early ter... | 0.874957 | caselaw | caselaw | True |
1 | 1 | UNITED STATES DI... | You are an expert Legal Document Auditor. Clas... | The document is a court order from a United St... | 0.935663 | caselaw | caselaw | True |
2 | 2 | \n \n FEDERAL COMMUNICATIONS COMMI... | You are an expert Legal Document Auditor. Clas... | The document is a Notice of Proposed Rule Maki... | 0.938619 | legislation | legislation | True |
3 | 3 | \n \n DEPARTMENT OF COMMERCE\n ... | You are an expert Legal Document Auditor. Clas... | The document is a notice from the National Oce... | 0.927012 | legislation | legislation | True |
4 | 4 | EXHIBIT 10.14\n\nAMENDMENT NO. 1 TO\n\nCHANGE ... | You are an expert Legal Document Auditor. Clas... | The document is an amendment to a severance ag... | 0.934622 | contracts | contracts | True |
print('TLM zero-shot classification accuracy over all documents: ', df['is_correct'].sum() / df.shape[0])
Next suppose we instead abstain from making predictions on 50% of the documents flagged with the lowest trustworthiness scores (e.g. having experts manually categorize these documents instead).
quantile = 0.5 # Play with value to observe the accuracy vs. number of abstained examples tradeoff
filtered_df = df[df['trustworthiness_score'] > df['trustworthiness_score'].quantile(quantile)]
acc = filtered_df['is_correct'].sum() / filtered_df.shape[0]
print(f'TLM zero-shot classification accuracy over the documents within the top-{(1-quantile) * 100}% of trustworthiness scores: {acc}')
This shows the benefit of considering the TLM’s trustworthiness score for zero-shot classification over having to rely on results from a standard LLM.