CSModel
A CSModel object is a wrapper around a Cell2Sentence model, which tracks the path of the model saved on disk. When needed, the model is loaded from the path on disk for inference or finetuning. The class contains utilities for model generation and cell embedding with a Huggingface backend.
- csmodel.CSModel(model_name_or_path, save_dir, save_name)
Wrapper class to abstract different types of input data that can be passed in cell2sentence based workflows.
- csmodel.CSModel.__init__(self, model_name_or_path, save_dir, save_name)
Core constructor, CSModel class contains a path to a model.
- Parameters:
model_name_or_path – either a string representing a Huggingface model if want to start with a default LLM, or a path to an already-trained C2S model on disk if want to do inference with/finetune starting from an already-trained C2S model
save_dir – directory where model should be saved to
save_name – name to save model under (no file extension needed)
- csmodel.CSModel.__str__(self)
Summarize CSData object as string for debugging and logging.
- csmodel.CSModel.fine_tune(self, csdata, task: str, train_args: TrainingArguments, loss_on_response_only: bool = True, top_k_genes: int = 100, max_eval_samples: int = 500, data_split_indices_dict: dict | None = None, prompt_formatter: PromptFormatter | None = None, formatted_hf_ds: Dataset | None = None, num_proc: int = 3)
Fine tune a model using the provided CSData object data
- Parameters:
csdata – a CSData object to be used as input for finetuning. alternatively, data can be any generator of sequential text that satisfies the same functional contract as a CSData object
task – name of finetuning task (see supported tasks in prompt_formatter.py). Ignored if prompt_formatter is not None.
train_args – Huggingface Trainer arguments object
loss_on_response_only – whether to take loss only on model’s answer
top_k_genes – number of genes to use for each cell sentence. Ignored if prompt_formatter is not None.
max_eval_samples – number of samples to use for validation
data_split_indices_dict – dictionary of indices for train, val, and (optionally) test set. Required keys are “train” and “val”, value should be a list of indices of samples in that data split.
prompt_formatter – optional custom PromptFormatter object. If None, a default one will be created using task and top_k_genes parameters.
formatted_hf_ds – optional Huggingface Dataset object containing formatted data, used in cases where custom formatting is desired (e.g. multicell tasks where more complex formatting is needed).
num_proc – number of processes to use for tokenization. Defaults to 3.
- Returns:
an updated CSModel is generated in-place
- Return type:
None
- csmodel.CSModel.generate_from_prompt(self, model, prompt, max_num_tokens=1024, **kwargs)
Generate new data using the model, starting with a given prompt.
- Parameters:
model – a C2S model
prompt – a textual prompt
max_num_tokens – the maximum number of tokens to generate given the model supplied
kwargs – arguments for model.generate() (for generation options, see Huggingface docs: https://huggingface.co/docs/transformers/en/main_classes/text_generation). Any kwargs are passed without input validation to the model.generate() function
- Returns:
Text corresponding to the number n of tokens requested
- csmodel.CSModel.generate_from_prompt_batched(self, model, prompt_list, max_num_tokens=1024, **kwargs)
Batched generation with C2S model. Takes as input a model and a list of prompts to generate from.
- Parameters:
model – a C2S model
prompt – a textual prompt
max_num_tokens – the maximum number of tokens to generate given the model supplied
kwargs – arguments for model.generate() (for generation options, see Huggingface docs: https://huggingface.co/docs/transformers/en/main_classes/text_generation)
- Returns:
Text corresponding to the number n of tokens requested
- csmodel.CSModel.embed_cell(self, model, prompt, max_num_tokens=1024)
Embed cell using the model, starting with a given prompt.
- Parameters:
model – a C2S model
prompt – a textual prompt
max_num_tokens – the maximum number of tokens to generate given the model supplied
- Returns:
Text corresponding to the number n of tokens requested
- csmodel.CSModel.embed_cells_batched(self, model, prompt_list, max_num_tokens=1024)
Embed multiple cell in batched fashion using the model, starting with a given prompt.
- Parameters:
model – a C2S model for cell embedding
prompt_list – a list of textual prompts
max_num_tokens – the maximum number of tokens to generate given the model supplied
- Returns:
Text corresponding to the number n of tokens requested
- csmodel.CSModel.push_model_to_hub(self, model_id_or_name)
Helper function to push the model to Huggingface. Note: need to be logged into Huggingface, see: https://huggingface.co/docs/transformers/en/model_sharing
- Parameters:
model_id_or_name – name to push Huggingface model to