|  | 
| 7 | 7 | It includes basic model initialization, single and batch embedding generation, and embedding analysis. | 
| 8 | 8 | """ | 
| 9 | 9 | 
 | 
|  | 10 | +import argparse | 
| 10 | 11 | import os | 
| 11 | 12 | import numpy as np | 
| 12 | 13 | 
 | 
| 13 | 14 | from nexaai.embedder import Embedder, EmbeddingConfig | 
| 14 | 15 | 
 | 
| 15 | 16 | def main(): | 
| 16 |  | -    model_path = os.path.expanduser( | 
| 17 |  | -        "~/.cache/nexa.ai/nexa_sdk/models/NexaAI/jina-v2-fp16-mlx/model.safetensors") | 
|  | 17 | +    parser = argparse.ArgumentParser(description="NexaAI Embedding Example") | 
|  | 18 | +    parser.add_argument("--model", default="~/.cache/nexa.ai/nexa_sdk/models/NexaAI/jina-v2-fp16-mlx/model.safetensors", | 
|  | 19 | +                       help="Path to the embedding model") | 
|  | 20 | +    parser.add_argument("--texts", nargs="+",  | 
|  | 21 | +                       default=["On-device AI is a type of AI that is processed on the device itself, rather than in the cloud.", | 
|  | 22 | +                               "Nexa AI allows you to run state-of-the-art AI models locally on CPU, GPU, or NPU — from instant use cases to production deployments.", | 
|  | 23 | +                               "A ragdoll is a breed of cat that is known for its long, flowing hair and gentle personality.", | 
|  | 24 | +                               "The capital of France is Paris."], | 
|  | 25 | +                       help="Texts to embed") | 
|  | 26 | +    parser.add_argument("--query", default="what is on device AI", | 
|  | 27 | +                       help="Query text for similarity analysis") | 
|  | 28 | +    parser.add_argument("--batch-size", type=int, help="Batch size for processing") | 
|  | 29 | +    parser.add_argument("--plugin-id", default="cpu_gpu", help="Plugin ID to use") | 
|  | 30 | +    args = parser.parse_args() | 
| 18 | 31 | 
 | 
| 19 |  | -    # For now, this modality is only supported in MLX. | 
| 20 |  | -    embedder: Embedder = Embedder.from_( | 
| 21 |  | -        name_or_path=model_path, plugin_id="mlx") | 
|  | 32 | +    model_path = os.path.expanduser(args.model) | 
|  | 33 | +    embedder = Embedder.from_(name_or_path=model_path, plugin_id=args.plugin_id) | 
| 22 | 34 |     print('Embedder loaded successfully!') | 
| 23 | 35 | 
 | 
| 24 | 36 |     dim = embedder.get_embedding_dim() | 
| 25 | 37 |     print(f"Dimension: {dim}") | 
| 26 | 38 | 
 | 
| 27 |  | -    texts = [ | 
| 28 |  | -        "On-device AI is a type of AI that is processed on the device itself, rather than in the cloud.", | 
| 29 |  | -        "Nexa AI allows you to run state-of-the-art AI models locally on CPU, GPU, or NPU — from instant use cases to production deployments.", | 
| 30 |  | -        "A ragdoll is a breed of cat that is known for its long, flowing hair and gentle personality.", | 
| 31 |  | -        "The capital of France is Paris." | 
| 32 |  | -    ] | 
|  | 39 | +    batch_size = args.batch_size or len(args.texts) | 
| 33 | 40 |     embeddings = embedder.generate( | 
| 34 |  | -        texts=texts, config=EmbeddingConfig(batch_size=len(texts))) | 
|  | 41 | +        texts=args.texts, config=EmbeddingConfig(batch_size=batch_size)) | 
| 35 | 42 | 
 | 
| 36 | 43 |     print("\n" + "="*80) | 
| 37 | 44 |     print("GENERATED EMBEDDINGS") | 
| 38 | 45 |     print("="*80) | 
| 39 | 46 | 
 | 
| 40 |  | -    for i, (text, embedding) in enumerate(zip(texts, embeddings)): | 
|  | 47 | +    for i, (text, embedding) in enumerate(zip(args.texts, embeddings)): | 
| 41 | 48 |         print(f"\nText {i+1}:") | 
| 42 | 49 |         print(f"  Content: {text}") | 
| 43 | 50 |         print(f"  Embedding shape: {len(embedding)} dimensions") | 
| 44 | 51 |         print(f"  First 10 elements: {embedding[:10]}") | 
| 45 | 52 |         print("-" * 70) | 
| 46 | 53 | 
 | 
| 47 |  | -    # Generate embedding for query | 
| 48 |  | -    query = "what is on device AI" | 
| 49 | 54 |     print(f"\n" + "="*80) | 
| 50 | 55 |     print("QUERY PROCESSING") | 
| 51 | 56 |     print("="*80) | 
| 52 |  | -    print(f"Query: '{query}'") | 
|  | 57 | +    print(f"Query: '{args.query}'") | 
| 53 | 58 | 
 | 
| 54 | 59 |     query_embedding = embedder.generate( | 
| 55 |  | -        texts=[query], config=EmbeddingConfig(batch_size=1))[0] | 
|  | 60 | +        texts=[args.query], config=EmbeddingConfig(batch_size=1))[0] | 
| 56 | 61 |     print(f"Query embedding shape: {len(query_embedding)} dimensions") | 
| 57 | 62 | 
 | 
| 58 |  | -    # Compute inner product between query and all texts | 
| 59 | 63 |     print(f"\n" + "="*80) | 
| 60 | 64 |     print("SIMILARITY ANALYSIS (Inner Product)") | 
| 61 | 65 |     print("="*80) | 
| 62 | 66 | 
 | 
| 63 |  | -    for i, (text, embedding) in enumerate(zip(texts, embeddings)): | 
| 64 |  | -        # Convert to numpy arrays for easier computation | 
|  | 67 | +    for i, (text, embedding) in enumerate(zip(args.texts, embeddings)): | 
| 65 | 68 |         query_vec = np.array(query_embedding) | 
| 66 | 69 |         text_vec = np.array(embedding) | 
| 67 |  | -         | 
| 68 |  | -        # Compute inner product (dot product) | 
| 69 | 70 |         inner_product = np.dot(query_vec, text_vec) | 
| 70 | 71 | 
 | 
| 71 | 72 |         print(f"\nText {i+1}:") | 
|  | 
0 commit comments