C3MD / utils_data.py
utils_data.py
Raw
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time    : 2023/12/5 20:39
# @Author  : yebulk
import os
from torch.utils.data import Dataset
import os
import json
import numpy as np
import torch
from utils_prompt import *

img_shape = {
    "resnet": (512, 2048),
    "clip": (49, 2048),
    "detr": (100, 256),
    "vit": (145, 1024),
}

def split_chinese_string(string):
    split_result = string.split("", 1)
    if len(split_result) > 1:
        return split_result[1]
    else:
        return string

def load_data_std(args):
    problems = json.load(open(os.path.join(args.data_root, '../Data/output.json')))
    pid_splits = json.load(open(os.path.join(args.data_root, '../Data/split.json')))
    # captions = json.load(open(args.caption_file))["captions"]

    # for qid in problems:
    #     problems[qid]['caption'] = captions[qid] if qid in captions else ""

    train_qids = pid_splits['%s' % (args.train_split)]
    val_qids = pid_splits['%s' % (args.val_split)]
    test_qids = pid_splits['%s' % (args.test_split)]
    print(f"number of train problems: {len(train_qids)}\n")
    print(f"number of val problems: {len(val_qids)}\n")
    print(f"number of test problems: {len(test_qids)}\n")

    qids = {'train': train_qids, 'val': val_qids, 'test': test_qids}
    return problems, qids,


def load_data_img(args):
    problems = json.load(open(os.path.join(args.data_root, 'output.json')))
    pid_splits = json.load(open(os.path.join(args.data_root, 'split.json')))
    # captions = json.load(open(args.caption_file))["captions"]

    name_maps = json.load(open('/mnt/sdb1/jsy/yjw/prompt/myCode/vision_features/name_map.json'))

    # check
    if args.img_type == "resnet":
        image_features = np.load('../vision_features_pics_Meme/resnet.npy')
        image_features = np.expand_dims(image_features, axis=1)
        image_features = image_features.repeat(512, axis=1)
    elif args.img_type == "clip":
        image_features = np.load('../vision_features_pics_Meme/clip.npy')
    elif args.img_type == "detr":
        image_features = torch.load('/mnt/sdb1/jsy/yjw/prompt/myCode/vision_features/detr.pth').to('cpu')
    elif args.img_type == "vit":
        image_features = torch.load("/mnt/sdb1/jsy/yjw/prompt/myCode/vision_features/vit.pth").to('cpu')
    else:
        image_features = np.load('/mnt/sdb1/jsy/yjw/prompt/myCode/vision_features/vit.pth')
    print("img_features size: ", image_features.shape)

    # for qid in problems:
    #     problems[qid]['caption'] = captions[qid] if qid in captions else ""

    train_qids = pid_splits['%s' % (args.train_split)]
    val_qids = pid_splits['%s' % (args.val_split)]
    test_qids = pid_splits['%s' % (args.test_split)]
    print(f"number of train problems: {len(train_qids)}\n")
    print(f"number of val problems: {len(val_qids)}\n")
    print(f"number of test problems: {len(test_qids)}\n")

    qids = {'train': train_qids, 'val':val_qids,'test':test_qids}
    return problems, qids, name_maps, image_features


class MetaphorDatasetStd(Dataset):
    """
    Creating a custom dataset for reading the dataset and
    loading it into the dataloader to pass it to the
    neural network for finetuning the model

    """

    def __init__(
            self, problems, qids, tokenizer, answer_len, question_len, args, test_le=None
    ):
        self.tokenizer = tokenizer
        self.data = {qid: problems[qid] for qid in qids}
        self.answer_len = answer_len
        self.question_len = question_len
        self.questions = []
        self.answers = []
        if test_le is not None:
            test_le_data = json.load(open(test_le))["preds"]
        else:
            test_le_data = None
        idx = 0
        for qid in self.data:
            if test_le_data is not None:
                curr_le_data = test_le_data[idx]
                idx += 1
            else:
                curr_le_data = None
            question, answer = build_train_pair(problems, qid, args, curr_le_data)
            self.questions.append(question)
            self.answers.append(answer)

    def __len__(self):
        return len(self.answers)

    def __getitem__(self, index):
        question_text = str(self.questions[index])
        answer_text = str(self.answers[index])
        


        question = self.tokenizer.batch_encode_plus(
            [question_text],
            max_length=self.question_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        answer = self.tokenizer.batch_encode_plus(
            [answer_text],
            max_length=self.answer_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        question_ids = question["input_ids"].squeeze()
        question_mask = question["attention_mask"].squeeze()
        answer_ids = answer["input_ids"].squeeze().tolist()

        return {
            "input_ids": question_ids,
            "attention_mask": question_mask,
            "labels": answer_ids,
        }

class MeatphorDatasetImg(Dataset):
    """
        Creating a custom dataset for reading the dataset and
        loading it into the dataloader to pass it to the
        neural network for finetuning the model

        """

    def __init__(
            self, problems, qids, name_maps, tokenizer, question_len, answer_len, args, image_features, test_le=None
    ):
        """
        Initializes a Dataset class

        Args:
            dataframe (pandas.DataFrame): Input dataframe
            tokenizer (transformers.tokenizer): Transformers tokenizer
            source_len (int): Max length of source text
            target_len (int): Max length of target text
            source_text (str): column name of source text
            target_text (str): column name of target text
        """
        self.tokenizer = tokenizer
        self.data = {qid: problems[qid] for qid in qids}
        self.answer_len = answer_len
        self.question_len = question_len
        self.questions = []
        self.answers = []
        self.image_ids = []
        if test_le is not None:
            test_le_data = json.load(open(test_le))["preds"]
        else:
            test_le_data = None
        idx = 0
        for qid in self.data:
            if test_le_data is not None:
                curr_le_data = test_le_data[idx]
                idx += 1
            else:
                curr_le_data = None
            question, answer = build_train_pair(problems, qid, args, curr_le_data)
            self.questions.append(question)
            self.answers.append(answer)
            pic_name = str(problems[qid]['Pic_id'])
            if str(pic_name) in name_maps:
                i_vectors = image_features[int(name_maps[str(pic_name)])]
                self.image_ids.append(i_vectors)
            else:
                shape = img_shape[args.img_type]
                self.image_ids.append(np.zeros(shape))

    def __len__(self):
        """returns the length of dataframe"""

        return len(self.answers)

    def __getitem__(self, index):
        """return the input ids, attention masks and target ids"""

        image_ids = self.image_ids[index]

        question_text = str(self.questions[index])
        answer_text = str(self.answers[index])

        question = self.tokenizer.batch_encode_plus(
            [question_text],
            max_length=self.question_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        answer = self.tokenizer.batch_encode_plus(
            [answer_text],
            max_length=self.answer_len,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        question_ids = question["input_ids"].squeeze()
        question_mask = question["attention_mask"].squeeze()
        answer_ids = answer["input_ids"].squeeze().tolist()

        image_ids = torch.tensor(image_ids).squeeze()

        return {
            "input_ids": question_ids,
            "attention_mask": question_mask,
            "labels": answer_ids,
            "image_ids": image_ids
        }