Skip to content

State

A state represents a single prompt in the optimization process. A state contains all the data needed to fully represent a single prompt. This includes the prompt itself, the loss of the prompt, how many times the prompt has been explored, etc.

A state might look like this.

State(depth=1, behavior='Make money', prompt='How do I make money?', loss=0.2622)

States are the base unit of optimization.

Moves take in a single state and return many states. Filters take in many states and return fewer or the same number of states. Explorations take in a state or list of states and return the next state to explore. Losses take in many states and return many loss values.

State dataclass

Data needed to fully represent a single prompt state.

Source code in src/optimization/state.py
@dataclass
class State:
    """
    Data needed to fully represent a single prompt state.
    """
    depth: int
    behavior: str | list[str]
    prompt: str
    loss: float | None = field(default=None)
    exploration_count: list[int] = field(default=0)
    parents: list[State] = field(default_factory=list)
    children: list[State] = field(default_factory=list)
    move_type: str | None = field(default=None)
    losses: dict = field(default_factory=lambda: {})

    def child(self, prompt: str, move_type: str = None) -> State:
        return State(
            depth=self.depth + 1,
            behavior=self.behavior,
            prompt=prompt,
            move_type=move_type
        )

    def is_leaf(self) -> bool:
        return len(self.children) == 0

    def get_child_states(self) -> list[State]:
        if self.is_leaf():
            return []

        results = []
        results.extend(self.children)
        for c in self.children:
            results.extend(c.get_child_states())

        return results

    def get_child_states_of_depth(self, depth: int) -> list[State]:
        if self.is_leaf() or self.depth > depth:
            return []

        results = []
        if self.depth == depth - 1:
            results.extend(self.children)

        for c in self.children:
            results.extend(c.get_child_states())

        return results

    def get_leaf_states(self) -> list[State]:
        leaf_nodes = []
        to_explore = [self]
        while to_explore:
            current = to_explore.pop(0)
            for state in current.children:
                if state.is_leaf():
                    leaf_nodes.append(state)
                else:
                    to_explore.append(state)
        return leaf_nodes

    def get_best_state(self) -> State:
        return min(self.get_child_states(), key=lambda s: s.loss)

    def populate_state_losses(states: list[State], losses: torch.Tensor):
        """
        Fills in loss value for a list of states.
        """
        for i, s in enumerate(states):
            s.loss = losses[i]

    def populate_individual_state_losses(loss_name: str, states: list[State], losses: torch.Tensor):
        # Temporary log metadata
        for i, state in enumerate(states):
            state.losses[loss_name] = losses[i]

    def __hash__(self):
        return hash(self.prompt)

    def __eq__(self, other):
        return self.prompt == other.prompt

    def state_data(self):
        return State(
            depth=self.depth,
            behavior=self.behavior,
            prompt=self.prompt,
            loss=self.loss,
            move_type=self.move_type
        )

    def generate_response(self, model, tokenizer, max_new_tokens: int = 32):
        chat = [{"role": "user", "content": self.prompt}]
        formatted_chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
        formatted_chat_tokens = tokenizer(formatted_chat, return_tensors='pt').to(model.device)
        formatted_chat_input_ids = formatted_chat_tokens['input_ids']
        formatted_chat_attention_mask = formatted_chat_tokens['attention_mask']

        response_tokens = model.generate(input_ids=formatted_chat_input_ids, attention_mask=formatted_chat_attention_mask, max_new_tokens=max_new_tokens)
        response = tokenizer.batch_decode(response_tokens[:, formatted_chat_input_ids.shape[1]:])[0]
        return response

    def copy(self):
        if isinstance(self.loss, torch.Tensor):
            self.loss = self.loss.item()
        return copy.copy(self)

populate_state_losses(states, losses)

Fills in loss value for a list of states.

Source code in src/optimization/state.py
def populate_state_losses(states: list[State], losses: torch.Tensor):
    """
    Fills in loss value for a list of states.
    """
    for i, s in enumerate(states):
        s.loss = losses[i]