From b1d1e523f5cdadce0cbf105179b33c014d5ec9eb Mon Sep 17 00:00:00 2001
From: Sean Hall <r.sean.hall@gmail.com>
Date: Fri, 16 Apr 2021 13:38:16 -0500
Subject: Add OnCachePayloadExtract*.

---
 src/engine/apply.cpp          | 58 +++++++++++++++++++++++++++--
 src/engine/container.h        |  1 +
 src/engine/plan.cpp           |  1 +
 src/engine/userexperience.cpp | 87 +++++++++++++++++++++++++++++++++++++++++++
 src/engine/userexperience.h   | 19 ++++++++++
 5 files changed, 162 insertions(+), 4 deletions(-)

(limited to 'src/engine')

diff --git a/src/engine/apply.cpp b/src/engine/apply.cpp
index f80baf26..ab7fa077 100644
--- a/src/engine/apply.cpp
+++ b/src/engine/apply.cpp
@@ -16,6 +16,7 @@ enum BURN_CACHE_PROGRESS_TYPE
     BURN_CACHE_PROGRESS_TYPE_ACQUIRE,
     BURN_CACHE_PROGRESS_TYPE_VERIFY,
     BURN_CACHE_PROGRESS_TYPE_CONTAINER_OR_PAYLOAD_VERIFY,
+    BURN_CACHE_PROGRESS_TYPE_EXTRACT,
 };
 
 // structs
@@ -43,6 +44,7 @@ typedef struct _BURN_CACHE_PROGRESS_CONTEXT
     BURN_CONTAINER* pContainer;
     BURN_PACKAGE* pPackage;
     BURN_PAYLOAD_GROUP_ITEM* pPayloadGroupItem;
+    BURN_PAYLOAD* pPayload;
 
     BOOL fCancel;
     HRESULT hrError;
@@ -875,6 +877,12 @@ static HRESULT ApplyExtractContainer(
         pContainer->qwCommittedCacheProgress = 0;
     }
 
+    if (pContainer->qwCommittedExtractProgress)
+    {
+        pContext->qwSuccessfulCacheProgress -= pContainer->qwCommittedExtractProgress;
+        pContainer->qwCommittedExtractProgress = 0;
+    }
+
     if (!pContainer->fActuallyAttached)
     {
         hr = ApplyAcquireContainerOrPayload(pContext, pContainer, NULL, NULL);
@@ -890,8 +898,18 @@ static HRESULT ApplyExtractContainer(
         CacheSetLastUsedSource(pContext->pVariables, pContext->sczLastUsedFolderCandidate, pContainer->sczFilePath);
     }
 
-    pContext->qwSuccessfulCacheProgress += pContainer->qwExtractSizeTotal;
-    pContainer->qwCommittedCacheProgress += pContainer->qwExtractSizeTotal;
+    if (pContainer->qwExtractSizeTotal < pContainer->qwCommittedExtractProgress)
+    {
+        AssertSz(FALSE, "Container extracted more than planned.");
+        pContext->qwSuccessfulCacheProgress -= pContainer->qwCommittedExtractProgress;
+        pContext->qwSuccessfulCacheProgress += pContainer->qwExtractSizeTotal;
+    }
+    else
+    {
+        pContext->qwSuccessfulCacheProgress += pContainer->qwExtractSizeTotal - pContainer->qwCommittedExtractProgress;
+    }
+
+    pContainer->qwCommittedExtractProgress = pContainer->qwExtractSizeTotal;
 
 LExit:
     ReleaseNullStr(pContext->sczLastUsedFolderCandidate);
@@ -1100,6 +1118,11 @@ static HRESULT ExtractContainer(
     BURN_CONTAINER_CONTEXT context = { };
     HANDLE hContainerHandle = INVALID_HANDLE_VALUE;
     LPWSTR sczExtractPayloadId = NULL;
+    BURN_CACHE_PROGRESS_CONTEXT progress = { };
+
+    progress.pCacheContext = pContext;
+    progress.pContainer = pContainer;
+    progress.type = BURN_CACHE_PROGRESS_TYPE_EXTRACT;
 
     // If the container is actually attached, then it was planned to be acquired through hSourceEngineFile.
     if (pContainer->fActuallyAttached)
@@ -1119,11 +1142,29 @@ static HRESULT ExtractContainer(
             BURN_PAYLOAD* pExtract = pContext->pPayloads->rgPayloads + iExtract;
             if (pExtract->sczUnverifiedPath && pExtract->cRemainingInstances && CSTR_EQUAL == ::CompareStringW(LOCALE_INVARIANT, 0, sczExtractPayloadId, -1, pExtract->sczSourcePath, -1))
             {
+                progress.pPayload = pExtract;
+
                 hr = PreparePayloadDestinationPath(pExtract->sczUnverifiedPath);
                 ExitOnFailure(hr, "Failed to prepare payload destination path: %ls", pExtract->sczUnverifiedPath);
 
+                hr = UserExperienceOnCachePayloadExtractBegin(pContext->pUX, pContainer->sczId, pExtract->sczKey);
+                if (FAILED(hr))
+                {
+                    UserExperienceOnCachePayloadExtractComplete(pContext->pUX, pContainer->sczId, pExtract->sczKey, hr);
+                    ExitOnRootFailure(hr, "BA aborted cache payload extract begin.");
+                }
+
                 // TODO: Send progress when extracting stream to file.
                 hr = ContainerStreamToFile(&context, pExtract->sczUnverifiedPath);
+                // Error handling happens after sending complete message to BA.
+
+                // If succeeded, send 100% complete here to make sure progress was sent to the BA.
+                if (SUCCEEDED(hr))
+                {
+                    hr = CompleteCacheProgress(&progress, pExtract->qwFileSize);
+                }
+
+                UserExperienceOnCachePayloadExtractComplete(pContext->pUX, pContainer->sczId, pExtract->sczKey, hr);
                 ExitOnFailure(hr, "Failed to extract payload: %ls from container: %ls", sczExtractPayloadId, pContainer->sczId);
 
                 fExtracted = TRUE;
@@ -1754,7 +1795,12 @@ static HRESULT CompleteCacheProgress(
     if (PROGRESS_CONTINUE == dwResult)
     {
         pContext->pCacheContext->qwSuccessfulCacheProgress += qwFileSize;
-        if (pContext->pContainer)
+
+        if (pContext->pPayload)
+        {
+            pContext->pContainer->qwCommittedExtractProgress += qwFileSize;
+        }
+        else if (pContext->pContainer)
         {
             pContext->pContainer->qwCommittedCacheProgress += qwFileSize;
         }
@@ -1800,7 +1846,7 @@ static DWORD CALLBACK CacheProgressRoutine(
     DWORD dwResult = PROGRESS_CONTINUE;
     BURN_CACHE_PROGRESS_CONTEXT* pProgress = static_cast<BURN_CACHE_PROGRESS_CONTEXT*>(lpData);
     LPCWSTR wzPackageOrContainerId = pProgress->pContainer ? pProgress->pContainer->sczId : pProgress->pPackage ? pProgress->pPackage->sczId : NULL;
-    LPCWSTR wzPayloadId = pProgress->pPayloadGroupItem ? pProgress->pPayloadGroupItem->pPayload->sczKey : NULL;
+    LPCWSTR wzPayloadId = pProgress->pPayloadGroupItem ? pProgress->pPayloadGroupItem->pPayload->sczKey : pProgress->pPayload ? pProgress->pPayload->sczKey : NULL;
     DWORD64 qwCacheProgress = pProgress->pCacheContext->qwSuccessfulCacheProgress + TotalBytesTransferred.QuadPart;
     if (qwCacheProgress > pProgress->pCacheContext->qwTotalCacheSize)
     {
@@ -1823,6 +1869,10 @@ static DWORD CALLBACK CacheProgressRoutine(
         hr = UserExperienceOnCacheContainerOrPayloadVerifyProgress(pProgress->pCacheContext->pUX, wzPackageOrContainerId, wzPayloadId, TotalBytesTransferred.QuadPart, TotalFileSize.QuadPart, dwOverallPercentage);
         ExitOnRootFailure(hr, "BA aborted container or payload verify: %ls", wzPayloadId);
         break;
+    case BURN_CACHE_PROGRESS_TYPE_EXTRACT:
+        hr = UserExperienceOnCachePayloadExtractProgress(pProgress->pCacheContext->pUX, wzPackageOrContainerId, wzPayloadId, TotalBytesTransferred.QuadPart, TotalFileSize.QuadPart, dwOverallPercentage);
+        ExitOnRootFailure(hr, "BA aborted extract container: %ls, payload: %ls", wzPackageOrContainerId, wzPayloadId);
+        break;
     }
 
 LExit:
diff --git a/src/engine/container.h b/src/engine/container.h
index 7c5c2b5f..3174eb38 100644
--- a/src/engine/container.h
+++ b/src/engine/container.h
@@ -78,6 +78,7 @@ typedef struct _BURN_CONTAINER
     LPWSTR sczUnverifiedPath;
     DWORD64 qwExtractSizeTotal;
     DWORD64 qwCommittedCacheProgress;
+    DWORD64 qwCommittedExtractProgress;
 } BURN_CONTAINER;
 
 typedef struct _BURN_CONTAINERS
diff --git a/src/engine/plan.cpp b/src/engine/plan.cpp
index 8421d87b..bf929835 100644
--- a/src/engine/plan.cpp
+++ b/src/engine/plan.cpp
@@ -1831,6 +1831,7 @@ static void ResetPlannedContainerState(
     pContainer->fPlanned = FALSE;
     pContainer->qwExtractSizeTotal = 0;
     pContainer->qwCommittedCacheProgress = 0;
+    pContainer->qwCommittedExtractProgress = 0;
 }
 
 static void ResetPlannedPayloadsState(
diff --git a/src/engine/userexperience.cpp b/src/engine/userexperience.cpp
index 02c67fc5..279a00b5 100644
--- a/src/engine/userexperience.cpp
+++ b/src/engine/userexperience.cpp
@@ -758,6 +758,93 @@ LExit:
     return hr;
 }
 
+EXTERN_C BAAPI UserExperienceOnCachePayloadExtractBegin(
+    __in BURN_USER_EXPERIENCE* pUserExperience,
+    __in_z_opt LPCWSTR wzContainerId,
+    __in_z_opt LPCWSTR wzPayloadId
+    )
+{
+    HRESULT hr = S_OK;
+    BA_ONCACHEPAYLOADEXTRACTBEGIN_ARGS args = { };
+    BA_ONCACHEPAYLOADEXTRACTBEGIN_RESULTS results = { };
+
+    args.cbSize = sizeof(args);
+    args.wzContainerId = wzContainerId;
+    args.wzPayloadId = wzPayloadId;
+
+    results.cbSize = sizeof(results);
+
+    hr = SendBAMessage(pUserExperience, BOOTSTRAPPER_APPLICATION_MESSAGE_ONCACHEPAYLOADEXTRACTBEGIN, &args, &results);
+    ExitOnFailure(hr, "BA OnCachePayloadExtractBegin failed.");
+
+    if (results.fCancel)
+    {
+        hr = HRESULT_FROM_WIN32(ERROR_INSTALL_USEREXIT);
+    }
+
+LExit:
+    return hr;
+}
+
+EXTERN_C BAAPI UserExperienceOnCachePayloadExtractComplete(
+    __in BURN_USER_EXPERIENCE* pUserExperience,
+    __in_z_opt LPCWSTR wzContainerId,
+    __in_z_opt LPCWSTR wzPayloadId,
+    __in HRESULT hrStatus
+    )
+{
+    HRESULT hr = S_OK;
+    BA_ONCACHEPAYLOADEXTRACTCOMPLETE_ARGS args = { };
+    BA_ONCACHEPAYLOADEXTRACTCOMPLETE_RESULTS results = { };
+
+    args.cbSize = sizeof(args);
+    args.wzContainerId = wzContainerId;
+    args.wzPayloadId = wzPayloadId;
+    args.hrStatus = hrStatus;
+
+    results.cbSize = sizeof(results);
+
+    hr = SendBAMessage(pUserExperience, BOOTSTRAPPER_APPLICATION_MESSAGE_ONCACHEPAYLOADEXTRACTCOMPLETE, &args, &results);
+    ExitOnFailure(hr, "BA OnCachePayloadExtractComplete failed.");
+
+LExit:
+    return hr;
+}
+
+EXTERN_C BAAPI UserExperienceOnCachePayloadExtractProgress(
+    __in BURN_USER_EXPERIENCE* pUserExperience,
+    __in_z_opt LPCWSTR wzContainerId,
+    __in_z_opt LPCWSTR wzPayloadId,
+    __in DWORD64 dw64Progress,
+    __in DWORD64 dw64Total,
+    __in DWORD dwOverallPercentage
+    )
+{
+    HRESULT hr = S_OK;
+    BA_ONCACHEPAYLOADEXTRACTPROGRESS_ARGS args = { };
+    BA_ONCACHEPAYLOADEXTRACTPROGRESS_RESULTS results = { };
+
+    args.cbSize = sizeof(args);
+    args.wzContainerId = wzContainerId;
+    args.wzPayloadId = wzPayloadId;
+    args.dw64Progress = dw64Progress;
+    args.dw64Total = dw64Total;
+    args.dwOverallPercentage = dwOverallPercentage;
+
+    results.cbSize = sizeof(results);
+
+    hr = SendBAMessage(pUserExperience, BOOTSTRAPPER_APPLICATION_MESSAGE_ONCACHEPAYLOADEXTRACTPROGRESS, &args, &results);
+    ExitOnFailure(hr, "BA OnCachePayloadExtractProgress failed.");
+
+    if (results.fCancel)
+    {
+        hr = HRESULT_FROM_WIN32(ERROR_INSTALL_USEREXIT);
+    }
+
+LExit:
+    return hr;
+}
+
 EXTERN_C BAAPI UserExperienceOnCacheVerifyBegin(
     __in BURN_USER_EXPERIENCE* pUserExperience,
     __in_z_opt LPCWSTR wzPackageOrContainerId,
diff --git a/src/engine/userexperience.h b/src/engine/userexperience.h
index d3dfb810..a848e60d 100644
--- a/src/engine/userexperience.h
+++ b/src/engine/userexperience.h
@@ -196,6 +196,25 @@ BAAPI UserExperienceOnCachePackageComplete(
     __in HRESULT hrStatus,
     __inout BOOTSTRAPPER_CACHEPACKAGECOMPLETE_ACTION* pAction
     );
+BAAPI UserExperienceOnCachePayloadExtractBegin(
+    __in BURN_USER_EXPERIENCE* pUserExperience,
+    __in_z_opt LPCWSTR wzContainerId,
+    __in_z_opt LPCWSTR wzPayloadId
+    );
+BAAPI UserExperienceOnCachePayloadExtractComplete(
+    __in BURN_USER_EXPERIENCE* pUserExperience,
+    __in_z_opt LPCWSTR wzContainerId,
+    __in_z_opt LPCWSTR wzPayloadId,
+    __in HRESULT hrStatus
+    );
+BAAPI UserExperienceOnCachePayloadExtractProgress(
+    __in BURN_USER_EXPERIENCE* pUserExperience,
+    __in_z_opt LPCWSTR wzContainerId,
+    __in_z_opt LPCWSTR wzPayloadId,
+    __in DWORD64 dw64Progress,
+    __in DWORD64 dw64Total,
+    __in DWORD dwOverallPercentage
+    );
 BAAPI UserExperienceOnCacheVerifyBegin(
     __in BURN_USER_EXPERIENCE* pUserExperience,
     __in_z_opt LPCWSTR wzPackageOrContainerId,
-- 
cgit v1.2.3-55-g6feb