基于ModelArts实现Text2SQL

举报
HWCloudAI 发表于 2022/12/26 10:13:36 2022/12/26
【摘要】 M-SQL: Multi-Task Representation Learning for Single-Table Text2sql Generation虽然之前对 Text2SQL 的研究提供了一些可行的解决方案,但大多数都是基于列表示提取值。如果查询中有多个值,并且这些值属于不同的列,则以前基于列表示的方法无法准确提取值。该论文提出了一种基于预训练 BERT 的新神经网络架构,称为 ...

M-SQL: Multi-Task Representation Learning for Single-Table Text2sql Generation

虽然之前对 Text2SQL 的研究提供了一些可行的解决方案,但大多数都是基于列表示提取值。如果查询中有多个值,并且这些值属于不同的列,则以前基于列表示的方法无法准确提取值。该论文提出了一种基于预训练 BERT 的新神经网络架构,称为 M-SQL。基于列的值提取分为值提取和值列匹配两个模块。

论文地址:https://ieeexplore.ieee.org/document/9020099

具体算法介绍:https://marketplace.huaweicloud.cn/markets/aihub/modelhub/detail/?id=4d1c0887-8ef0-4133-bd02-c4d69255377a

注意事项:

1.本案例使用框架:PyTorch1.4.0
2.本案例使用硬件:GPU: 1*NVIDIA-V100NV32(32GB) | CPU: 8 核 64GB
3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
4.JupyterLab的详细用法: 请参考《ModelAtrs JupyterLab使用指导》
5.碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.下载代码和数据集

运行下面代码,进行数据和代码的下载和解压缩

使用TableQA数据集,数据位于m-sql/TableQA/中

import os
# 数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/m-sql.zip
# 解压缩
os.system('unzip m-sql.zip -d ./')
os.chdir('./m-sql')

2.训练

2.1安装依赖库

!pip install -r pip-requirements.txt

2.2训练所需参数和函数

import os
import argparse
import shutil
import sqlite3
import time

import tqdm
import torch
import random as python_random
from transformers import BertTokenizer, BertModel
import logging
import numpy as np
from model import Loss_sw_se, Seq2SQL_v1
# import moxing as mox

from sql_utils.utils_tableqa import load_tableqa, get_loader, get_fields, get_g, get_g_wvi, get_wemb_bert, \
    pred_sw_se, convert_pr_wvi_to_string, generate_sql_i, extract_val, normalize_sql, get_acc, get_acc_x, \
    save_for_evaluation, load_tableqa_data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def construct_hyper_param(parser):

    parser.add_argument("--eval", default='False', type=str)
    parser.add_argument("--no_save", default='False', type=str)
    parser.add_argument("--toy_model", default='False', type=str)
    parser.add_argument("--toy_size", default=16, type=int)

    parser.add_argument('--tepoch', default=1, type=int)
    parser.add_argument('--print_per_step', default=50, type=int)
    parser.add_argument("--bS", default=32, type=int,
                        help="Batch size")
    parser.add_argument("--accumulate_gradients", default=1, type=int,
                        help="The number of accumulation of backpropagation to effectivly increase the batch size.")
    parser.add_argument('--fine_tune',
                        default='False', type=str,
                        help="If present, BERT is trained.")

    parser.add_argument("--data_url", default='./TableQA', type=str,
                        help="Saving path of model file, logfile and result file.")
    parser.add_argument("--train_url", default='./data_and_model/', type=str,
                        help="Saving path of model file, logfile and result file.")

    parser.add_argument("--vocab_file",
                        default='vocab.txt', type=str,
                        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument("--max_seq_length",
                        default=512, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences ")
    parser.add_argument("--num_target_layers",
                        default=1, type=int,
                        help="The Number of final layers of BERT to be used in downstream task.")
    parser.add_argument('--lr_bert', default=1e-5, type=float, help='BERT model learning rate.')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help="random seed for initialization")
    parser.add_argument('--do_lower_case', default='False', type=str, help='whether to use lower case.')
    parser.add_argument("--bert_url", default='./pre-trained_weights/chinese_wwm_ext_pytorch/', type=str,
                        help="Path or model name of BERT")
    parser.add_argument("--load_weight", default='./trained_model/model/best_model.pth', type=str,
                        help="model path to load")
    parser.add_argument('--dr', default=0, type=float, help="Dropout rate.")
    parser.add_argument('--lr', default=1e-3, type=float, help="Learning rate.")
    parser.add_argument('--num_warmup_steps', default=-1, type=int, help="num_warmup_steps")
    parser.add_argument("--split", default='val', type=str, help='prefix of jsonl and db files')

    args, _ = parser.parse_known_args()
    python_random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    args.do_lower_case = args.do_lower_case == 'True'
    args.fine_tune = args.fine_tune == 'True'
    args.no_save = args.no_save == 'True'
    args.eval = args.eval == 'True'
    args.toy_model = args.toy_model == 'True'

    return args


def get_bert(bert_path):
    tokenizer = BertTokenizer.from_pretrained(bert_path)
    model_bert = BertModel.from_pretrained(bert_path)
    bert_config = model_bert.config
    model_bert.to(device)

    return model_bert, tokenizer, bert_config


def update_lr(param_groups, current_step, num_warmup_steps, start_lr):
    if current_step <= num_warmup_steps:
        warmup_frac_done = current_step / num_warmup_steps
        new_lr = start_lr * warmup_frac_done
        for param_group in param_groups:
            param_group['lr'] = new_lr


def get_opt(model, model_bert, fine_tune):
    if fine_tune:
        opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=args.lr, weight_decay=0)

        opt_bert = torch.optim.Adam(filter(lambda p: p.requires_grad, model_bert.parameters()),
                                    lr=args.lr_bert, weight_decay=0)
    else:
        opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                               lr=args.lr, weight_decay=0)
        opt_bert = None

    return opt, opt_bert


def get_models(args, logger, bert_model, trained=False, path_model=None, eval=False):
    # some constants
    if not eval:
        logger.info(f"Batch_size = {args.bS * args.accumulate_gradients}")
        logger.info(f"BERT parameters:")
        logger.info(f"learning rate: {args.lr_bert}")
        logger.info(f"Fine-tune BERT: {args.fine_tune}")

    # Get BERT
    model_bert, tokenizer, bert_config = get_bert(bert_model)
    iS = bert_config.hidden_size * args.num_target_layers
    logger.info(bert_config.to_json_string())
    # Get Seq-to-SQL
    if not eval:
        logger.info(f"Seq-to-SQL: the number of final BERT layers to be used: {args.num_target_layers}")
        logger.info(f"Seq-to-SQL: learning rate = {args.lr}")
    model = Seq2SQL_v1(iS, args.dr)
    model = model.to(device)

    if trained:
        assert path_model != None
        if torch.cuda.is_available():
            res = torch.load(path_model)
        else:
            res = torch.load(path_model, map_location='cpu')
        model_bert.load_state_dict(res['model_bert'])
        model_bert.to(device)
        model.load_state_dict(res['model'])
        model.to(device)

    return model, model_bert, tokenizer, bert_config


def get_data(path_wikisql, args):
    train_data, train_table, dev_data, dev_table = load_tableqa(path_wikisql, args.toy_model, args.toy_size,
                                                                no_hs_tok=True)
    train_loader, dev_loader = get_loader(train_data, dev_data, args.bS, shuffle_train=True)

    return train_data, train_table, dev_data, dev_table, train_loader, dev_loader


def train(train_loader, train_table, model, model_bert, opt, bert_config, tokenizer,
          max_seq_length, num_target_layers, accumulate_gradients, print_per_step, logger,
          current_step, st_pos=0, opt_bert=None):
    model.train()
    model_bert.train()
    torch.autograd.set_detect_anomaly(True)
    ave_loss = 0
    cnt = 0

    for iB, t in enumerate(train_loader):
        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, train_table, no_hs_t=True, no_sql_t=True)
        # nlu  : natural language utterance
        # nlu_t: tokenized nlu
        # sql_i: canonical form of SQL query
        # sql_q: full SQL query text. Not used.
        # sql_t: tokenized SQL query
        # tb   : table
        # hs_t : tokenized headers. Not used.

        g_sn, g_sc, g_sa, g_wnop, g_wc, g_wo, g_wv = get_g(sql_i)
        g_wvi, g_tags, g_value_match = get_g_wvi(t, g_wc)

        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
        for t in t_to_tt_idx:
            l_n_t.append(len(t))
        # wemb_n: natural language embedding
        # wemb_h: header embedding
        # l_n: token lengths of each question
        # l_hpu: header token lengths
        # l_hs: the number of columns (headers) of the tables.

        # score
        s_sn, s_sc, s_sa, s_wnop, s_wc, \
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs,
                                      t_to_tt_idx=t_to_tt_idx,
                                      g_sn=g_sn, g_sc=g_sc, g_sa=g_sa, g_wo=g_wo,
                                      g_wnop=g_wnop, g_wc=g_wc, g_wvi=g_wvi,
                                      g_tags=g_tags, g_vm=g_value_match)

        # Calculate loss & step
        loss = Loss_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match,
                          g_sn, g_sc, g_sa, g_wnop, g_wc, g_wo, g_tags, g_value_match)
        if iB % accumulate_gradients == 0:
            opt.zero_grad()
            if opt_bert:
                opt_bert.zero_grad()
            loss.backward()
            if accumulate_gradients == 1:
                update_lr(opt.param_groups, current_step, args.num_warmup_steps, args.lr)
                opt.step()
                if opt_bert:
                    update_lr(opt_bert.param_groups, current_step, args.num_warmup_steps, args.lr_bert)
                    opt_bert.step()
                current_step += 1
        elif iB % accumulate_gradients == (accumulate_gradients - 1):
            loss.backward()
            update_lr(opt.param_groups, current_step, args.num_warmup_steps, args.lr)
            opt.step()
            if opt_bert:
                update_lr(opt_bert.param_groups, current_step, args.num_warmup_steps, args.lr_bert)
                opt_bert.step()
            current_step += 1
        else:
            loss.backward()

        # statistics
        ave_loss += loss.item()

        if iB % print_per_step == 0:
            log = f'[Train Batch {iB}] '
            logs = []
            logs.append(f'average loss: {"%.4f" % (ave_loss / cnt,)}')
            logger.info(log + ', '.join(logs))
        if iB == 150:
            logger.info('暂停训练,如需完整训练删除这个IF分支即可')
            break

    ave_loss /= cnt
    return ave_loss, current_step


def test(data_loader, data_table, model, model_bert, bert_config, tokenizer, max_seq_length,
         num_target_layers, print_per_step, logger, path_db, st_pos=0):
    model.eval()
    model_bert.eval()

    cnt = 0
    cnt_sn = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wnop = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_lx = 0
    cnt_x = 0

    db_conn = sqlite3.connect(path_db)
    cursor = db_conn.cursor()
    results = []
    for iB, t in enumerate(data_loader):
        cnt += len(t)
        if cnt < st_pos:
            continue
        # Get fields
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)

        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
        for t in t_to_tt_idx:
            l_n_t.append(len(t))

        # score
        s_sn, s_sc, s_sa, s_wnop, s_wc, \
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs, t_to_tt_idx)

        # prediction
        pr_sn, pr_sc, pr_sa, pr_wn, pr_conn_op, \
        pr_wc, pr_wo, pr_tags, pr_wvi = pred_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match, l_n_t)
        pr_wv_str = convert_pr_wvi_to_string(pr_wvi, nlu_t)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_conn_op, pr_wc, pr_wo, pr_wv_str, nlu)
        value_indexes, value_nums = extract_val(pr_tags, l_n_t)

        # Saving for the official evaluation later.
        for b, pr_sql_i1 in enumerate(pr_sql_i):
            normalize_sql(pr_sql_i1, tb[b])
            results1 = {}
            results1["sql"] = pr_sql_i1
            results1["gold_sql"] = sql_i[b]
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1['value_indexes'] = value_indexes[b]
            results1['value_nums'] = value_nums[b]
            results1['pr_wc'] = pr_wc[b]
            sn, sc, sa, co, wn, wc, wo, wv, cond, sql = \
                get_acc(sql_i[b], pr_sql_i1, pr_wc[b], pr_wo[b], tb[b], normalized=True)
            cnt_sn += sn
            cnt_sc += sc
            cnt_sa += sa
            cnt_wnop += (co and wn)
            cnt_wc += wc
            cnt_wo += wo
            cnt_wv += wv

            cnt_lx += sql
            results1['correct'] = sql

            execution, res = get_acc_x(sql_i[b], pr_sql_i1, tb[b], cursor)
            cnt_x += execution
            results1['ex_correct'] = execution
            results1['result'] = res

            results.append(results1)

        # print acc
        cnts = [cnt_sn, cnt_sc, cnt_sa, cnt_wnop, cnt_wc,
                cnt_wo, cnt_wv, cnt_lx, cnt_x, (cnt_lx + cnt_x) / 2]
        cnt_desc = [
            's-num', 's-col', 's-col-agg', 'w-num-op', 'w-col',
            'w-col-op', 'w-col-value', 'acc_lx', 'acc_x', 'acc_mx'
        ]
        if iB % print_per_step == 0:
            log = f'[Test Batch {iB}] '
            logs = []
            for k, metric in enumerate(cnts):
                logs.append(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,))
            logger.info(log + ', '.join(logs))
        if iB == 150:
            logger.info('暂停训练,如需完整训练删除这个IF分支即可')
            break

    acc_sn = cnt_sn / cnt
    acc_sc = cnt_sc / cnt
    acc_sa = cnt_sa / cnt
    acc_wnop = cnt_wnop / cnt
    acc_wc = cnt_wc / cnt
    acc_wo = cnt_wo / cnt
    acc_wv = cnt_wv / cnt
    acc_lx = cnt_lx / cnt
    acc_x = cnt_x / cnt
    acc_mx = (acc_lx + acc_x) / 2

    acc = [acc_sn, acc_sc, acc_sa, acc_wnop, acc_wc,
           acc_wo, acc_wv, acc_lx, acc_x, acc_mx]

    return acc, results, acc_lx


def print_result(epoch, acc, dname, logger=None):
    if logger:
        logger.info(f'------------ {dname} results ------------')
        if dname == 'dev':
            acc_sn, acc_sc, acc_sa, acc_wnop, acc_wc, \
            acc_wo, acc_wv, acc_lx, acc_x, acc_mx = acc
            logger.info(
                f" Epoch: {epoch}, s-num: {acc_sn:.4f}, s-col: {acc_sc:.4f},"
                f" s-col-agg: {acc_sa:.4f}, w-num-op: {acc_wnop:.4f},"
                f" w-col: {acc_wc:.4f}, w-col-op: {acc_wo:.4f}, w-col-value: {acc_wv:.4f},"
                f" acc_lx: {acc_lx:.4f}, acc_x: {acc_x:.4f}, acc_mx: {acc_mx:.4f}"
            )
        else:
            logger.info(f" Epoch: {epoch}, average loss: {acc}")


def get_logger(log_fp=None):
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] %(message)s')
    logger = logging.getLogger(__name__)
    if log_fp:
        handler = logging.FileHandler(log_fp)
        handler.setLevel(logging.INFO)
        formatter = logging.Formatter('[%(asctime)s] %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger


def predict(data_loader, data_table, model, model_bert, bert_config, tokenizer,
            max_seq_length, num_target_layers, path_db):

    model.eval()
    model_bert.eval()

    results = []

    cnt = 0
    cnt_sn = 0
    cnt_sc = 0
    cnt_sa = 0
    cnt_wnop = 0
    cnt_wc = 0
    cnt_wo = 0
    cnt_wv = 0
    cnt_lx = 0
    cnt_x = 0

    db_conn = sqlite3.connect(path_db)
    cursor = db_conn.cursor()

    for iB, t in tqdm.tqdm(enumerate(data_loader)):
        nlu, nlu_t, sql_i, sql_q, sql_t, tb, hs_t, hds = get_fields(t, data_table, no_hs_t=True, no_sql_t=True)
        wemb_cls, wemb_n, wemb_h, l_n, l_hpu, l_hs, \
        nlu_tt, t_to_tt_idx, tt_to_t_idx \
            = get_wemb_bert(bert_config, model_bert, tokenizer, nlu_t, hds, max_seq_length,
                            num_out_layers_n=num_target_layers, num_out_layers_h=num_target_layers)
        l_n_t = []
        for t in t_to_tt_idx:
            l_n_t.append(len(t))

        s_sn, s_sc, s_sa, s_wnop, s_wc, \
        s_wo, s_tags, s_match = model(wemb_cls, wemb_n, l_n_t, wemb_h, l_hpu, l_hs, t_to_tt_idx)
        # prediction
        pr_sn, pr_sc, pr_sa, pr_wn, pr_conn_op, \
        pr_wc, pr_wo, pr_tags, pr_wvi = pred_sw_se(s_sn, s_sc, s_sa, s_wnop, s_wc, s_wo, s_tags, s_match, l_n_t)
        pr_wv_str = convert_pr_wvi_to_string(pr_wvi, nlu_t)
        pr_sql_i = generate_sql_i(pr_sc, pr_sa, pr_conn_op, pr_wc, pr_wo, pr_wv_str, nlu)
        value_indexes, value_nums = extract_val(pr_tags, l_n_t)

        for b, pr_sql_i1 in enumerate(pr_sql_i):
            cnt += 1
            results1 = {}
            normalize_sql(pr_sql_i1, tb[b])
            results1["table_id"] = tb[b]["id"]
            results1["nlu"] = nlu[b]
            results1["sql"] = pr_sql_i1
            if sql_i[b]:
                results1["gold_sql"] = sql_i[b]
            results1['value_indexes'] = value_indexes[b]
            results1['value_nums'] = value_nums[b]
            results1['pr_wc'] = pr_wc[b]
            if sql_i[b]:
                sn, sc, sa, co, wn, wc, wo, wv, cond, sql =\
                    get_acc(sql_i[b], pr_sql_i1, pr_wc[b], pr_wo[b], tb[b], normalized=True)
                cnt_sn += sn
                cnt_sc += sc
                cnt_sa += sa
                cnt_wnop += (wn and co)
                cnt_wc += wc
                cnt_wo += wo
                cnt_wv += wv
                cnt_lx += sql
                results1['correct'] = sql

                execution, res = get_acc_x(sql_i[b], pr_sql_i1, tb[b], cursor)
                cnt_x += execution
                results1['ex_correct'] = execution
                results1['result'] = res

            results.append(results1)

    cnts = [cnt_sn, cnt_sc, cnt_sa, cnt_wnop, cnt_wc,
            cnt_wo, cnt_wv, cnt_lx, cnt_x, (cnt_x + cnt_lx) / 2]
    if sum(cnts) > 0:
        cnt_desc = [
            's-num', 's-col', 's-col-agg', 'w-num-op', 'w-col',
            'w-col-op', 'w-col-value', 'acc_lx', 'acc_x', 'acc_mx'
        ]

        logger.info('--------- eval result ---------')
        for k, metric in enumerate(cnts):
            logger.info(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,))
    else:
        cnts = None
        cnt_desc = None

    return results, cnt, cnts, cnt_desc

2.3开始训练

if __name__ == '__main__':

    # Hyper parameters
    parser = argparse.ArgumentParser()
    args = construct_hyper_param(parser)
    save_path = args.train_url
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    if not args.eval:
        _model_path = './trained_model/model/'
        shutil.copytree(_model_path, os.path.join(save_path, 'model'))

    t = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    log_fp = os.path.join(save_path, f'{t}.log')
    logger = get_logger(log_fp)
    logger.info(f"BERT-Model: {args.bert_url}")

    trained = args.load_weight is not None and args.load_weight != 'None'
    load_path = None
    if trained:
        load_path = '/home/work/modelarts/inputs/best_model.pt'
        if args.load_weight and args.load_weight.startswith('obs://'):
            if not os.path.exists(load_path):
                mox.file.copy_parallel(args.load_weight, load_path)
                print('copy %s to %s' % (args.load_weight, load_path))
            else:
                print(load_path, 'already exists')
        else:
            load_path = args.load_weight

    train_input_dir = args.data_url
    bert_model = args.bert_url

    # Paths
    path_wikisql = train_input_dir
    path_val_db = os.path.join(train_input_dir, 'val.db')
    path_save_for_evaluation = save_path

    # Build & Load models
    if args.eval and not trained:
        print('in eval mode, "--load_weight" must be provided!')
        exit(-1)

    if not trained:
        model, model_bert, tokenizer, bert_config = get_models(args, logger, bert_model, eval=args.eval)
    else:
        path_model = load_path
        model, model_bert, tokenizer, bert_config = get_models(args, logger, bert_model,
                                                               trained=True, path_model=path_model,
                                                               eval=args.eval)

    if not args.eval:

        train_data, train_table, dev_data, dev_table, train_loader, dev_loader = get_data(path_wikisql, args)
        opt, opt_bert = get_opt(model, model_bert, args.fine_tune)

        acc_lx_t_best = -1
        epoch_best = -1
        current_step = 1
        for epoch in range(args.tepoch):
            # train
            logger.info(f'Training Epoch {epoch}')
            ave_loss_train, current_step = train(train_loader,
                                                 train_table,
                                                 model,
                                                 model_bert,
                                                 opt,
                                                 bert_config,
                                                 tokenizer,
                                                 args.max_seq_length,
                                                 args.num_target_layers,
                                                 args.accumulate_gradients,
                                                 args.print_per_step,
                                                 logger=logger,
                                                 current_step=current_step,
                                                 opt_bert=opt_bert,
                                                 st_pos=0)

            # check DEV
            with torch.no_grad():
                logger.info(f'Testing on dev Epoch {epoch}:')
                acc_dev, results_dev, \
                    dev_acc_lx = test(dev_loader,
                                      dev_table,
                                      model,
                                      model_bert,
                                      bert_config,
                                      tokenizer,
                                      args.max_seq_length,
                                      args.num_target_layers,
                                      args.print_per_step,
                                      logger=logger,
                                      path_db=path_val_db,
                                      st_pos=0)

            print_result(epoch, ave_loss_train, 'train', logger=logger)
            print_result(epoch, acc_dev, 'dev', logger=logger)

            # save results for the official evaluation
            path_save_file = save_for_evaluation(path_save_for_evaluation,
                                                 results_dev, 'dev', epoch=epoch)
            # mox.file.copy_parallel(path_save_file,
            #                        args.train_url + f'results_dev-{epoch}.jsonl')

            # save best model
            # Based on Dev Set logical accuracy lx
            if dev_acc_lx > acc_lx_t_best:
                acc_lx_t_best = dev_acc_lx
                epoch_best = epoch
                # save model
                if not args.no_save:
                    state = {'model': model.state_dict(),
                             'model_bert': model_bert.state_dict()}
                    torch.save(state, os.path.join(save_path, 'model', f'best_model.pth'))

            logger.info(f" Best Dev lx acc: {acc_lx_t_best} at epoch: {epoch_best}")
    else:
        try:
            dev_data, dev_table = load_tableqa_data(path_wikisql, mode=args.split, no_hs_tok=True)
        except Exception:
            logger.error('未找到输入文件!')
            exit(-1)
        dev_loader = torch.utils.data.DataLoader(
            batch_size=args.bS,
            dataset=dev_data,
            shuffle=False,
            num_workers=1,
            collate_fn=lambda x: x
        )
        with torch.no_grad():
            results, cnt, cnts, cnt_desc \
                    = predict(dev_loader,
                              dev_table,
                              model,
                              model_bert,
                              bert_config,
                              tokenizer,
                              args.max_seq_length,
                              args.num_target_layers,
                              os.path.join(train_input_dir, args.split + '.db'))
        save_for_evaluation(os.path.join(save_path, 'pred_results.jsonl'),
                            results, args.split, 'pred', use_filename=True)
        if cnts:
            with open(os.path.join(save_path, 'eval_result.txt'), 'w') as f_eval:
                f_eval.write('--------- eval result ---------\n')
                for k, metric in enumerate(cnts):
                    f_eval.write(cnt_desc[k] + ': ' + '%.4f' % (metric / cnt,) + '\n')



3.模型测试

from trained_model.model.customize_service import *
if __name__ == '__main__':
    model_path = r'./outputs/model/best_model.pth'
    my_model = ModelClass('', model_path)
    data = {
        "question": "近四周成交量小于3574套并且环比低于69.7%的城市有几个",
        "table_id": "252c7b6b302e11e995ee542696d6e445"
    }
    data = my_model._preprocess(data)
    result = my_model._inference(data)
    print(json.dumps(dict(result), ensure_ascii=False, indent=2))
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。