【Python教程】【传知代码】基于图神经网络的知识追踪方法(论文复现)

零 Python教程评论68字数 4166阅读13分53秒阅读模式

概述

论文链接提出了一种基于图神经网络的知识追踪方法,称为基于图的知识追踪(GKT)。将知识结构构建为图,其中节点对应于概念,边对应于它们之间的关系,将知识追踪任务构建为图神经网络中的时间序列节点级分类问题。在两个开放数据集上的实证验证表明,方法可以更好地预测学生的表现,并且该模型比先前的方法具有更可解释的预测,其贡献如下:

1)展示了知识追踪可以重新构想为图神经网络的应用。
2)为了实现需要输入模型的图结构,在许多情况下并不明确的情况下,我们提出了各种方法,并使用实证验证进行了比较。
3)证明了所提出的方法比先前的方法更准确和可解释的预测。文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

下图是本文提出GKT的体系结构:文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

1715588979610_image.png文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

模型聚合了回答的概念及其相邻概念的隐藏状态和嵌入。这种聚合使用隐藏状态、表示正确和错误答案的输入向量 xt​,以及概念及其回答的嵌入矩阵Ex 和Ec 进行:文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

1715589027917_image.png文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

接下来,模型根据聚集的特征和知识图结构更新隐藏状态。这一步骤确保模型融合了当前概念及其在知识图中的相邻节点的信息:文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

1715589082085_image.png文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

最后,模型输出学生在下一时间步正确回答每个概念的预测概率:文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

1715589120879_image.png文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

演示效果

使用了学生数学练习日志的两个开放数据集:ASSISTments 2009-2010“skill-builder”由在线教育服务ASSISTments1(以下称为“ASSISTments”)提供和Bridge to Algebra 2006-2007[19]用于KDDCup教育数据挖掘挑战赛(以下简称“KDDCup”)。两个数据集上的每一个习题均被赋予一个人预设知识概念标签。文章源自灵鲨社区-https://www.0s52.com/bcjc/pythonjc/15352.html

利用指定的条件对各数据集进行预处理。对ASSISTments来说,把同时作答的日志合并为一个整体,然后抽取出和命名概念标签有关的日志并最终抽取出和至少十次作答的概念标签有关。在KDDCup中,我们将问题与步骤的结合看作是答案,接着从与命名且非哑元的概念标签有关的日志中提取信息,最终从至少10次回答的概念标签中提取相关日志。鉴于标签的频繁出现,将多个回答日志整合为一组有助于避免不公正的高预测表现。排除未命名和虚拟概念标签能够去除噪音。利用每一个概念标签被答出的次数为日志设定一个阈值,从而保证足够多的日志被用于去除噪音。在使用上述条件对数据集进行预处理后,为ASSISTments数据集获得了62,955个日志,由1,000名学生和101项技能组成,并为KDDCup数据集获得了98,200条日志,由1,000名学生和211项技能组成:

1715589172887_image.png

处理数据集

1715589219906_image.png

进行训练

1715589237948_image.png

实验结果

1715589251758_image.png

核心代码

接下来这段核心代码是一个基于PyTorch实现的GKT(Graph Knowledge Tracing)模型的训练与评估过程。主要包括以下几个部分:

1)__init__方法:初始化GKT模型,包括传入知识点数量、图结构、隐藏层节点数等参数,并创建GKTNet模型实例。

2)train方法:用于模型训练,接收训练数据和测试数据(可选),并根据指定的epoch数、设备类型和学习率进行训练。在每个epoch内,遍历训练数据,将数据转移到指定设备上,然后通过GKT模型进行预测,计算损失函数并进行反向传播更新模型参数。如果提供了测试数据,还会调用eval方法进行模型评估。

3)eval方法:用于模型评估,接收测试数据和设备类型,并在评估过程中将数据转移到指定设备上,然后通过GKT模型进行预测,并计算AUC和准确率作为评估指标。

4)save和load方法:用于模型的保存和加载,分别将模型参数保存到文件中,以及从文件中加载模型参数。

py

复制代码
class GKT(KTM):
    def __init__(self, ku_num, graph, hidden_num, net_params: dict = None, loss_params=None):
        super(GKT, self).__init__()
        self.gkt_model = GKTNet(
            ku_num,
            graph,
            hidden_num,
            **(net_params if net_params is not None else {})
        )
        # self.gkt_model = GKTNet(ku_num, graph, hidden_num)
        self.loss_params = loss_params if loss_params is not None else {}

    def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
        loss_function = SLMLoss(**self.loss_params)
        trainer = torch.optim.Adam(self.gkt_model.parameters(), lr)

        for e in range(epoch):
            losses = []
            for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e):
                # convert to device
                question: torch.Tensor = question.to(device)
                data: torch.Tensor = data.to(device)
                data_mask: torch.Tensor = data_mask.to(device)
                label: torch.Tensor = label.to(device)
                pick_index: torch.Tensor = pick_index.to(device)
                label_mask: torch.Tensor = label_mask.to(device)

                # real training
                predicted_response, _ = self.gkt_model(question, data, data_mask)

                loss = loss_function(predicted_response, pick_index, label, label_mask)

                # back propagation
                trainer.zero_grad()
                loss.backward()
                trainer.step()

                losses.append(loss.mean().item())
            print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses))))

            if test_data is not None:
                auc, accuracy = self.eval(test_data)
                print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))

    def eval(self, test_data, device="cpu") -> tuple:
        self.gkt_model.eval()
        y_true = []
        y_pred = []

        for (question, data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"):
            # convert to device
            question: torch.Tensor = question.to(device)
            data: torch.Tensor = data.to(device)
            data_mask: torch.Tensor = data_mask.to(device)
            label: torch.Tensor = label.to(device)
            pick_index: torch.Tensor = pick_index.to(device)
            label_mask: torch.Tensor = label_mask.to(device)

            # real evaluating
            output, _ = self.gkt_model(question, data, data_mask)
            output = output[:, :-1]
            output = pick(output, pick_index.to(output.device))
            pred = tensor2list(output)
            label = tensor2list(label)
            for i, length in enumerate(label_mask.numpy().tolist()):
                length = int(length)
                y_true.extend(label[i][:length])
                y_pred.extend(pred[i][:length])
        self.gkt_model.train()
        return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5)

    def save(self, filepath) -> ...:
        torch.save(self.gkt_model.state_dict(), filepath)
        logging.info("save parameters to %s" % filepath)

    def load(self, filepath):
        self.gkt_model.load_state_dict(torch.load(filepath))
        logging.info("load parameters from %s" % filepath)

整体而言,这段代码实现了一个GKT模型的训练与评估流程,采用了PyTorch作为深度学习框架,并提供了模型的保存和加载功能。

写在最后

虽然我们已经取得了一些初步的研究成果,但基于图神经网络的知识追踪方法仍然面临着许多挑战和机遇。首先,随着教育数据的不断增长和复杂化,如何构建更加高效、准确的知识图谱成为了一个亟待解决的问题。其次,如何结合学生的学习行为和历史数据,进一步优化图神经网络的模型结构和学习算法,提高知识追踪的精度和效率,也是我们需要深入研究的课题。

然而,正是这些挑战激发了我们对基于图神经网络的知识追踪方法的浓厚兴趣。我们相信,随着技术的不断进步和研究的深入,我们将能够克服这些挑战,并开发出更加先进、实用的知识追踪系统。这些系统不仅能够为教师提供更加精准的教学建议,帮助学生实现更高效的学习,还能够为教育资源的优化配置和个性化教育的推广提供有力支持。

零
  • 转载请务必保留本文链接:https://www.0s52.com/bcjc/pythonjc/15352.html
    本社区资源仅供用于学习和交流,请勿用于商业用途
    未经允许不得进行转载/复制/分享

发表评论