forked from nzc/dnn_ctr
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
13 lines (10 loc) · 674 Bytes
/
main.py
File metadata and controls
13 lines (10 loc) · 674 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
# -*- coding:utf-8 -*-
from utils import data_preprocess
from model import DeepFM
import torch
result_dict = data_preprocess.read_criteo_data('./data/tiny_train_input.csv', './data/category_emb.csv')
test_dict = data_preprocess.read_criteo_data('./data/tiny_test_input.csv', './data/category_emb.csv')
with torch.cuda.device(2):
deepfm = DeepFM.DeepFM(39,result_dict['feature_sizes'],verbose=True,use_cuda=True, weight_decay=0.0001,use_fm=True,use_ffm=False,use_deep=True).cuda()
deepfm.fit(result_dict['index'], result_dict['value'], result_dict['label'],
test_dict['index'], test_dict['value'], test_dict['label'],ealry_stopping=True,refit=True)