State
A state represents a single prompt in the optimization process. A state contains all the data needed to fully represent a single prompt. This includes the prompt itself, the loss of the prompt, how many times the prompt has been explored, etc.
A state might look like this.
State(depth=1, behavior='Make money', prompt='How do I make money?', loss=0.2622)
States are the base unit of optimization.
Moves take in a single state and return many states. Filters take in many states and return fewer or the same number of states. Explorations take in a state or list of states and return the next state to explore. Losses take in many states and return many loss values.
State
dataclass
Data needed to fully represent a single prompt state.
Source code in src/optimization/state.py
| @dataclass
class State:
"""
Data needed to fully represent a single prompt state.
"""
depth: int
behavior: str | list[str]
prompt: str
loss: float | None = field(default=None)
exploration_count: list[int] = field(default=0)
parents: list[State] = field(default_factory=list)
children: list[State] = field(default_factory=list)
move_type: str | None = field(default=None)
losses: dict = field(default_factory=lambda: {})
def child(self, prompt: str, move_type: str = None) -> State:
return State(
depth=self.depth + 1,
behavior=self.behavior,
prompt=prompt,
move_type=move_type
)
def is_leaf(self) -> bool:
return len(self.children) == 0
def get_child_states(self) -> list[State]:
if self.is_leaf():
return []
results = []
results.extend(self.children)
for c in self.children:
results.extend(c.get_child_states())
return results
def get_child_states_of_depth(self, depth: int) -> list[State]:
if self.is_leaf() or self.depth > depth:
return []
results = []
if self.depth == depth - 1:
results.extend(self.children)
for c in self.children:
results.extend(c.get_child_states())
return results
def get_leaf_states(self) -> list[State]:
leaf_nodes = []
to_explore = [self]
while to_explore:
current = to_explore.pop(0)
for state in current.children:
if state.is_leaf():
leaf_nodes.append(state)
else:
to_explore.append(state)
return leaf_nodes
def get_best_state(self) -> State:
return min(self.get_child_states(), key=lambda s: s.loss)
def populate_state_losses(states: list[State], losses: torch.Tensor):
"""
Fills in loss value for a list of states.
"""
for i, s in enumerate(states):
s.loss = losses[i]
def populate_individual_state_losses(loss_name: str, states: list[State], losses: torch.Tensor):
# Temporary log metadata
for i, state in enumerate(states):
state.losses[loss_name] = losses[i]
def __hash__(self):
return hash(self.prompt)
def __eq__(self, other):
return self.prompt == other.prompt
def state_data(self):
return State(
depth=self.depth,
behavior=self.behavior,
prompt=self.prompt,
loss=self.loss,
move_type=self.move_type
)
def generate_response(self, model, tokenizer, max_new_tokens: int = 32):
chat = [{"role": "user", "content": self.prompt}]
formatted_chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
formatted_chat_tokens = tokenizer(formatted_chat, return_tensors='pt').to(model.device)
formatted_chat_input_ids = formatted_chat_tokens['input_ids']
formatted_chat_attention_mask = formatted_chat_tokens['attention_mask']
response_tokens = model.generate(input_ids=formatted_chat_input_ids, attention_mask=formatted_chat_attention_mask, max_new_tokens=max_new_tokens)
response = tokenizer.batch_decode(response_tokens[:, formatted_chat_input_ids.shape[1]:])[0]
return response
def copy(self):
if isinstance(self.loss, torch.Tensor):
self.loss = self.loss.item()
return copy.copy(self)
|
populate_state_losses(states, losses)
Fills in loss value for a list of states.
Source code in src/optimization/state.py
| def populate_state_losses(states: list[State], losses: torch.Tensor):
"""
Fills in loss value for a list of states.
"""
for i, s in enumerate(states):
s.loss = losses[i]
|