diff --git a/arch/x86/kvm/svm/sev.c b/arch/x86/kvm/svm/sev.c
index 75e0b21ad07c9e89a54131be45bdcca3fa88af4e..61c4bf4b3a0a87fc815faa1e706846116b7af35d 100644
--- a/arch/x86/kvm/svm/sev.c
+++ b/arch/x86/kvm/svm/sev.c
@@ -595,43 +595,50 @@ static int sev_es_sync_vmsa(struct vcpu_svm *svm)
 	return 0;
 }
 
-static int sev_launch_update_vmsa(struct kvm *kvm, struct kvm_sev_cmd *argp)
+static int __sev_launch_update_vmsa(struct kvm *kvm, struct kvm_vcpu *vcpu,
+				    int *error)
 {
-	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
 	struct sev_data_launch_update_vmsa vmsa;
+	struct vcpu_svm *svm = to_svm(vcpu);
+	int ret;
+
+	/* Perform some pre-encryption checks against the VMSA */
+	ret = sev_es_sync_vmsa(svm);
+	if (ret)
+		return ret;
+
+	/*
+	 * The LAUNCH_UPDATE_VMSA command will perform in-place encryption of
+	 * the VMSA memory content (i.e it will write the same memory region
+	 * with the guest's key), so invalidate it first.
+	 */
+	clflush_cache_range(svm->vmsa, PAGE_SIZE);
+
+	vmsa.reserved = 0;
+	vmsa.handle = to_kvm_svm(kvm)->sev_info.handle;
+	vmsa.address = __sme_pa(svm->vmsa);
+	vmsa.len = PAGE_SIZE;
+	return sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_VMSA, &vmsa, error);
+}
+
+static int sev_launch_update_vmsa(struct kvm *kvm, struct kvm_sev_cmd *argp)
+{
 	struct kvm_vcpu *vcpu;
 	int i, ret;
 
 	if (!sev_es_guest(kvm))
 		return -ENOTTY;
 
-	vmsa.reserved = 0;
-
 	kvm_for_each_vcpu(i, vcpu, kvm) {
-		struct vcpu_svm *svm = to_svm(vcpu);
-
-		/* Perform some pre-encryption checks against the VMSA */
-		ret = sev_es_sync_vmsa(svm);
+		ret = mutex_lock_killable(&vcpu->mutex);
 		if (ret)
 			return ret;
 
-		/*
-		 * The LAUNCH_UPDATE_VMSA command will perform in-place
-		 * encryption of the VMSA memory content (i.e it will write
-		 * the same memory region with the guest's key), so invalidate
-		 * it first.
-		 */
-		clflush_cache_range(svm->vmsa, PAGE_SIZE);
+		ret = __sev_launch_update_vmsa(kvm, vcpu, &argp->error);
 
-		vmsa.handle = sev->handle;
-		vmsa.address = __sme_pa(svm->vmsa);
-		vmsa.len = PAGE_SIZE;
-		ret = sev_issue_cmd(kvm, SEV_CMD_LAUNCH_UPDATE_VMSA, &vmsa,
-				    &argp->error);
+		mutex_unlock(&vcpu->mutex);
 		if (ret)
 			return ret;
-
-		svm->vcpu.arch.guest_state_protected = true;
 	}
 
 	return 0;