Using LLMs as Context-Aware Text Embedding Models - NV-Embed Paper Review
Review of the NV-Embed paper where a generative model (Mistral 7B base) shows state of the art performance as a generalist text embedding model.
Can you harness the immense language understanding capabilities of generative models (e.g., large language models) in generating high quality text embeddings? Yes!
This paper - NV-Embed (NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models) demonstrates how to finetune a base LLM (Mistral 7B) to provide state of the art (SOTA) text embeddings. Perhaps, the more important thing about this approach is that it offers a way to move from static embeddings (e.g., traditional embeddings that cannot be changed once the model is trained) to dynamic, context-aware embeddings (embeddings that can be tuned via instructions).
This post includes my thoughts on a quick review of the NV-Embed model, key ideas and results from a quick experiment.
TLDR;
The approach of instruction tuning embeddings is valuable. It offers the developer an additional lever to tune/optimize the embedding model for the task at hand. This includes building high quality RAG pipelines useful for building agentic systems
My quick experiments show the approach is flexible (it generalizes to a new dataset I tested with - clustering and classifying YC company descriptions), potentially opening new ways to explore, analyze and visualize data.
Text Embeddings
Text embeddings play a crucial role in numerous real-world applications. They are essential for tasks such as data analysis (generating semantic embeddings, clustering, and deriving insights), recommendation systems (suggesting items similar to a given item), and many others.
An embedding is a vector representation of a data point (which could be text, an image, a video, etc.) that encodes the semantic meaning of that data point. These embeddings can be utilized for various applications that require understanding the relationship between data points e.g., computing the relevance or similarity between data points.
One of the key advantages of embeddings is that they enable efficient processing at scale, thanks to fast vector search algorithms like approximate nearest neighbor search (FAISS, Annoy, SCANN, etc). This scalability is fundamental to the Retrieval-Augmented Generation (RAG) pattern, where embeddings are used to identify the most relevant documents to feed into a language model for a given task.
In a typical workflow, all relevant passages are first embedded. Then, at query time, the input query is also embedded, and similarity is computed using metrics such as cosine similarity. This process allows for rapid retrieval of the most semantically relevant information from large document collections.
Critically, the quality of data analysis systems, recommendation engines, and RAG pipelines is heavily dependent on the effectiveness of the retrieval step. Poor retrieval can compromise the entire workflow. Some common challenges in this area include:
Insufficient Semantic Modeling: Embedding models may inadequately capture the semantics of the data. This is a complex problem that has been the focus of extensive research in metric learning and contrastive learning - using the right model architecture, curating the right training data (hard negative mining) and objectives etc
Out-of-Distribution Application: Models are often applied to tasks that diverge significantly from their training data. This is particularly common when off-the-shelf models trained on academic datasets are used to embed text containing customer-specific jargon or domain-specific language.
Query-Passage Structure Mismatch: The classic approach of embedding queries and passages, then using cosine similarity for retrieval, can break down when the structure of queries differs substantially from that of passages. While passages are often rich, self-contained chunks of text, queries can vary widely - they might be questions, concatenations of user context and direct questions, or samples from recent user interactions. In such cases, cosine similarity may not be an adequate measure of relevance.
These challenges underscore the excitement surrounding papers like NVEmbed, which introduce innovative ideas for repurposing Large Language Models (LLMs) as generalist embedding models. NV-Embed attempts to address these three key problems and has achieved the top ranking on the MTEB: Massive Text Embedding Benchmark MTEB leaderboard.
What is the NV Embed Model?
The NV-Embed model is a generalist embedding model designed to significantly enhance the performance of decoder-only large language models (LLMs) for embedding and retrieval tasks. The primary motivations behind its development were threefold: to improve the performance of decoder-only LLMs as versatile embedding models, to create a state-of-the-art embedding model using only publicly available data, and to enhance performance across a wide range of tasks, including retrieval, classification, and clustering. The model aims to address these goals through novel architectural designs and a two-stage training procedure, ultimately achieving superior results on comprehensive embedding benchmarks without relying on proprietary synthetic data from frontier LLMs like GPT-4.
Key Decisions with NV-Embed
The NV-Embed model is based on the Mistral-7B, enabling it leverage the strong language understanding capabilities inherent in a large, pre-trained model like Mistral-7B.
IMO, Some of the key design decisions in the paper include:
Contrastive Instruction-Tuning Training: The model employs a two-stage contrastive instruction-tuning method.
Contrastive Training: The first stage focuses on contrastive training using retrieval datasets such as MS MARCO, Natural Questions, and FEVER.
Instruction Tuning: The second stage introduces a blend of both retrieval and non-retrieval datasets, including classification, clustering, and semantic textual similarity tasks. For example, it might use prompts like "Given a web search query, retrieve relevant passages that answer the query." This approach enables the model to handle multiple types of tasks effectively within a single architecture. The paper appendix shows examples of templates used for each of these tasks
Removal of Causal Attention Masks: The model eliminates the causal attention constraint during contrastive training. This departure from traditional decoder-only LLM architectures allows the model to learn bidirectional contextual representations.
Latent Attention Layer: NV-Embed introduces a novel pooling mechanism featuring 512 latents and 8 multi-head attentions. This innovation aims to improve the quality of sentence embeddings compared to simpler strategies like mean pooling or using the last token's representation.
Experiment: Exploring differences in Embedding Structures
Given that the model is trained to consider instruction prompts, this provides the developer latitude to influence the behaviour and quality of embeddings (at zero cost, no finetuning or training needed).
Method
To explore this, I ran a simple experiment to visualize the structural impact of instructions on the extracted embeddings using data from YCombinator. Overall process was:
Using data on all YC companies from Jan 2010 - August 2024, extract company descriptions
Extract embeddings for each company using the NV-Embed model with multiple instruction template conditions.
Visualize embeddings i.e., reduce to 2 dimensions using TSNE, generate an interactive plot.
Instruction Templates
I had three main instruction template conditions (below) and a fourth case where no instruction is given.
Given a company description, retrieve other companies that are semantically similar or are in the same domain? \nQuery:
Classify the company description, as Artificial Intelligence (AI) or not artificial intelligence. \nQuery
Classify the company description as health domain or not health domain. \nQuery
The first follows a retrieval template while the next two follow a classification template. The code used for this experiment is available in the reference section. The relevant section is here:
semantic_clustering_instruction = "Given a company description, retrieve other companies that are semantically similar or are in the same domain."
semantic_clustering_embeddings = get_embeddings(model, yc_desc, semantic_clustering_instruction, max_length=max_seq_length )
semantic_reduced_dims = reduce_dimensions(semantic_clustering_embeddings, 2)
save_json(tensor_to_json(semantic_clustering_embeddings), 'data/semantic_embeddings.json')
plot_clusters(semantic_reduced_dims, df, color_by='mentions_ai', title="Semantic Instruction Clustering of YC Companies")
The general hypothesis here is:
Instruction embeddings are better than the base condition (no instruction). E.g. in the AI classification case we want to see better separation of AI vs non AI companies in the embeddings compared to the base condition. Similarly
Evaluation
So how do we verify or interpret the quality of these embeddings? Well, the right way would be to carefully construct a benchmark with labels and compute standard retrieval metrics such as ndcg@k and classification metrics. For classification, we'd typically look at metrics like accuracy, precision, recall, and F1-score. These metrics provide a quantitative measure of how well our embeddings can be used to categorize companies into different groups.
In this case, we'll take some liberties and infer a few labels, then iteratively explore each visualization to make sense of the data. For labels, we'll add a "mentions_ai" field to our dataset column and a "mentions_health" column, both based on regular expressions.
It's important to note that in a production environment, you'd typically need more than this as a first benchmark. However, this approach serves as a starting point for our analysis.
Next, we'll plot the data points colored by these labels to see how well the embeddings map to the data. For example, we expect that the condition instructed on semantic relationships will show visible clusters, and the condition instructed on AI will show clear separation between AI and non-AI companies. By visually inspecting these plots, we can gain insights into how well our embeddings capture the intended semantic information.
Keep reading with a 7-day free trial
Subscribe to Designing with Machine Learning to keep reading this post and get 7 days of free access to the full post archives.