Python多階段框架實現虛擬試衣間,超逼真

2020-12-18 AI科技大本營

作者 | 李秋鍵

責編 | 晉兆雨

頭圖 | CSDN下載自視覺中國

任意姿態下的虛擬試衣因其巨大的應用潛力而引起了人們的廣泛關注。然而,現有的方法在將新穎的服裝和姿勢貼合到一個人身上的同時,很難保留服裝紋理和面部特徵(面孔、毛髮)中的細節。故在論文《Downto the Last Detail: Virtual Try-on with Detail Carving》中提出了一種新的多階段合成框架,可以很好地保留圖像顯著區域的豐富細節。

具體地說,就是提出了一個多階段的框架,將生成分解為空間對齊,然後由粗到細生成。為了更好地保留顯著區域的細節,如服裝和面部區域,我們提出了一個樹塊(樹擴張融合塊)來利用多尺度特徵在發生器網絡。通過多個階段的端到端訓練,可以聯合優化整個框架,最終使得視覺逼真度得到了顯著的提高、同時獲得了細節更為豐富的結果。在標準數據集上進行的大量實驗表明,他們提出的框架實現了最先進的性能,特別是在保存服裝紋理和面部識別的視覺細節方面。

故今天我們將在他們代碼的基礎上,實現虛擬換衣系統。具體流程如下:

實驗前的準備

首先我們使用的python版本是3.6.5所用到的模塊如下:

opencv是將用來進行圖像處理和圖片保存讀取等操作。numpy模塊用來處理矩陣數據的運算。pytorch模塊是常用的用來搭建模型和訓練的深度學習框架,和tensorflow以及Keras等具有相當的地位。json是為了讀取json存儲格式的數據。PIL庫可以完成對圖像進行批處理、生成圖像預覽、圖像格式轉換和圖像處理操作,包括圖像基本處理、像素處理、顏色處理等。argparse 是python自帶的命令行參數解析包,可以用來方便地讀取命令行參數。

網絡模型的定義和訓練

其中已經訓練好的模型地址如下:https://drive.google.com/open?id=1vQo4xNGdYe2uAtur0mDlHY7W2ZR3shWT。其中需要將其中的模型放到"./pretrained_checkpoint"目錄下。

對於數據集的存放,分為cloth_image(用來存儲衣服圖片),cloth_mask(用來分割衣服的mask,可以使用grabcut的方法進行分割保存),image(用來存儲人物圖片),parse_cihp(用來衣服語義分析的圖片結果,可以使用[CIHP_PGN](https://github.com/Engineering-Course/CIHP_PGN)的方法獲得)和pose_coco(用來存儲提取到的人物姿態特徵數據,可以使用openpose進行提取保存為josn數據即可)。

對於模型的訓練,我們需要使用到VGG19模型,網絡上可以很容易下載到,然後把它放到vgg_model文件夾下。

其中提出的一種基於目標姿態和店內服裝圖像由粗到細的多階段圖像生成框架,首先是設計了一個解析轉換網絡來預測目標語義圖,該語義圖在空間上對齊相應的身體部位,並提供更多關於軀幹和四肢形狀的結構信息。然後使用一種新的樹擴張融合塊(tree - block)算法,將空間對齊的布料與粗糙的渲染圖像融合在一起,以獲得更合理、更體面的結果。其中這個虛擬試穿網絡不僅不藉助3D信息,可以在任意姿態下將新衣服疊加到人的對應區域上,還保留和增強了顯著區域的豐富細節,如布料紋理、面部特徵等。同時還使用了空間對齊、多尺度上下文特徵聚集和顯著的區域增強,以由粗到細的方式各種難題。

(1)其中網絡主要使用pix2pix模型,其中的部分代碼如下:

class PixelDiscriminator(nn.Module):def__init__(self, input_nc,ndf=64, norm_layer=nn.InstanceNorm2d):super(PixelDiscriminator,self).__init__if type(norm_layer) ==functools.partial:use_bias =norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer ==nn.InstanceNorm2dself.net = nn.Sequential(nn.Conv2d(input_nc, ndf,kernel_size=1, stride=1, padding=0),nn.LeakyReLU(0.2, True),nn.Conv2d(ndf, ndf * 2,kernel_size=1, stride=1, padding=0, bias=use_bias),norm_layer(ndf * 2),nn.LeakyReLU(0.2, True),nn.Conv2d(ndf * 2, 1,kernel_size=1, stride=1, padding=0, bias=use_bias),nn.Sigmoid)defforward(self, input):return self.net(input)class PatchDiscriminator(nn.Module):def__init__(self, input_nc,ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):super(PatchDiscriminator,self).__init__if type(norm_layer) ==functools.partial: # no need to use biasas BatchNorm2d has affine parametersuse_bias =norm_layer.func == nn.InstanceNorm2delse:use_bias = norm_layer ==nn.InstanceNorm2dkw = 4padw = 1sequence =[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),nn.LeakyReLU(0.2, True)]nf_mult = 1nf_mult_prev = 1# channel upfor n in range(1,n_layers): # gradually increase thenumber of filtersnf_mult_prev = nf_mult #1,2,4,8nf_mult = min(2 ** n, 8)sequence += [nn.Conv2d(ndf *nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,bias=use_bias),norm_layer(ndf *nf_mult),nn.LeakyReLU(0.2,True)]# channel downnf_mult_prev = nf_multnf_mult = min(2 ** n_layers,8)sequence += [nn.Conv2d(ndf *nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw,bias=use_bias),norm_layer(ndf *nf_mult),nn.LeakyReLU(0.2, True)]# channel = 1 (bct, 1, x, x)sequence += [nn.Conv2d(ndf *nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction mapsequence += [nn.Sigmoid()]self.model =nn.Sequential(*sequence)

(2)生成器部分代碼:

class GenerationModel(BaseModel):defname(self):return 'Generation model:pix2pix | pix2pixHD'def__init__(self, opt):self.t0 = timeBaseModel.__init__(self,opt)self.train_mode =opt.train_mode# resume of networksresume_gmm = opt.resume_gmmresume_G_parse =opt.resume_G_parseresume_D_parse =opt.resume_D_parseresume_G_appearance =opt.resume_G_appresume_D_appearance =opt.resume_D_appresume_G_face = opt.resume_G_faceresume_D_face =opt.resume_D_face# define networkself.gmm_model =torch.nn.DataParallel(GMM(opt)).cudaself.generator_parsing =Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing, opt.ndf, opt.netG_parsing,opt.norm,not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)self.discriminator_parsing =Define_D(opt.input_nc_D_parsing, opt.ndf, opt.netD_parsing, opt.n_layers_D,opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)self.generator_appearance =Define_G(opt.input_nc_G_app, opt.output_nc_app, opt.ndf, opt.netG_app,opt.norm,not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids,with_tanh=False)self.discriminator_appearance = Define_D(opt.input_nc_D_app, opt.ndf,opt.netD_app, opt.n_layers_D,opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)self.generator_face =Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf, opt.netG_face,opt.norm,notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)self.discriminator_face =Define_D(opt.input_nc_D_face, opt.ndf, opt.netD_face, opt.n_layers_D,opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)if opt.train_mode == 'gmm':setattr(self,'generator', self.gmm_model)else:setattr(self,'generator', getattr(self, 'generator_' + self.train_mode))setattr(self, 'discriminator',getattr(self, 'discriminator_' + self.train_mode))# load networksself.networks_name = ['gmm','parsing', 'parsing', 'appearance', 'appearance', 'face', 'face']self.networks_model =[self.gmm_model, self.generator_parsing, self.discriminator_parsing,self.generator_appearance, self.discriminator_appearance,self.generator_face, self.discriminator_face]self.networks =dict(zip(self.networks_name, self.networks_model))self.resume_path =[resume_gmm, resume_G_parse, resume_D_parse, resume_G_appearance,resume_D_appearance, resume_G_face, resume_D_face]for network, resume inzip(self.networks_model, self.resume_path):if network != andresume != '':assert(osp.exists(resume), 'the resume not exits')print('loading...')self.load_network(network, resume, ifprint=False)# define optimizerself.optimizer_gmm =torch.optim.Adam(self.gmm_model.parameters, lr=opt.lr, betas=(0.5, 0.999))self.optimizer_parsing_G =torch.optim.Adam(self.generator_parsing.parameters, lr=opt.lr,betas=[opt.beta1, 0.999])self.optimizer_parsing_D =torch.optim.Adam(self.discriminator_parsing.parameters, lr=opt.lr,betas=[opt.beta1, 0.999])self.optimizer_appearance_G= torch.optim.Adam(self.generator_appearance.parameters, lr=opt.lr,betas=[opt.beta1, 0.999])self.optimizer_appearance_D= torch.optim.Adam(self.discriminator_appearance.parameters, lr=opt.lr,betas=[opt.beta1, 0.999])self.optimizer_face_G =torch.optim.Adam(self.generator_face.parameters, lr=opt.lr, betas=[opt.beta1,0.999])self.optimizer_face_D =torch.optim.Adam(self.discriminator_face.parameters, lr=opt.lr,betas=[opt.beta1, 0.999])if opt.train_mode == 'gmm':self.optimizer_G =self.optimizer_gmmelif opt.joint_all:self.optimizer_G =[self.optimizer_parsing_G, self.optimizer_appearance_G, self.optimizer_face_G]setattr(self,'optimizer_D', getattr(self, 'optimizer_' + self.train_mode + '_D'))else:setattr(self,'optimizer_G', getattr(self, 'optimizer_' + self.train_mode + '_G'))setattr(self,'optimizer_D', getattr(self, 'optimizer_' + self.train_mode + '_D'))self.t1 = time

模型的使用

在模型訓練完成之後,通過命令「python demo.py --batch_size_v 80--num_workers 4 --forward_save_path 'demo/forward'」生成圖片。

(1)分別定義讀取模型函數和模型調用處理圖片函數

def load_model(model, path):checkpoint = torch.load(path)try:model.load_state_dict(checkpoint)except:model.load_state_dict(checkpoint.state_dict)model = model.cudamodel.evalprint(20*'=')for param in model.parameters:param.requires_grad = Falsedef forward(opt, paths, gpu_ids, refine_path):cudnn.enabled = Truecudnn.benchmark = Trueopt.output_nc = 3gmm = GMM(opt)gmm =torch.nn.DataParallel(gmm).cuda# 'batch'generator_parsing =Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing, opt.ndf,opt.netG_parsing, opt.norm,notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)generator_app_cpvton =Define_G(opt.input_nc_G_app, opt.output_nc_app, opt.ndf, opt.netG_app,opt.norm,notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids, with_tanh=False)generator_face =Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf, opt.netG_face,opt.norm,notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)models = [gmm,generator_parsing, generator_app_cpvton, generator_face]for model, path in zip(models,paths):load_model(model, path)print('==>loaded model')augment = {}if '0.4' in torch.__version__:augment['3'] =transforms.Compose([# transforms.Resize(256),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) # change to [C, H, W]augment['1'] = augment['3']else:augment['3'] =transforms.Compose([#transforms.Resize(256),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]) # change to [C, H, W]augment['1'] =transforms.Compose([# transforms.Resize(256),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]) # change to [C, H, W]val_dataset = DemoDataset(opt,augment=augment)val_dataloader = DataLoader(val_dataset,shuffle=False,drop_last=False,num_workers=opt.num_workers,batch_size = opt.batch_size_v,pin_memory=True)with torch.no_grad:for i, result inenumerate(val_dataloader):'warped cloth'warped_cloth =warped_image(gmm, result)if opt.warp_cloth:warped_cloth_name =result['warped_cloth_name']warped_cloth_path =os.path.join('dataset', 'warped_cloth', warped_cloth_name[0])if notos.path.exists(os.path.split(warped_cloth_path)[0]):os.makedirs(os.path.split(warped_cloth_path)[0])utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)print('processing_%d'%i)continuesource_parse =result['source_parse'].float.cudatarget_pose_embedding =result['target_pose_embedding'].float.cudasource_image =result['source_image'].float.cudacloth_parse =result['cloth_parse'].cudacloth_image =result['cloth_image'].cudatarget_pose_img =result['target_pose_img'].float.cudacloth_parse =result['cloth_parse'].float.cudasource_parse_vis =result['source_parse_vis'].float.cuda"filter add clothinfomation"real_s =source_parseindex = [x for x inlist(range(20)) if x != 5 and x != 6 and x != 7]real_s_ =torch.index_select(real_s, 1, torch.tensor(index).cuda)input_parse =torch.cat((real_s_, target_pose_embedding, cloth_parse), 1).cuda'P'generate_parse =generator_parsing(input_parse) # tanhgenerate_parse =F.softmax(generate_parse, dim=1)generate_parse_argmax =torch.argmax(generate_parse, dim=1, keepdim=True).floatres = for index in range(20):res.append(generate_parse_argmax == index)generate_parse_argmax =torch.cat(res, dim=1).float"A"image_without_cloth =create_part(source_image, source_parse, 'image_without_cloth', False)input_app = torch.cat((image_without_cloth,warped_cloth, generate_parse), 1).cuda

源碼地址:

提取碼:qcj6

相關焦點

  • 家樂福推出「虛擬試衣間」 為自有品牌Tex打開銷路
    9月8日消息,日前,家樂福(需求面積:3000-15000平方米)攜漢威士集團旗下媒體Arena在旗下服裝品牌TEX官網上推出虛擬試衣間功能。據外媒報導,家樂福此舉是為促銷其自有品牌Tex不久前上線的秋冬新款女裝。
  • 德國海恩斯坦推出數位化虛擬試衣實驗室
    川北在線核心提示:原標題:德國海恩斯坦推出數位化虛擬試衣實驗室 德國海恩斯坦2020年9月推出了新型數位化試衣實驗室,使品牌商、零售商和供應商能使用先進的3D技術、虛擬試衣和可視化服務,實施或完善3D設計流程。
  • 3D 虛擬試衣時代已經來了,你試了嗎?
    圖片源於網絡 圖文無關實體零售商們正在尋求科技化的解決方案來彌補疫情對試衣間體驗帶來的限制,以增加消費者信心,促進購買。因為疫情相關政策使消費者在門店試穿衣服不再像以往那樣便捷、安全,曾在線上平臺被輕視了的虛擬試穿技術,眼下正在成為一些門店迫在眉睫的需求。
  • 研發布料仿真和AI人體重建技術,「亙星智能」要實現高精度虛擬試衣
    具體到萬億級市場規模的服裝行業,基於AI的3D人體重建技術已經實現了非接觸式人體測量,系統只需通過照片或視頻重建3D人體,用戶無需脫衣或身著緊身衣即可實現,這一技術目前正在推動虛擬試衣行業的發展,並且在汽車、軍工、人體工學設計等領域均有擴展空間。
  • [多圖]Computex 2012:用Kinect打造的夢幻女性試衣間
    今年的 Computex 除了有一眼望不盡,兩眼看不完的 Windows 8 各式筆記本、平板、AIO外,小編也在夾縫中看到一些前兩年比較夯,但是最近比較退燒的體感應用,像是因應 Kinect for Windows 推出,微軟現場就展示結合甲尚科技的iClone5,以及女性福音的虛擬試衣間,另外體感也能跟 Android 平板結合
  • 34個最優秀好用的Python開源框架
    人工智慧和深度學習的熱潮極大的帶動了Python的發展,迅速在Python生態圈中催生了大批的涉及各個方面的優秀Python開源框架,今天小編就帶你回顧下2018年度最優秀好用的Python開源框架。下面是從2018年中近10000個python開源框架中評價整理的34個最為好用的開源框架,它們細分可以分為Python Toolkit、Web、Terminal、Code Editor、Debugging、complier、Data Related、Chart8類,分布情況如下圖:Python ToolKit
  • 優衣庫的試衣間就安全了?
    2015-11-10/14:59 驅動中國 2015年11月10日消息,前段時間,優衣庫試衣間的不雅視頻事件傳得沸沸揚揚
  • 雲立方網科普:常用高效的Python爬蟲框架有哪些?
    Python是現在非常流行的程式語言,而爬蟲則是Python語言中最典型的應用,下面是總結的高效Python爬蟲框架,大家看看是否都用過呢。
  • 隱藏在試衣間中的服務細節
    先來看看大牌們不炒作不營銷,打造的試衣間美學~比如,你不得不知道,有點小潔癖的Hermés會在試衣間裡放一條絲巾,以防顧客們在試衣服的時候把妝蹭到衣服上;又比如,LV會充分考慮到選擇困難症患者,而給他們準備睡袍,好讓他們隨時走出試衣間調換款式;更有Prada給試衣間按上各種角度的鏡子,讓每一個自戀狂都恨不得抱著鏡子照到天荒地老。
  • 試衣間裡,一次拿進兩件同款式衣服
    本報訊 到大廈、商場買衣服,大多要進試衣間試一下大小。  就是這個小小的空間,近來被小偷盯上了。他們把那裡作為偷竊的中轉站,幾個人合夥玩了「一進一出」的把戲,神不知鬼不覺地把一件件成衣偷了出去。  昨天凌晨1點,杭州天水派出所。
  • 「女生試衣間為什麼總排隊?是我沒錯了」
    「女生試衣間為什麼總排隊?,是我沒錯了」逛街是女人的天性,很多女人逛街都會逛一整天,而且完全都不會覺得累,很多男人都非常討厭陪女朋友或者是老婆逛街,因為女生逛街的能力真的是太強大了,相信大家都有一個感受,那就是為什么女生在試衣間呆那麼久,為什么女試衣間總是要排隊的呢?下面這幾張圖就來告訴你看,看看你是不是也是這樣的呢?
  • 多家服裝企業紛紛開始轉戰線上 實現了「雲開張」
    不過,在線下閉店的同時,多家服裝企業紛紛開始轉戰線上,實現了「雲開張」。  北京商報記者近日調查發現,萊爾斯丹、音兒、太平鳥、卡賓等品牌紛紛建立了微信購物群,方便消費者購買。  有導購表示,目前在線上出售的部分產品,都比商場中便宜很多。「我們把商場的扣點都給顧客便宜了,皮衣都便宜了1000元以上。」上述導購說。
  • 朵拉試衣間和多個品牌達成合作,在新零售領域打通「大牌直供」
    據了解,朵拉試衣間與全球品牌方、源頭工廠合作,只做自營正品,以確保源頭直採和嚴格品控,讓店長無需囤貨便可輕鬆坐擁貨源。自2019年以來,朵拉試衣間已經與MK製造商、MarkFairwhale製造商、Armani製造商等多個服飾品牌達成戰略合作。
  • 使用python實現一個簡單計算器
    如果做一些簡單的界面,使用tkinter還是很方便的,畢竟是python自帶的庫。今天將會做下面這樣的一個計算器,可以實現基本的加減程序的運算,整體代碼邏輯比較簡單,主要是一個回調函數的理解。實現思路1.UI界面布局2.功能函數實現3.重構布局代碼4.按鈕回調函數綁定具體實現過程1.界面實現
  • 基於Android智能終端的虛擬SIM卡軟體實現
    為了解決上述所提出問題,基於智能終端的虛擬SIM卡技術得到重視和發展。通過支持虛擬SIM技術的終端消費者可以在全球覆蓋範圍內,以接近目的地價格水平使用數據上網服務。本文根據虛擬卡平臺架構,結合Android Telephony框架結構,研究虛擬卡在Android終端中的應用以及相關的技術。
  • 優衣庫試衣間「偷拍門」後續:品牌全國門店將加強自檢
    來源|都市現場綜合深圳公共轉載請註明來源昨天我們報導了,龍華艾扣(ICO)購物中心的優衣庫試衣間裡,竟然被顧客找到了一個正在運行狀態的針孔攝像頭,這也引發了市民的廣泛關注,今天記者從龍華警方了解到,目前該案件正在調查當中,事件發生後,優衣庫方面有沒有採取應對措施呢?
  • 優衣庫同公司品牌GU推出虛擬試衣模特YU
    優衣庫同公司品牌GU推出虛擬試衣模特YU來源:聯商網2020-03-17 11:32聯商網消息:日前,迅銷集團旗下快時尚品牌GU(極優)通過官方微信宣布,推出3D虛擬模特:YU。
  • 優衣庫試衣間暗藏針孔攝像頭 優衣庫門店回應:會給大家一個說法
    在最近一名女子在深圳龍華ICO購物中心優衣庫門店試衣間試衣服時,發現試衣間內竟然藏有針孔攝像頭,隨後女子慌忙逃出,並向優衣庫門店提出問題。隨後攝像頭被拆卸出來,發現是通過口香糖粘結在試衣間內。事情發生以後優衣庫表示會給廣大群眾一個說法。
  • VPF:適用於 Python 的開源視頻處理框架,加速視頻任務、提高 GPU...
    雷鋒網 AI 開發者按:近日,NVIDIA 開源了適用於 Python 的視頻處理框架「VideoProcessingFramework(VPF)」。該框架為開發人員提供了一個簡單但功能強大的 Python 工具,可用於硬體加速的視頻編碼、解碼和處理類等任務。
  • Python語言基本語法元素之格式框架:注釋、縮進、續行符
    筆者希望自己對python編程知識的加工處理,能對讀者產生作用。這次寫些基礎概念,Python語言基本語法元素。格式框架:注釋、縮進、續行符先來看看筆者寫的簡單代碼(pycharm環境下)貨幣兌換1.0:貨幣兌換3.0中的部分代碼:對比兩個代碼部分。