Source code for genos.api.embedding_extractor
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import requests
import time
import logging
from typing import List, Union
from .health import BaseAPI
from ..exceptions import APIRequestError, ValidationError
import torch
# 设置日志
logger = logging.getLogger(__name__)
[docs]
class EmbeddingExtractorAPI(BaseAPI):
"""Wrapper class for the Embedding Extraction API endpoints.
This class handles requests to the embedding extraction endpoints,
including single sequence and batch processing.
"""
def __init__(self, session: requests.Session, base_url: str, timeout: int = 30, config=None):
"""
Initialize the EmbeddingExtractorAPI client.
Args:
session (requests.Session): Reusable HTTP session for API requests.
base_url (str): Base URL of the Genos service.
timeout (int, optional): Request timeout in seconds. Default is 30.
config (GenosConfig, optional): Configuration object for endpoint management.
"""
super().__init__(session, base_url, timeout)
self.config = config
[docs]
def extract(self, sequence: str, model_name: str = "Genos-1.2B",
pooling_method: str = "mean") -> Union[dict, List[dict]]:
"""
Extracts a numerical embedding representation for a given nucleotide sequence.
Args:
sequence (str ): DNA sequence string .
model_name (str, optional): Model name to use. Default is "Genos-1.2B".
Options: "Genos-1.2B", "Genos-10B"
pooling_method (str, optional): Pooling method. Default is "mean".
Options: "mean", "max", "last", "none"
Returns:
dict:
- "token_count": number of tokens
- "embedding_shape": shape of embedding array
- "embedding_dim": dimension of embedding
- "embedding": embedding array (list)
Raises:
ValueError: If sequence is not a valid string or list.
ValidationError: If parameters are invalid.
APIRequestError: If the API request fails.
Examples:
>>> # Single sequence
>>> result = embedding_api.extract("ATCGATCGATCG")
>>> print(result['embedding_dim'])
4096
"""
# 判断是单个序列还是批量
if isinstance(sequence, str):
# 单个序列
return self._extract_single(sequence, model_name, pooling_method)
else:
raise ValueError("sequence must be a string or list of strings")
def _extract_single(self, sequence: str, model_name: str, pooling_method: str) -> dict:
"""内部方法:提取单个序列"""
# Validate input
if len(sequence) == 0:
raise ValidationError("sequence cannot be empty")
# Check sequence length limit (128K characters)
if len(sequence) > 128000:
raise ValidationError("sequence length cannot exceed 128,000 bases")
# Prepare request
# 从配置获取端点路径
if self.config:
endpoint = self.config.get_endpoint("embedding.extract")
else:
endpoint = "" # 默认值(向后兼容)
url = f"{self.base_url}{endpoint}"
payload = {
"sequence": sequence,
"model_name": model_name,
"pooling_method": pooling_method
}
# Start timing
start_time = time.time()
# Use the base class method for request handling with token validation
data = self._make_request('POST', url, json=payload)
# Calculate elapsed time
elapsed_time = time.time() - start_time
# Check response status - new format has 'status' field
# if data.get("status") != 200:
# error_msg = data.get("messages", "Unknown API error")
# raise APIRequestError(f"API request failed: {error_msg}")
embedding = torch.tensor(data["result"]["embedding"])
data["result"]["embedding"] = embedding
# Print elapsed time to screen
logger.info(f"⏱️ Embedding extraction completed in {elapsed_time:.4f}s "
f"(sequence_length={data.get('sequence_length', 'N/A')})")
res = {}
for key in ("result", "status", "message"):
res[key] = data.get(key)
return res