aboutsummaryrefslogtreecommitdiff
path: root/src/burn/engine/payload.cpp
blob: 270da6aa3a81a57959db2ec25a76f49813e19b2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
// Copyright (c) .NET Foundation and contributors. All rights reserved. Licensed under the Microsoft Reciprocal License. See LICENSE.TXT file in the project root for full license information.

#include "precomp.h"


// internal function declarations


// function definitions

extern "C" HRESULT PayloadsParseFromXml(
    __in BURN_PAYLOADS* pPayloads,
    __in_opt BURN_CONTAINERS* pContainers,
    __in_opt BURN_PAYLOAD_GROUP* pLayoutPayloads,
    __in IXMLDOMNode* pixnBundle
    )
{
    HRESULT hr = S_OK;
    IXMLDOMNodeList* pixnNodes = NULL;
    IXMLDOMNode* pixnNode = NULL;
    DWORD cNodes = 0;
    LPWSTR scz = NULL;
    BOOL fChainPayload = pContainers && pLayoutPayloads; // These are required when parsing chain payloads.
    BOOL fValidFileSize = FALSE;
    size_t cByteOffset = fChainPayload ? offsetof(BURN_PAYLOAD, sczKey) : offsetof(BURN_PAYLOAD, sczSourcePath);
    BOOL fXmlFound = FALSE;

    // select payload nodes
    hr = XmlSelectNodes(pixnBundle, L"Payload", &pixnNodes);
    ExitOnFailure(hr, "Failed to select payload nodes.");

    // get payload node count
    hr = pixnNodes->get_length((long*)&cNodes);
    ExitOnFailure(hr, "Failed to get payload node count.");

    if (!cNodes)
    {
        ExitFunction();
    }

    // allocate memory for payloads
    pPayloads->rgPayloads = (BURN_PAYLOAD*)MemAlloc(sizeof(BURN_PAYLOAD) * cNodes, TRUE);
    ExitOnNull(pPayloads->rgPayloads, hr, E_OUTOFMEMORY, "Failed to allocate memory for payload structs.");

    pPayloads->cPayloads = cNodes;

    // create dictionary for payloads
    hr = DictCreateWithEmbeddedKey(&pPayloads->sdhPayloads, pPayloads->cPayloads, reinterpret_cast<void**>(&pPayloads->rgPayloads), cByteOffset, DICT_FLAG_NONE);
    ExitOnFailure(hr, "Failed to create dictionary for payloads.");

    // parse payload elements
    for (DWORD i = 0; i < cNodes; ++i)
    {
        BURN_PAYLOAD* pPayload = &pPayloads->rgPayloads[i];
        fValidFileSize = FALSE;

        hr = XmlNextElement(pixnNodes, &pixnNode, NULL);
        ExitOnFailure(hr, "Failed to get next node.");

        // @Id
        hr = XmlGetAttributeEx(pixnNode, L"Id", &pPayload->sczKey);
        ExitOnRequiredXmlQueryFailure(hr, "Failed to get @Id.");

        // @FilePath
        hr = XmlGetAttributeEx(pixnNode, L"FilePath", &pPayload->sczFilePath);
        ExitOnRequiredXmlQueryFailure(hr, "Failed to get @FilePath.");

        // @SourcePath
        hr = XmlGetAttributeEx(pixnNode, L"SourcePath", &pPayload->sczSourcePath);
        ExitOnRequiredXmlQueryFailure(hr, "Failed to get @SourcePath.");

        if (!fChainPayload)
        {
            // All non-chain payloads are embedded in the UX container.
            pPayload->packaging = BURN_PAYLOAD_PACKAGING_EMBEDDED;
        }
        else
        {
            // @Packaging
            hr = XmlGetAttributeEx(pixnNode, L"Packaging", &scz);
            ExitOnRequiredXmlQueryFailure(hr, "Failed to get @Packaging.");

            if (CSTR_EQUAL == ::CompareStringW(LOCALE_INVARIANT, 0, scz, -1, L"embedded", -1))
            {
                pPayload->packaging = BURN_PAYLOAD_PACKAGING_EMBEDDED;
            }
            else if (CSTR_EQUAL == ::CompareStringW(LOCALE_INVARIANT, 0, scz, -1, L"external", -1))
            {
                pPayload->packaging = BURN_PAYLOAD_PACKAGING_EXTERNAL;
            }
            else
            {
                ExitWithRootFailure(hr, E_INVALIDARG, "Invalid value for @Packaging: %ls", scz);
            }

            // @Container
            hr = XmlGetAttributeEx(pixnNode, L"Container", &scz);
            ExitOnOptionalXmlQueryFailure(hr, fXmlFound, "Failed to get @Container.");

            if (fXmlFound)
            {
                // find container
                hr = ContainerFindById(pContainers, scz, &pPayload->pContainer);
                ExitOnFailure(hr, "Failed to find container: %ls", scz);

                pPayload->pContainer->cParsedPayloads += 1;
            }
            else if (BURN_PAYLOAD_PACKAGING_EMBEDDED == pPayload->packaging)
            {
                ExitWithRootFailure(hr, E_NOTFOUND, "@Container is required for embedded payload.");
            }

            // @LayoutOnly
            hr = XmlGetYesNoAttribute(pixnNode, L"LayoutOnly", &pPayload->fLayoutOnly);
            ExitOnOptionalXmlQueryFailure(hr, fXmlFound, "Failed to get @LayoutOnly.");

            // @DownloadUrl
            hr = XmlGetAttributeEx(pixnNode, L"DownloadUrl", &pPayload->downloadSource.sczUrl);
            ExitOnOptionalXmlQueryFailure(hr, fXmlFound, "Failed to get @DownloadUrl.");

            // @FileSize
            hr = XmlGetAttributeEx(pixnNode, L"FileSize", &scz);
            ExitOnOptionalXmlQueryFailure(hr, fXmlFound, "Failed to get @FileSize.");

            if (fXmlFound)
            {
                hr = StrStringToUInt64(scz, 0, &pPayload->qwFileSize);
                ExitOnFailure(hr, "Failed to parse @FileSize.");

                fValidFileSize = TRUE;
            }

            // @CertificateAuthorityKeyIdentifier
            hr = XmlGetAttributeEx(pixnNode, L"CertificateRootPublicKeyIdentifier", &scz);
            ExitOnOptionalXmlQueryFailure(hr, fXmlFound, "Failed to get @CertificateRootPublicKeyIdentifier.");

            if (fXmlFound)
            {
                hr = StrAllocHexDecode(scz, &pPayload->pbCertificateRootPublicKeyIdentifier, &pPayload->cbCertificateRootPublicKeyIdentifier);
                ExitOnFailure(hr, "Failed to hex decode @CertificateRootPublicKeyIdentifier.");

                pPayload->verification = BURN_PAYLOAD_VERIFICATION_AUTHENTICODE;
            }

            // @CertificateThumbprint
            hr = XmlGetAttributeEx(pixnNode, L"CertificateRootThumbprint", &scz);
            ExitOnOptionalXmlQueryFailure(hr, fXmlFound, "Failed to get @CertificateRootThumbprint.");

            if (fXmlFound)
            {
                hr = StrAllocHexDecode(scz, &pPayload->pbCertificateRootThumbprint, &pPayload->cbCertificateRootThumbprint);
                ExitOnFailure(hr, "Failed to hex decode @CertificateRootThumbprint.");
            }

            // @Hash
            hr = XmlGetAttributeEx(pixnNode, L"Hash", &scz);
            ExitOnOptionalXmlQueryFailure(hr, fXmlFound, "Failed to get @Hash.");

            if (fXmlFound)
            {
                hr = StrAllocHexDecode(scz, &pPayload->pbHash, &pPayload->cbHash);
                ExitOnFailure(hr, "Failed to hex decode the Payload/@Hash.");

                if (BURN_PAYLOAD_VERIFICATION_NONE == pPayload->verification)
                {
                    pPayload->verification = BURN_PAYLOAD_VERIFICATION_HASH;
                }
            }

            if (BURN_PAYLOAD_VERIFICATION_NONE == pPayload->verification)
            {
                ExitWithRootFailure(hr, E_INVALIDDATA, "There was no verification information for payload: %ls", pPayload->sczKey);
            }
            else if (BURN_PAYLOAD_VERIFICATION_HASH == pPayload->verification && !fValidFileSize)
            {
                ExitWithRootFailure(hr, E_INVALIDDATA, "File size is required when verifying by hash for payload: %ls", pPayload->sczKey);
            }

            if (pPayload->fLayoutOnly)
            {
                hr = MemEnsureArraySize(reinterpret_cast<LPVOID*>(&pLayoutPayloads->rgItems), pLayoutPayloads->cItems + 1, sizeof(BURN_PAYLOAD_GROUP_ITEM), 5);
                ExitOnFailure(hr, "Failed to allocate memory for layout payloads.");

                pLayoutPayloads->rgItems[pLayoutPayloads->cItems].pPayload = pPayload;
                ++pLayoutPayloads->cItems;

                pLayoutPayloads->qwTotalSize += pPayload->qwFileSize;
            }
        }

        hr = DictAddValue(pPayloads->sdhPayloads, pPayload);
        ExitOnFailure(hr, "Failed to add payload to payloads dictionary.");

        // prepare next iteration
        ReleaseNullObject(pixnNode);
    }

    hr = S_OK;

    if (pContainers && pContainers->cContainers)
    {
        for (DWORD i = 0; i < pPayloads->cPayloads; ++i)
        {
            BURN_PAYLOAD* pPayload = &pPayloads->rgPayloads[i];
            BURN_CONTAINER* pContainer = pPayload->pContainer;

            if (!pContainer)
            {
                continue;
            }
            else if (!pContainer->sdhPayloads)
            {
                hr = DictCreateWithEmbeddedKey(&pContainer->sdhPayloads, pContainer->cParsedPayloads, NULL, offsetof(BURN_PAYLOAD, sczSourcePath), DICT_FLAG_NONE);
                ExitOnFailure(hr, "Failed to create dictionary for container payloads.");
            }

            hr = DictAddValue(pContainer->sdhPayloads, pPayload);
            ExitOnFailure(hr, "Failed to add payload to container dictionary.");
        }
    }

LExit:
    ReleaseObject(pixnNodes);
    ReleaseObject(pixnNode);
    ReleaseStr(scz);

    return hr;
}

extern "C" void PayloadUninitialize(
    __in BURN_PAYLOAD* pPayload
    )
{
    if (pPayload)
    {
        ReleaseStr(pPayload->sczKey);
        ReleaseStr(pPayload->sczFilePath);
        ReleaseMem(pPayload->pbHash);
        ReleaseMem(pPayload->pbCertificateRootThumbprint);
        ReleaseMem(pPayload->pbCertificateRootPublicKeyIdentifier);
        ReleaseStr(pPayload->sczSourcePath);
        ReleaseFileHandle(pPayload->hLocalFile);
        ReleaseStr(pPayload->sczLocalFilePath);
        ReleaseStr(pPayload->sczFailedLocalAcquisitionPath);
        ReleaseStr(pPayload->downloadSource.sczUrl);
        ReleaseStr(pPayload->downloadSource.sczUser);
        ReleaseStr(pPayload->downloadSource.sczPassword);
        ReleaseStr(pPayload->downloadSource.sczAuthorizationHeader);
        ReleaseStr(pPayload->sczUnverifiedPath);
    }
}

extern "C" void PayloadsUninitialize(
    __in BURN_PAYLOADS* pPayloads
    )
{
    if (pPayloads->rgPayloads)
    {
        for (DWORD i = 0; i < pPayloads->cPayloads; ++i)
        {
            PayloadUninitialize(pPayloads->rgPayloads + i);
        }
        MemFree(pPayloads->rgPayloads);
    }

    ReleaseDict(pPayloads->sdhPayloads);

    // clear struct
    memset(pPayloads, 0, sizeof(BURN_PAYLOADS));
}

extern "C" HRESULT PayloadExtractUXContainer(
    __in BURN_PAYLOADS* pPayloads,
    __in BURN_CONTAINER_CONTEXT* pContainerContext,
    __in_z LPCWSTR wzTargetDir
    )
{
    HRESULT hr = S_OK;
    LPWSTR sczStreamName = NULL;
    LPWSTR sczDirectory = NULL;
    BURN_PAYLOAD* pPayload = NULL;
    HANDLE hTargetFile = INVALID_HANDLE_VALUE;

    // extract all payloads
    for (;;)
    {
        // get next stream
        hr = ContainerNextStream(pContainerContext, &sczStreamName);
        if (E_NOMOREITEMS == hr)
        {
            hr = S_OK;
            break;
        }
        ExitOnFailure(hr, "Failed to get next stream.");

        // find payload by stream name
        hr = PayloadFindEmbeddedBySourcePath(pPayloads->sdhPayloads, sczStreamName, &pPayload);
        ExitOnFailure(hr, "Failed to find embedded payload: %ls", sczStreamName);

        // make file path
        hr = PathConcatRelativeToFullyQualifiedBase(wzTargetDir, pPayload->sczFilePath, &pPayload->sczLocalFilePath);
        ExitOnFailure(hr, "Failed to concat file paths.");

        // extract file
        hr = PathGetDirectory(pPayload->sczLocalFilePath, &sczDirectory);
        ExitOnFailure(hr, "Failed to get directory portion of local file path");

        hr = DirEnsureExists(sczDirectory, NULL);
        ExitOnFailure(hr, "Failed to ensure directory exists");

        hTargetFile = ::CreateFileW(pPayload->sczLocalFilePath, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
        ExitOnInvalidHandleWithLastError(hTargetFile, hr, "Failed to create file: %ls", pPayload->sczLocalFilePath);

        hr = ContainerStreamToHandle(pContainerContext, hTargetFile);
        ExitOnFailure(hr, "Failed to extract file.");

        // Reopen the payload for read-only access to prevent the file from being removed or tampered with while the BA is running.
        ReleaseFileHandle(hTargetFile);

        hr = FileCreateWithRetry(pPayload->sczLocalFilePath, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 30, 100, &pPayload->hLocalFile);
        ExitOnFailure(hr, "Failed to open file: %ls", pPayload->sczLocalFilePath);

        // flag that the payload has been acquired
        pPayload->state = BURN_PAYLOAD_STATE_ACQUIRED;
    }

    // locate any payloads that were not extracted
    for (DWORD i = 0; i < pPayloads->cPayloads; ++i)
    {
        pPayload = &pPayloads->rgPayloads[i];

        // if the payload has not been acquired
        if (BURN_PAYLOAD_STATE_ACQUIRED > pPayload->state)
        {
            ExitWithRootFailure(hr, E_INVALIDDATA, "Payload was not found in container: %ls", pPayload->sczKey);
        }
    }

LExit:
    ReleaseFileHandle(hTargetFile);
    ReleaseStr(sczStreamName);
    ReleaseStr(sczDirectory);

    return hr;
}

extern "C" HRESULT PayloadFindById(
    __in BURN_PAYLOADS* pPayloads,
    __in_z LPCWSTR wzId,
    __out BURN_PAYLOAD** ppPayload
    )
{
    HRESULT hr = S_OK;

    hr = DictGetValue(pPayloads->sdhPayloads, wzId, reinterpret_cast<void**>(ppPayload));

    return hr;
}

extern "C" HRESULT PayloadFindEmbeddedBySourcePath(
    __in STRINGDICT_HANDLE sdhPayloads,
    __in_z LPCWSTR wzStreamName,
    __out BURN_PAYLOAD** ppPayload
    )
{
    HRESULT hr = S_OK;

    hr = DictGetValue(sdhPayloads, wzStreamName, reinterpret_cast<void**>(ppPayload));

    return hr;
}


// internal function definitions