1 /**************************************************************
2  *
3  * Licensed to the Apache Software Foundation (ASF) under one
4  * or more contributor license agreements.  See the NOTICE file
5  * distributed with this work for additional information
6  * regarding copyright ownership.  The ASF licenses this file
7  * to you under the Apache License, Version 2.0 (the
8  * "License"); you may not use this file except in compliance
9  * with the License.  You may obtain a copy of the License at
10  *
11  *   http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing,
14  * software distributed under the License is distributed on an
15  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16  * KIND, either express or implied.  See the License for the
17  * specific language governing permissions and limitations
18  * under the License.
19  *
20  *************************************************************/
21 
22 #undef UNICODE
23 #undef _UNICODE
24 
25 #pragma once
26 
27 #ifdef _MSC_VER
28 #pragma warning(push, 1) /* disable warnings within system headers */
29 #endif
30 #define WIN32_LEAN_AND_MEAN
31 #include <windows.h>
32 #include <msiquery.h>
33 #include <imagehlp.h>
34 #include <tchar.h>
35 #include <strsafe.h>
36 #ifdef _MSC_VER
37 #pragma warning(pop)
38 #endif
39 
40 #include <malloc.h>
41 #include <time.h>
42 #include <string>
43 #include <hash_map>
44 
45 const DWORD PE_Signature = 0x00004550;
46 typedef std::pair< std::string, bool > StringPair;
47 typedef	std::hash_map< std::string, bool > ExcludeLibsMap;
48 
49 #ifdef DEBUG
OutputDebugStringFormat(LPCSTR pFormat,...)50 static void OutputDebugStringFormat( LPCSTR pFormat, ... )
51 {
52 	CHAR    buffer[1024];
53 	va_list args;
54 
55 	va_start( args, pFormat );
56 	StringCchVPrintfA( buffer, sizeof(buffer), pFormat, args );
57 	OutputDebugStringA( buffer );
58 }
59 #else
OutputDebugStringFormat(LPCSTR,...)60 static void OutputDebugStringFormat( LPCSTR, ... )
61 {
62 }
63 #endif
64 
IsValidHandle(HANDLE handle)65 static bool IsValidHandle( HANDLE handle )
66 {
67 	return NULL != handle && INVALID_HANDLE_VALUE != handle;
68 }
69 
GetMsiProperty(MSIHANDLE handle,const std::string & sProperty)70 static std::string GetMsiProperty(MSIHANDLE handle, const std::string& sProperty)
71 {
72 	std::string result;
73     TCHAR		szDummy[1] = TEXT("");
74     DWORD		nChars = 0;
75 
76     if (MsiGetProperty(handle, sProperty.c_str(), szDummy, &nChars) == ERROR_MORE_DATA)
77     {
78         DWORD nBytes = ++nChars * sizeof(TCHAR);
79         LPTSTR buffer = reinterpret_cast<LPTSTR>(_alloca(nBytes));
80         ZeroMemory( buffer, nBytes );
81         MsiGetProperty(handle, sProperty.c_str(), buffer, &nChars);
82         result = buffer;
83     }
84     return result;
85 }
86 
rebaseImage(const std::string & filePath,LPVOID address)87 static BOOL rebaseImage( const std::string& filePath, LPVOID address )
88 {
89 	ULONG ulOldImageSize;
90 	ULONG_PTR lpOldImageBase;
91 	ULONG ulNewImageSize;
92 	ULONG_PTR lpNewImageBase = reinterpret_cast<ULONG_PTR>(address);
93 
94 	BOOL bResult = ReBaseImage(
95 		filePath.c_str(),
96 		"",
97 		TRUE,
98 		FALSE,
99 		FALSE,
100 		0,
101 		&ulOldImageSize,
102 		&lpOldImageBase,
103 		&ulNewImageSize,
104 		&lpNewImageBase,
105 		(ULONG)time(NULL) );
106 
107 	return bResult;
108 }
109 
rebaseImage(MSIHANDLE,const std::string & sFilePath,LPVOID address)110 static BOOL rebaseImage( MSIHANDLE /*handle*/, const std::string& sFilePath, LPVOID address )
111 {
112 	std::string	mystr;
113 	mystr = "Full file: " + sFilePath;
114 
115 	BOOL bResult = rebaseImage( sFilePath, address );
116 
117 	if ( !bResult )
118 	{
119 		OutputDebugStringFormat( "Rebasing library %s failed", mystr.c_str() );
120 	}
121 
122 	return bResult;
123 }
124 
rebaseImagesInFolder(MSIHANDLE handle,const std::string & sPath,LPVOID address,ExcludeLibsMap & rExcludeMap)125 static BOOL rebaseImagesInFolder( MSIHANDLE handle, const std::string& sPath, LPVOID address, ExcludeLibsMap& rExcludeMap )
126 {
127 	std::string     sDir     = sPath;
128 	std::string	    sPattern = sPath + TEXT("*.dll");
129 	WIN32_FIND_DATA	aFindFileData;
130 
131 	HANDLE hFind = FindFirstFile( sPattern.c_str(), &aFindFileData );
132 	if ( IsValidHandle(hFind) )
133 	{
134 		BOOL fSuccess = false;
135 
136 		do
137 		{
138 			std::string sFileName = aFindFileData.cFileName;
139 			if ( rExcludeMap.find( sFileName ) == rExcludeMap.end() )
140 			{
141 				OutputDebugStringFormat( "Rebase library: %s", sFileName.c_str() );
142                 std::string	sLibFile = sDir +  sFileName;
143                 rebaseImage( handle, sLibFile, address );
144 			}
145 			else
146 			{
147 				OutputDebugStringFormat( "Exclude library %s from rebase", sFileName.c_str() );
148 			}
149 
150 			fSuccess = FindNextFile( hFind, &aFindFileData );
151 		}
152 		while ( fSuccess );
153 
154 		FindClose( hFind );
155 	}
156 
157 	return ERROR_SUCCESS;
158 }
159 
rebaseImages(MSIHANDLE handle,LPVOID pAddress,ExcludeLibsMap & rMap)160 static BOOL rebaseImages( MSIHANDLE handle, LPVOID pAddress, ExcludeLibsMap& rMap )
161 {
162 	std::string sInstallPath = GetMsiProperty(handle, TEXT("INSTALLLOCATION"));
163 
164 	std::string sBasisDir  = sInstallPath + TEXT("Basis\\program\\");
165 	std::string sOfficeDir = sInstallPath + TEXT("program\\");
166 	std::string sUreDir    = sInstallPath + TEXT("URE\\bin\\");
167 
168 	BOOL bResult = rebaseImagesInFolder( handle, sBasisDir, pAddress, rMap );
169 	bResult &= rebaseImagesInFolder( handle, sOfficeDir, pAddress, rMap );
170 	bResult &= rebaseImagesInFolder( handle, sUreDir, pAddress, rMap );
171 
172 	return bResult;
173 }
174 
IsServerSystem(MSIHANDLE)175 static BOOL IsServerSystem( MSIHANDLE /*handle*/ )
176 {
177 	OSVERSIONINFOEX osVersionInfoEx;
178 	osVersionInfoEx.dwOSVersionInfoSize = sizeof(OSVERSIONINFOEX);
179 	GetVersionEx(reinterpret_cast<LPOSVERSIONINFO>(&osVersionInfoEx));
180 
181 	if ( osVersionInfoEx.wProductType != VER_NT_WORKSTATION )
182 	{
183         OutputDebugStringFormat( "Server system detected. No rebase necessary!" );
184 		return TRUE;
185 	}
186 	else
187     {
188         OutputDebugStringFormat( "Client system detected. Rebase necessary!" );
189         return FALSE;
190     }
191 }
192 
InitExcludeFromRebaseList(MSIHANDLE handle,ExcludeLibsMap & rMap)193 static void InitExcludeFromRebaseList( MSIHANDLE handle, ExcludeLibsMap& rMap )
194 {
195 	size_t      nPos( 0 );
196     const TCHAR cDelim = ',';
197 	std::string sLibsExcluded = GetMsiProperty(handle, TEXT("EXCLUDE_FROM_REBASE"));
198 
199     while ( nPos < sLibsExcluded.size() )
200 	{
201 	    size_t nDelPos = sLibsExcluded.find_first_of( cDelim, nPos );
202 
203 		std::string sExcludedLibName;
204 		if ( nDelPos != std::string::npos )
205 		{
206 			sExcludedLibName = sLibsExcluded.substr( nPos, nDelPos - nPos );
207 		    nPos = nDelPos+1;
208 		}
209 		else
210 		{
211 			sExcludedLibName = sLibsExcluded.substr( nPos );
212 			nPos = sLibsExcluded.size();
213 		}
214 
215 		if ( sExcludedLibName.size() > 0 )
216 		{
217 			OutputDebugStringFormat( "Insert library %s into exclude from rebase list", sExcludedLibName.c_str() );
218 			rMap.insert( StringPair( sExcludedLibName, true ));
219 		}
220 	}
221 }
222 
RebaseLibrariesOnProperties(MSIHANDLE handle)223 extern "C" BOOL __stdcall RebaseLibrariesOnProperties( MSIHANDLE handle )
224 {
225 	static LPVOID pDefault = reinterpret_cast<LPVOID>(0x10000000);
226 
227 	OutputDebugStringFormat( "RebaseLibrariesOnProperties has been called" );
228 	std::string sDontOptimizeLibs = GetMsiProperty(handle, TEXT("DONTOPTIMIZELIBS"));
229 	if ( sDontOptimizeLibs.length() > 0 && sDontOptimizeLibs == "1" )
230 	{
231         OutputDebugStringFormat( "Don't optimize libraries set. No rebase necessary!" );
232 		return TRUE;
233 	}
234 
235 	if ( !IsServerSystem( handle ))
236 	{
237 		ExcludeLibsMap aExcludeLibsMap;
238 		InitExcludeFromRebaseList( handle, aExcludeLibsMap );
239 
240 		return rebaseImages( handle, pDefault, aExcludeLibsMap );
241 	}
242 
243 	return TRUE;
244 }
245