ValueNet4SPARQL / src / intermediate_representation / beam.py
beam.py
Raw
import copy

from intermediate_representation.semQL import Root1, Filter, N, Root, T, C, A, Sel


class ActionInfo(object):
    """sufficient statistics for making a prediction of an action at a time step"""

    def __init__(self, action=None):
        self.t = 0
        self.score = 0
        self.parent_t = -1
        self.action = action
        self.frontier_prod = None
        self.frontier_field = None

        # for GenToken actions only
        self.copy_from_src = False
        self.src_token_position = -1


class Beams(object):
    def __init__(self, is_sketch=False):
        self.actions = []
        self.action_infos = []
        self.inputs = []
        self.score = 0.
        self.t = 0
        self.is_sketch = is_sketch
        self.sketch_step = 0
        self.sketch_attention_history = list()

    def get_availableClass(self):
        """
        return next possible action class.
        :return:
        """

        # TODO: it could be update by speed
        # return the available class using rule
        # FIXME: now should change for these 11: "Filter 1 ROOT",
        def check_type(lists):
            for s in lists:
                if type(s) == int:
                    return False
            return True

        stack = [Root1]
        for action in self.actions:
            infer_action = action.get_next_action(is_sketch=self.is_sketch)
            infer_action.reverse()
            if stack[-1] is type(action):
                stack.pop()
                # check if the are non-terminal
                if check_type(infer_action):
                    stack.extend(infer_action)
            else:
                raise RuntimeError("Not the right action")

        result = stack[-1] if len(stack) > 0 else None

        return result

    @classmethod
    def get_parent_action(cls, actions):
        """

        :param actions:
        :return:
        """

        def check_type(lists):
            for s in lists:
                if type(s) == int:
                    return False
            return True

        # check the origin state Root
        if len(actions) == 0:
            return None

        stack = [Root1]
        for id_x, action in enumerate(actions):
            infer_action = action.get_next_action()
            for ac in infer_action:
                ac.parent = action
                ac.pt = id_x
            infer_action.reverse()
            if stack[-1] is type(action):
                stack.pop()
                # check if the are non-terminal
                if check_type(infer_action):
                    stack.extend(infer_action)
            else:
                for t in actions:
                    if type(t) != C:
                        print(t, end="")
                print('asd')
                print(action)
                print(stack[-1])
                raise RuntimeError("Not the right action")
        result = stack[-1] if len(stack) > 0 else None

        return result

    def apply_action(self, action):
        # TODO: not finish implement yet
        self.t += 1
        self.actions.append(action)

    def clone_and_apply_action(self, action):
        new_hyp = self.copy()
        new_hyp.apply_action(action)

        return new_hyp

    def clone_and_apply_action_info(self, action_info):
        action = action_info.action
        action.score = action_info.score
        new_hyp = self.clone_and_apply_action(action)
        new_hyp.action_infos.append(action_info)
        new_hyp.sketch_step = self.sketch_step
        new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history)

        return new_hyp

    def copy(self):
        new_hyp = Beams(is_sketch=self.is_sketch)
        # if self.tree:
        #     new_hyp.tree = self.tree.copy()

        new_hyp.actions = list(self.actions)
        new_hyp.score = self.score
        new_hyp.t = self.t
        new_hyp.sketch_step = self.sketch_step
        new_hyp.sketch_attention_history = copy.copy(self.sketch_attention_history)

        return new_hyp

    def infer_n(self):
        if len(self.actions) > 4:
            prev_action = self.actions[-3]
            if isinstance(prev_action, Filter):
                if prev_action.id_c > 11:
                    # Nested Query, only select 1 column
                    return ['N A']
            if self.actions[0].id_c != 3:
                return [self.actions[3].production]
        return N._init_grammar()

    @property
    def completed(self):
        return True if self.get_availableClass() is None else False

    @property
    def is_valid(self):
        actions = self.actions
        return self.check_sel_valid(actions)

    def check_sel_valid(self, actions):
        find_sel = False
        sel_actions = list()
        for ac in actions:
            if type(ac) == Sel:
                find_sel = True
            elif find_sel and type(ac) in [N, T, C, A]:
                if type(ac) not in [N]:
                    sel_actions.append(ac)
            elif find_sel and type(ac) not in [N, T, C, A]:
                break

        if find_sel is False:
            return True

        # not the complete sel lf
        if len(sel_actions) % 3 != 0:
            return True

        sel_string = list()
        for i in range(len(sel_actions) // 3):
            if (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c) in sel_string:
                return False
            else:
                sel_string.append(
                    (sel_actions[i * 3 + 0].id_c, sel_actions[i * 3 + 1].id_c, sel_actions[i * 3 + 2].id_c))
        return True


if __name__ == '__main__':
    test = Beams(is_sketch=True)
    # print(Root1(1).get_next_action())
    test.actions.append(Root1(3))
    test.actions.append(Root(5))

    print(str(test.get_availableClass()))