Huggingface-native
class HfNativeLLM(BaseLLM):
def load_llm_and_tokenizer(self) -> None:
""" fetch the model from huggingface and run it """
self.max_gpu_memory = self.convert_map(self.max_gpu_memory)
self.auth_token = get_from_env("HF_AUTH_TOKENS")
""" only casual lms for now """
self.model = MODEL_CLASS[self.model_type].from_pretrained(
self.model_name,
device_map="auto",
max_memory=self.max_gpu_memory,
use_auth_token = self.auth_token
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
use_auth_token = self.auth_token
)
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def parse_tool_callings(self, result):
pattern = r'\[\{.*?\}\]'
matches = re.findall(pattern, result)
return matches[-1]
def address_syscall(self,
llm_syscall,
temperature=0.0) -> None:
llm_syscall.set_status("executing")
llm_syscall.set_start_time(time.time())
self.logger.log(
f"{llm_syscall.agent_name} is switched to executing.\n",
level = "executing"
)
messages = llm_syscall.request_data.messages
tools = llm_syscall.request_data.tools
message_return_type = llm_syscall.request_data.message_return_type
""" context_manager works only with open llms """
if self.context_manager.check_restoration(llm_syscall.get_pid()):
restored_context = self.context_manager.gen_recover(
llm_syscall.get_pid()
)
start_idx = restored_context["start_idx"]
beams = restored_context["beams"]
beam_scores = restored_context["beam_scores"]
beam_attention_mask = restored_context["beam_attention_mask"]
outputs = self.llm_generate(
search_mode = "beam_search",
beam_size = 1,
beams = beams,
beam_scores = beam_scores,
beam_attention_mask = beam_attention_mask,
max_new_tokens = self.max_new_tokens,
start_idx = start_idx,
timestamp = llm_syscall.get_time_limit()
)
else:
""" use the system prompt otherwise """
if tools:
messages = self.tool_calling_input_format(messages, tools)
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize = False
)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
attention_mask = input_ids != self.tokenizer.pad_token_id
input_ids = input_ids.to(self.eval_device)
attention_mask = attention_mask.to(self.eval_device)
outputs = self.llm_generate(
input_ids = input_ids,
attention_mask = attention_mask,
search_mode = "beam_search",
beam_size = 1,
max_new_tokens=self.max_new_tokens,
start_idx = 0,
timestamp = llm_syscall.get_time_limit()
)
# TODO temporarily
outputs["result"] = outputs["result"][input_ids.shape[-1]:]
# output_ids = outputs
# print(output_ids)
output_ids = outputs["result"]
""" devectorize the output """
result = self.tokenizer.decode(output_ids, skip_special_tokens=False)
if outputs["finished_flag"]: # finished flag is set as True
if self.context_manager.check_restoration(
llm_syscall.get_pid()):
self.context_manager.clear_restoration(
llm_syscall.get_pid()
)
if tools:
tool_calls = self.parse_tool_calls(
result
)
llm_syscall.set_response(
Response(
response_message = None,
tool_calls = tool_calls
)
)
else:
llm_syscall.set_response(
Response(
response_message = result
)
)
llm_syscall.set_status("done")
else:
""" the module will automatically suspend if reach the time limit """
self.logger.log(
f"{llm_syscall.agent_name} is switched to suspending due to the reach of time limit ({llm_syscall.get_time_limit()}s).\n",
level = "suspending"
)
self.context_manager.gen_snapshot(
llm_syscall.get_pid(),
context = {
"start_idx": outputs["start_idx"],
"beams": outputs["beams"],
"beam_scores": outputs["beam_scores"],
"beam_attention_mask": outputs["beam_attention_mask"]
}
)
if message_return_type == "json":
result = self.parse_json_format(result)
llm_syscall.set_response(
Response(
response_message = result
)
)
llm_syscall.set_status("suspending")
llm_syscall.set_end_time(time.time())
def llm_generate(self,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
beams: torch.Tensor = None,
beam_scores: torch.Tensor = None,
beam_attention_mask: torch.Tensor = None,
beam_size: int = None,
max_new_tokens: int = None,
search_mode: str = None,
start_idx: int = 0,
timestamp: int = None
):
""" only supports beam search generation """
if search_mode == "beam_search":
output_ids = self.beam_search(
input_ids = input_ids,
attention_mask = attention_mask,
beam_size = beam_size,
beams = beams,
beam_scores = beam_scores,
beam_attention_mask = beam_attention_mask,
max_new_tokens = max_new_tokens,
start_idx = start_idx,
timestamp = timestamp
)
return output_ids
else:
# TODO: greedy support
return NotImplementedError
def beam_search(self,
input_ids: torch.Tensor = None,
attention_mask: torch.Tensor = None,
beams=None,
beam_scores=None,
beam_attention_mask=None,
beam_size: int = None,
max_new_tokens: int = None,
start_idx: int = 0,
timestamp: int = None
):
"""
beam search gets multiple token sequences concurrently and calculates
which token sequence is the most likely opposed to calculating the
best token greedily
"""
if beams is None or beam_scores is None or beam_attention_mask is None:
beams = input_ids.repeat(beam_size, 1)
beam_attention_mask = attention_mask.repeat(beam_size, 1)
beam_scores = torch.zeros(beam_size, device=self.eval_device)
start_time = time.time()
finished_flag = False
idx = start_idx
for step in range(start_idx, max_new_tokens):
with torch.no_grad():
# Obtain logits for the last tokens across all beams
outputs = self.model(beams, attention_mask=beam_attention_mask)
next_token_logits = outputs.logits[:, -1, :]
# Apply softmax to convert logits to probabilities
next_token_probs = torch.softmax(next_token_logits, dim=-1)
# Calculate scores for all possible next tokens for each beam
next_token_scores = beam_scores.unsqueeze(-1) + torch.log(next_token_probs)
# Flatten to treat the beam and token dimensions as one
next_token_scores_flat = next_token_scores.view(-1)
# Select top overall scores to find the next beams
top_scores, top_indices = torch.topk(next_token_scores_flat, beam_size, sorted=True)
# Determine the next beams and their corresponding tokens
beam_indices = top_indices // next_token_probs.size(1) # Find which beam the top tokens came from
token_indices = top_indices % next_token_probs.size(1) # Find which token within the beam was selected
# Update beams, scores, and attention masks
beams = torch.cat([beams[beam_indices], token_indices.unsqueeze(-1)], dim=-1)
beam_attention_mask = torch.cat([beam_attention_mask[beam_indices], torch.ones_like(token_indices).unsqueeze(-1)], dim=-1)
beam_scores = top_scores
# Check for stopping criteria
if timestamp is not None and time.time() - start_time >= timestamp:
idx = step
break
# Check for completion
if torch.all(beams[:, -1] == self.tokenizer.eos_token_id):
idx = step
finished_flag = True
break
if step + 1 == max_new_tokens:
idx = step
finished_flag = True
break
best_beam_idx = beam_scores.argmax()
best_beam = beams[best_beam_idx]
outputs = {
"finished_flag": finished_flag,
"start_idx": idx,
"beams": beams,
"beam_scores": beam_scores,
"beam_attention_mask": beam_attention_mask,
"result": best_beam
}
return outputs
Last updated