• 数据文件预览:
`# 数据集在虚拟机/root/data/ag_news_csv下
• 文件说明:
◦ train.csv表示训练数据, 共12万条数据; test.csv表示验证数据, 共7600条数据; classes.txt是标签(新闻主题)含义文件, 里面有四个单词'World', 'Sports', 'Business', 'Sci/Tech'代表新闻的四个主题, readme.txt是该数据集的英文说明.
• train.csv预览:
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again." "3","Carlyle Looks Toward Commercial Aerospace (Reuters)","Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market." "3","Oil and Economy Cloud Stocks' Outlook (Reuters)","Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums." "3","Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)","Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday." "3","Oil prices soar to all-time record, posing new menace to US economy (AFP)","AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections." "3","Stocks End Up, But Near Year Lows (Reuters)","Reuters - Stocks ended slightly higher on Friday\but stayed near lows for the year as oil prices surged past #36;46\a barrel, offsetting a positive outlook from computer maker\Dell Inc. (DELL.O)" "3","Money Funds Fell in Latest Week (AP)","AP - Assets of the nation's retail money market mutual funds fell by #36;1.17 billion in the latest week to #36;849.98 trillion, the Investment Company Institute said Thursday." "3","Fed minutes show dissent over inflation (USATODAY.com)","USATODAY.com - Retail sales bounced back a bit in July, and new claims for jobless benefits fell last week, the government said Thursday, indicating the economy is improving from a midsummer slump." "3","Safety Net (Forbes.com)","Forbes.com - After earning a PH.D. in Sociology, Danny Bazil Riley started to work as the general manager at a commercial real estate firm at an annual base salary of #36;70,000. Soon after, a financial planner stopped by his desk to drop off brochures about insurance benefits available through his employer. But, at 32, ""buying insurance was the furthest thing from my mind,"" says Riley." "3","Wall St. Bears Claw Back Into the Black"," NEW YORK (Reuters) - Short-sellers, Wall Street's dwindling band of ultra-cynics, are seeing green again."
• 文件内容说明:
◦ train.csv共由3列组成, 使用','进行分隔, 分别代表: 标签, 新闻标题, 新闻简述; 其中标签用"1", "2", "3", "4"表示, 依次对应classes中的内容.
◦ test.csv与train.csv内容格式与含义相同.
`from torchtext.legacy.datasets.text_classification import _csv_iterator, _create_data_from_iterator, TextClassificationDataset from torchtext.utils import extract_archive from torchtext.vocab import build_vocab_from_iterator, Vocab
def setup_datasets(ngrams=2, vocab_train=None, vocab_test=None, include_unk=False):
train_csv_path = 'data/ag_news_csv/train.csv'
test_csv_path = 'data/ag_news_csv/test.csv'
if vocab_train is None:
vocab_train = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
if vocab_test is None:
vocab_test = build_vocab_from_iterator(_csv_iterator(test_csv_path, ngrams))
else:
if not isinstance(vocab, Vocab):
raise TypeError("Passed vocabulary is not of type Vocab")
train_data, train_labels = _create_data_from_iterator(
vocab_train, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
test_data, test_labels = _create_data_from_iterator(
vocab_test, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
if len(train_labels ^ test_labels) > 0:
raise ValueError("Training and test labels don't match")
return (TextClassificationDataset(vocab_train, train_data, train_labels),
TextClassificationDataset(vocab_test, test_data, test_labels))
train_dataset, test_dataset = setup_datasets() print("train_dataset", train_dataset)`
整个案例的实现可分为以下五个步骤